In [1]:
!pip install pydicom numpy scikit-image pillow scipy SimpleITK


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m26.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
"""
Medical Image Training - v10_perceptual_loss (Perceptual Loss)
Configuration: 64 LapSRN channels, 5 blocks | 128 DRRN channels, 25 blocks | LeakyReLU | Perceptual Loss
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import numpy as np
from PIL import Image
from tqdm import tqdm
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')

# ==============================================================================
# CONFIGURATION
# ==============================================================================

class Config:
    VERSION = 'v10_perceptual_loss'
    DATA_DIR = './preprocessed_data'
    SAVE_DIR = './trained_models_v10'
    
    EPOCHS_SR = 50
    EPOCHS_CLASS = 30
    BATCH_SIZE = 16
    LEARNING_RATE = 1e-4
    
    LAPSRN_SCALE = 4
    DRRN_SCALE = 2
    TOTAL_SCALE = 8
    
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # v10_perceptual_loss specific - USING PERCEPTUAL LOSS
    LAPSRN_CHANNELS = 64
    LAPSRN_BLOCKS = 5
    DRRN_CHANNELS = 128
    DRRN_BLOCKS = 25
    KERNEL_SIZE = 3
    ACTIVATION = 'leaky'
    BACKBONE = 'resnet50'
    
    # Loss weights
    L1_WEIGHT = 1.0
    PERCEPTUAL_WEIGHT = 0.1  # Perceptual loss weight


# ==============================================================================
# DATASETS
# ==============================================================================

class SuperResolutionDataset(Dataset):
    def __init__(self, preprocessed_data_dir, hr_patch_size=64, scale_factor=4):
        self.hr_patch_size = hr_patch_size
        self.lr_patch_size = hr_patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files) * 4
    
    def __getitem__(self, idx):
        img_idx = idx // 4
        img_path = self.image_files[img_idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = np.random.randint(0, max(1, h - self.hr_patch_size + 1))
        left = np.random.randint(0, max(1, w - self.hr_patch_size + 1))
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class DRRNDataset(Dataset):
    def __init__(self, preprocessed_data_dir, patch_size=64, scale_factor=2):
        self.hr_patch_size = patch_size
        self.lr_patch_size = patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files) * 4
    
    def __getitem__(self, idx):
        img_idx = idx // 4
        img_path = self.image_files[img_idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = np.random.randint(0, max(1, h - self.hr_patch_size + 1))
        left = np.random.randint(0, max(1, w - self.hr_patch_size + 1))
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class ClassificationDataset(Dataset):
    def __init__(self, preprocessed_data_dir, enhance_size=224):
        self.enhance_size = enhance_size
        self.data = []
        
        category_map = {'Normal': 0, 'Ischemia': 1, 'Bleeding': 2}
        urgency_map = {'Normal': 0.1, 'Ischemia': 0.7, 'Bleeding': 0.95}
        
        for category, label in category_map.items():
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.data.append({
                            'path': os.path.join(category_path, filename),
                            'label': label,
                            'urgency': urgency_map[category]
                        })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        img = Image.open(sample['path']).convert('L')
        img = img.resize((self.enhance_size, self.enhance_size), Image.BICUBIC)
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array.copy()).unsqueeze(0).float()
        
        return img_tensor, sample['label'], sample['urgency']


# ==============================================================================
# PERCEPTUAL LOSS NETWORK
# ==============================================================================

class PerceptualLossNetwork(nn.Module):
    """
    Uses VGG19 pretrained features to compute perceptual loss.
    Extracts features from relu1_2, relu2_2, relu3_2 layers.
    Adapted for single-channel (grayscale) input by replicating to 3 channels.
    """
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features
        
        # Split VGG into feature extraction stages
        # relu1_2: layers 0-3
        # relu2_2: layers 4-8
        # relu3_2: layers 9-13
        self.slice1 = nn.Sequential(*[vgg[i] for i in range(4)])   # relu1_2
        self.slice2 = nn.Sequential(*[vgg[i] for i in range(4, 9)])   # relu2_2
        self.slice3 = nn.Sequential(*[vgg[i] for i in range(9, 14)])  # relu3_2
        
        # Freeze all VGG parameters
        for param in self.parameters():
            param.requires_grad = False
        
        # Feature weights (deeper = more weight)
        self.weights = [1.0, 1.0, 1.0]
    
    def forward(self, x):
        """
        Args:
            x: grayscale image tensor (B, 1, H, W)
        Returns:
            list of feature maps from each VGG stage
        """
        # Convert grayscale to 3-channel by repeating
        x_3ch = x.repeat(1, 3, 1, 1)
        
        h1 = self.slice1(x_3ch)
        h2 = self.slice2(h1)
        h3 = self.slice3(h2)
        
        return [h1, h2, h3]


class CombinedSRLoss(nn.Module):
    """
    Combined L1 + Perceptual loss for super-resolution.
    L1 preserves pixel accuracy, perceptual loss preserves structure/texture.
    """
    def __init__(self, l1_weight=1.0, perceptual_weight=0.1, device='cpu'):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.perceptual_net = PerceptualLossNetwork().to(device)
        self.perceptual_net.eval()
        self.l1_weight = l1_weight
        self.perceptual_weight = perceptual_weight
    
    def forward(self, sr_output, hr_target):
        # L1 pixel loss
        l1 = self.l1_loss(sr_output, hr_target)
        
        # Perceptual loss
        sr_features = self.perceptual_net(sr_output)
        hr_features = self.perceptual_net(hr_target)
        
        perceptual = 0.0
        for i, (sf, hf) in enumerate(zip(sr_features, hr_features)):
            perceptual += self.perceptual_net.weights[i] * F.l1_loss(sf, hf)
        perceptual /= len(sr_features)
        
        total = self.l1_weight * l1 + self.perceptual_weight * perceptual
        return total, l1, perceptual


# ==============================================================================
# BUILDING BLOCKS
# ==============================================================================

class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        return self.activation(out + residual)


class RecursiveBlock(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.activation(self.conv2(out))
        return out + residual


# ==============================================================================
# MODELS - v10_perceptual_loss (same architecture as baseline)
# ==============================================================================

class LapSRN(nn.Module):
    """v10_perceptual_loss: Standard LapSRN (same as baseline, loss is different)"""
    def __init__(self, scale_factor=4, num_channels=1):
        super().__init__()
        self.scale_factor = scale_factor
        self.num_levels = 2  # 2x2 = 4x
        ch = 64
        
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, ch, 3, padding=1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.pyramid_levels = nn.ModuleList()
        self.image_reconstruction = nn.ModuleList()
        
        for _ in range(self.num_levels):
            layers = []
            for _ in range(5):
                layers.append(ResidualBlock(ch, 3))
            layers.append(nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.2, True))
            
            self.pyramid_levels.append(nn.Sequential(*layers))
            self.image_reconstruction.append(nn.Conv2d(ch, num_channels, 3, padding=1))
    
    def forward(self, x):
        features = self.feature_extraction(x)
        outputs = []
        current_features = features
        
        for level_idx in range(self.num_levels):
            current_features = self.pyramid_levels[level_idx](current_features)
            img_out = self.image_reconstruction[level_idx](current_features)
            
            if level_idx > 0:
                img_out = img_out + F.interpolate(outputs[-1], scale_factor=2, mode='bilinear', align_corners=False)
            else:
                img_out = img_out + F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            
            outputs.append(img_out)
        
        return outputs[-1], outputs


class DRRN(nn.Module):
    """v10_perceptual_loss: Standard DRRN (same as baseline, loss is different)"""
    def __init__(self, num_channels=1, scale_factor=2):
        super().__init__()
        self.scale_factor = scale_factor
        ch = 128
        
        self.input_conv = nn.Conv2d(num_channels, ch, 3, padding=1)
        
        self.recursive_blocks = nn.ModuleList()
        for _ in range(25):
            self.recursive_blocks.append(RecursiveBlock(ch, 3))
        
        self.fusion = nn.Sequential(
            nn.Conv2d(ch * 3, ch, 1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.upsample = nn.Sequential(
            nn.Conv2d(ch, ch * 4, 3, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True)
        )
        
        self.output_conv = nn.Sequential(
            nn.Conv2d(ch, 64, 3, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, num_channels, 3, padding=1)
        )
    
    def forward(self, x):
        input_upsampled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
        
        features = self.input_conv(x)
        multi_scale_features = []
        current = features
        
        collect_indices = [8, 16, 24]
        
        for idx, block in enumerate(self.recursive_blocks):
            current = block(current)
            if idx in collect_indices:
                multi_scale_features.append(current)
        
        fused = torch.cat(multi_scale_features, dim=1)
        fused = self.fusion(fused)
        upsampled = self.upsample(fused)
        output = self.output_conv(upsampled)
        
        return output + input_upsampled


class MedicalImageClassifier(nn.Module):
    """v10_perceptual_loss: Standard ResNet50 classifier (same as baseline)"""
    def __init__(self, num_classes=3):
        super().__init__()
        
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        self.classification_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self.urgency_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.feature_head = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classification_head(features), self.urgency_head(features), self.feature_head(features)


# ==============================================================================
# TRAINING
# ==============================================================================

def train_model():
    config = Config()
    
    print(f"\n{'='*80}")
    print(f"TRAINING {config.VERSION.upper()}")
    print(f"{'='*80}")
    print(f"Configuration:")
    print(f"  - LapSRN: {config.LAPSRN_CHANNELS} channels, {config.LAPSRN_BLOCKS} blocks")
    print(f"  - DRRN: {config.DRRN_CHANNELS} channels, {config.DRRN_BLOCKS} blocks")
    print(f"  - Kernel: {config.KERNEL_SIZE}x{config.KERNEL_SIZE}")
    print(f"  - Activation: {config.ACTIVATION}")
    print(f"  - Backbone: {config.BACKBONE.upper()}")
    print(f"  - Loss: L1 (weight={config.L1_WEIGHT}) + Perceptual (weight={config.PERCEPTUAL_WEIGHT})")
    print(f"  - Device: {config.DEVICE}")
    print(f"{'='*80}\n")
    
    version_save_dir = os.path.join(config.SAVE_DIR, config.VERSION)
    os.makedirs(version_save_dir, exist_ok=True)
    
    # Initialize models
    lapsrn = LapSRN().to(config.DEVICE)
    drrn = DRRN().to(config.DEVICE)
    classifier = MedicalImageClassifier().to(config.DEVICE)
    
    # Initialize perceptual loss (shared for both SR models)
    combined_loss = CombinedSRLoss(
        l1_weight=config.L1_WEIGHT,
        perceptual_weight=config.PERCEPTUAL_WEIGHT,
        device=config.DEVICE
    )
    
    print(f"\nPerceptual Loss info:")
    print(f"  - Feature extractor: VGG19 (pretrained, frozen)")
    print(f"  - Feature layers: relu1_2, relu2_2, relu3_2")
    print(f"  - Grayscale handling: replicate to 3 channels before VGG")
    print(f"  - Total loss = {config.L1_WEIGHT} * L1 + {config.PERCEPTUAL_WEIGHT} * Perceptual")
    print(f"  - Expected: Better structural/texture preservation in SR outputs")
    
    # Create datasets
    sr_dataset = SuperResolutionDataset(config.DATA_DIR, hr_patch_size=64, scale_factor=4)
    drrn_dataset = DRRNDataset(config.DATA_DIR, patch_size=64, scale_factor=2)
    class_dataset = ClassificationDataset(config.DATA_DIR, enhance_size=224)
    
    # Split datasets (80/20)
    train_sr, val_sr = torch.utils.data.random_split(sr_dataset, 
        [int(0.8*len(sr_dataset)), len(sr_dataset)-int(0.8*len(sr_dataset))])
    train_drrn, val_drrn = torch.utils.data.random_split(drrn_dataset,
        [int(0.8*len(drrn_dataset)), len(drrn_dataset)-int(0.8*len(drrn_dataset))])
    train_class, val_class = torch.utils.data.random_split(class_dataset,
        [int(0.8*len(class_dataset)), len(class_dataset)-int(0.8*len(class_dataset))])
    
    # DataLoaders
    train_sr_loader = DataLoader(train_sr, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    train_drrn_loader = DataLoader(train_drrn, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    train_class_loader = DataLoader(train_class, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    
    # Train LapSRN
    print("\n" + "="*80)
    print("[1/3] Training LapSRN (16x16 → 64x64, 4x upsampling) + Perceptual Loss")
    print("="*80)
    
    optimizer = optim.Adam(lapsrn.parameters(), lr=config.LEARNING_RATE)
    best_loss = float('inf')
    
    for epoch in range(config.EPOCHS_SR):
        lapsrn.train()
        train_loss = 0
        pbar = tqdm(train_sr_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_SR}')
        
        for lr_imgs, hr_imgs in pbar:
            lr_imgs, hr_imgs = lr_imgs.to(config.DEVICE), hr_imgs.to(config.DEVICE)
            optimizer.zero_grad()
            sr_output, _ = lapsrn(lr_imgs)
            loss, l1, perceptual = combined_loss(sr_output, hr_imgs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.6f}', 'l1': f'{l1.item():.6f}', 'perc': f'{perceptual.item():.6f}'})
        
        avg_loss = train_loss / len(train_sr_loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(lapsrn.state_dict(), os.path.join(version_save_dir, 'lapsrn_best.pth'))
    
    print(f"✓ LapSRN training complete (best loss: {best_loss:.6f})")
    
    # Train DRRN
    print("\n" + "="*80)
    print("[2/3] Training DRRN (64x64 → 128x128, 2x upsampling) + Perceptual Loss")
    print("="*80)
    
    optimizer = optim.Adam(drrn.parameters(), lr=config.LEARNING_RATE)
    best_loss = float('inf')
    
    for epoch in range(config.EPOCHS_SR):
        drrn.train()
        train_loss = 0
        pbar = tqdm(train_drrn_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_SR}')
        
        for lr_imgs, hr_imgs in pbar:
            lr_imgs, hr_imgs = lr_imgs.to(config.DEVICE), hr_imgs.to(config.DEVICE)
            optimizer.zero_grad()
            sr_output = drrn(lr_imgs)
            loss, l1, perceptual = combined_loss(sr_output, hr_imgs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.6f}', 'l1': f'{l1.item():.6f}', 'perc': f'{perceptual.item():.6f}'})
        
        avg_loss = train_loss / len(train_drrn_loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(drrn.state_dict(), os.path.join(version_save_dir, 'drrn_best.pth'))
    
    print(f"✓ DRRN training complete (best loss: {best_loss:.6f})")
    
    # Train Classifier
    print("\n" + "="*80)
    print("[3/3] Training Classifier (128x128 → 224x224 → Classification)")
    print("="*80)
    
    optimizer = optim.Adam(classifier.parameters(), lr=config.LEARNING_RATE)
    class_criterion = nn.CrossEntropyLoss()
    urgency_criterion = nn.BCELoss()
    best_acc = 0.0
    
    for epoch in range(config.EPOCHS_CLASS):
        classifier.train()
        correct, total = 0, 0
        pbar = tqdm(train_class_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_CLASS}')
        
        for images, labels, urgency in pbar:
            images = images.to(config.DEVICE)
            labels = labels.to(config.DEVICE)
            urgency = urgency.to(config.DEVICE).unsqueeze(1).float()
            
            optimizer.zero_grad()
            class_out, urgency_out, _ = classifier(images)
            loss = class_criterion(class_out, labels) + 0.5 * urgency_criterion(urgency_out, urgency)
            loss.backward()
            optimizer.step()
            
            _, predicted = torch.max(class_out, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            pbar.set_postfix({'acc': f'{100*correct/total:.2f}%'})
        
        acc = 100 * correct / total
        if acc > best_acc:
            best_acc = acc
            torch.save(classifier.state_dict(), os.path.join(version_save_dir, 'classifier_best.pth'))
    
    print(f"✓ Classifier training complete (best accuracy: {best_acc:.2f}%)")
    
    # Save configuration
    config_dict = {
        'version': config.VERSION,
        'lapsrn_channels': config.LAPSRN_CHANNELS,
        'lapsrn_blocks': config.LAPSRN_BLOCKS,
        'drrn_channels': config.DRRN_CHANNELS,
        'drrn_blocks': config.DRRN_BLOCKS,
        'kernel_size': config.KERNEL_SIZE,
        'activation': config.ACTIVATION,
        'backbone': config.BACKBONE,
        'l1_weight': config.L1_WEIGHT,
        'perceptual_weight': config.PERCEPTUAL_WEIGHT,
        'epochs_sr': config.EPOCHS_SR,
        'epochs_class': config.EPOCHS_CLASS,
        'timestamp': datetime.now().isoformat(),
        'notes': 'Same baseline architecture. SR training uses combined L1 + VGG19 perceptual loss for better structural preservation.'
    }
    
    with open(os.path.join(version_save_dir, 'config.json'), 'w') as f:
        json.dump(config_dict, f, indent=2)
    
    print(f"\n{'='*80}")
    print("✓ ALL TRAINING COMPLETE!")
    print(f"{'='*80}")
    print(f"Models saved to: {version_save_dir}")
    print("\nPipeline: 16x16 → LapSRN(4x) → 64x64 → DRRN(2x) → 128x128 → Classifier(224x224)")
    print("\nKey difference from v1_baseline:")
    print("  - SR loss changed from L1-only to L1 + Perceptual (VGG19)")
    print("  - VGG19 features: relu1_2, relu2_2, relu3_2")
    print("  - Grayscale input replicated to 3ch before passing through VGG19")
    print("  - Expected: Better texture/structure preservation, higher SSIM")


if __name__ == "__main__":
    train_model()


TRAINING V10_PERCEPTUAL_LOSS
Configuration:
  - LapSRN: 64 channels, 5 blocks
  - DRRN: 128 channels, 25 blocks
  - Kernel: 3x3
  - Activation: leaky
  - Backbone: RESNET50
  - Loss: L1 (weight=1.0) + Perceptual (weight=0.1)
  - Device: cuda

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100% 548M/548M [00:00<00:00, 628MB/s] 



Perceptual Loss info:
  - Feature extractor: VGG19 (pretrained, frozen)
  - Feature layers: relu1_2, relu2_2, relu3_2
  - Grayscale handling: replicate to 3 channels before VGG
  - Total loss = 1.0 * L1 + 0.1 * Perceptual
  - Expected: Better structural/texture preservation in SR outputs

[1/3] Training LapSRN (16x16 → 64x64, 4x upsampling) + Perceptual Loss


Epoch 1/50: 100% 1328/1328 [00:43<00:00, 30.83it/s, loss=0.000637, l1=0.000532, perc=0.001051]
Epoch 2/50: 100% 1328/1328 [00:41<00:00, 32.20it/s, loss=0.008034, l1=0.004197, perc=0.038374]
Epoch 3/50: 100% 1328/1328 [00:38<00:00, 34.37it/s, loss=0.014466, l1=0.007409, perc=0.070563]
Epoch 4/50: 100% 1328/1328 [00:38<00:00, 34.37it/s, loss=0.041659, l1=0.023722, perc=0.179379]
Epoch 5/50: 100% 1328/1328 [00:38<00:00, 34.47it/s, loss=0.030805, l1=0.016514, perc=0.142912]
Epoch 6/50: 100% 1328/1328 [00:37<00:00, 35.13it/s, loss=0.001494, l1=0.000706, perc=0.007879]
Epoch 7/50: 100% 1328/1328 [00:38<00:00, 34.65it/s, loss=0.006900, l1=0.003429, perc=0.034709]
Epoch 8/50: 100% 1328/1328 [00:38<00:00, 34.17it/s, loss=0.019197, l1=0.010547, perc=0.086500]
Epoch 9/50: 100% 1328/1328 [00:38<00:00, 34.22it/s, loss=0.015633, l1=0.008664, perc=0.069695]
Epoch 10/50: 100% 1328/1328 [00:39<00:00, 33.78it/s, loss=0.037853, l1=0.022757, perc=0.150962]
Epoch 11/50: 100% 1328/1328 [00:39<00:00, 33.99it

✓ LapSRN training complete (best loss: 0.017746)

[2/3] Training DRRN (64x64 → 128x128, 2x upsampling) + Perceptual Loss


Epoch 1/50: 100% 1328/1328 [01:45<00:00, 12.53it/s, loss=0.001950, l1=0.001084, perc=0.008654]
Epoch 2/50: 100% 1328/1328 [01:47<00:00, 12.33it/s, loss=0.001090, l1=0.000567, perc=0.005233]
Epoch 3/50: 100% 1328/1328 [01:57<00:00, 11.34it/s, loss=0.006807, l1=0.003929, perc=0.028781]
Epoch 4/50: 100% 1328/1328 [01:57<00:00, 11.32it/s, loss=0.011897, l1=0.006469, perc=0.054286]
Epoch 5/50: 100% 1328/1328 [01:57<00:00, 11.35it/s, loss=0.013449, l1=0.007372, perc=0.060774]
Epoch 6/50: 100% 1328/1328 [01:57<00:00, 11.33it/s, loss=0.008577, l1=0.004660, perc=0.039176]
Epoch 7/50: 100% 1328/1328 [01:57<00:00, 11.35it/s, loss=0.006768, l1=0.003887, perc=0.028811]
Epoch 8/50: 100% 1328/1328 [01:57<00:00, 11.32it/s, loss=0.000887, l1=0.000490, perc=0.003969]
Epoch 9/50: 100% 1328/1328 [01:57<00:00, 11.34it/s, loss=0.000537, l1=0.000290, perc=0.002470]
Epoch 10/50: 100% 1328/1328 [01:57<00:00, 11.32it/s, loss=0.005731, l1=0.003134, perc=0.025969]
Epoch 11/50: 100% 1328/1328 [01:57<00:00, 11.31it

✓ DRRN training complete (best loss: 0.004469)

[3/3] Training Classifier (128x128 → 224x224 → Classification)


Epoch 1/30:   0% 0/332 [00:00<?, ?it/s]MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvBinWinogradRxSf3x2; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvHipImplicitGemmGroupBwdXdlops; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xW
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvHipImplicitGemmGroupWrwXdlops; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xW
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC; key: 2x512x7x7x1x3x3x1x512x16x1x1x0x2x2x0x1x1x0x0x1

In [None]:
"""
Model Evaluation Script - v10_perceptual_loss (Notebook-Friendly)
Calculates PSNR, SSIM, Accuracy, and Classification Metrics
Note: v10 uses the same baseline architecture as v1. Perceptual loss only affects training.

Usage in Jupyter Notebook:
    from evaluate_v10_perceptual_loss import evaluate_model
    
    results = evaluate_model(
        version='v10_perceptual_loss',
        data_dir='./preprocessed_data',
        model_dir='./trained_models_v10'
    )
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from tqdm import tqdm
import json
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# ==============================================================================
# DATASETS
# ==============================================================================

class SuperResolutionDataset(Dataset):
    def __init__(self, preprocessed_data_dir, hr_patch_size=64, scale_factor=4):
        self.hr_patch_size = hr_patch_size
        self.lr_patch_size = hr_patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = (h - self.hr_patch_size) // 2
        left = (w - self.hr_patch_size) // 2
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class DRRNDataset(Dataset):
    def __init__(self, preprocessed_data_dir, patch_size=64, scale_factor=2):
        self.hr_patch_size = patch_size
        self.lr_patch_size = patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = (h - self.hr_patch_size) // 2
        left = (w - self.hr_patch_size) // 2
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class ClassificationDataset(Dataset):
    def __init__(self, preprocessed_data_dir, enhance_size=224):
        self.enhance_size = enhance_size
        self.data = []
        
        category_map = {'Normal': 0, 'Ischemia': 1, 'Bleeding': 2}
        urgency_map = {'Normal': 0.1, 'Ischemia': 0.7, 'Bleeding': 0.95}
        
        for category, label in category_map.items():
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.data.append({
                            'path': os.path.join(category_path, filename),
                            'label': label,
                            'urgency': urgency_map[category]
                        })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        img = Image.open(sample['path']).convert('L')
        img = img.resize((self.enhance_size, self.enhance_size), Image.BICUBIC)
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array.copy()).unsqueeze(0).float()
        
        return img_tensor, sample['label'], sample['urgency']


# ==============================================================================
# MODEL DEFINITIONS - v10_perceptual_loss (baseline architecture)
# ==============================================================================

class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        return self.activation(out + residual)


class RecursiveBlock(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.activation(self.conv2(out))
        return out + residual


class LapSRN(nn.Module):
    def __init__(self, scale_factor=4, num_channels=1):
        super().__init__()
        self.scale_factor = scale_factor
        self.num_levels = 2
        ch = 64
        
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, ch, 3, padding=1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.pyramid_levels = nn.ModuleList()
        self.image_reconstruction = nn.ModuleList()
        
        for _ in range(self.num_levels):
            layers = []
            for _ in range(5):
                layers.append(ResidualBlock(ch, 3))
            layers.append(nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.2, True))
            
            self.pyramid_levels.append(nn.Sequential(*layers))
            self.image_reconstruction.append(nn.Conv2d(ch, num_channels, 3, padding=1))
    
    def forward(self, x):
        features = self.feature_extraction(x)
        outputs = []
        current_features = features
        
        for level_idx in range(self.num_levels):
            current_features = self.pyramid_levels[level_idx](current_features)
            img_out = self.image_reconstruction[level_idx](current_features)
            
            if level_idx > 0:
                img_out = img_out + F.interpolate(outputs[-1], scale_factor=2, mode='bilinear', align_corners=False)
            else:
                img_out = img_out + F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            
            outputs.append(img_out)
        
        return outputs[-1], outputs


class DRRN(nn.Module):
    def __init__(self, num_channels=1, scale_factor=2):
        super().__init__()
        self.scale_factor = scale_factor
        ch = 128
        
        self.input_conv = nn.Conv2d(num_channels, ch, 3, padding=1)
        
        self.recursive_blocks = nn.ModuleList()
        for _ in range(25):
            self.recursive_blocks.append(RecursiveBlock(ch, 3))
        
        self.fusion = nn.Sequential(
            nn.Conv2d(ch * 3, ch, 1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.upsample = nn.Sequential(
            nn.Conv2d(ch, ch * 4, 3, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True)
        )
        
        self.output_conv = nn.Sequential(
            nn.Conv2d(ch, 64, 3, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, num_channels, 3, padding=1)
        )
    
    def forward(self, x):
        input_upsampled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
        
        features = self.input_conv(x)
        multi_scale_features = []
        current = features
        
        collect_indices = [8, 16, 24]
        
        for idx, block in enumerate(self.recursive_blocks):
            current = block(current)
            if idx in collect_indices:
                multi_scale_features.append(current)
        
        fused = torch.cat(multi_scale_features, dim=1)
        fused = self.fusion(fused)
        upsampled = self.upsample(fused)
        output = self.output_conv(upsampled)
        
        return output + input_upsampled


class MedicalImageClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        
        from torchvision import models
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        self.classification_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self.urgency_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.feature_head = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classification_head(features), self.urgency_head(features), self.feature_head(features)


# ==============================================================================
# EVALUATION
# ==============================================================================

def evaluate_model(version='v10_perceptual_loss', data_dir='./preprocessed_data', 
                   model_dir='./trained_models_v10',
                   device='cuda' if torch.cuda.is_available() else 'cpu'):
    
    print(f"\n{'='*80}")
    print(f"EVALUATING MODEL: {version}")
    print(f"{'='*80}\n")
    
    version_dir = os.path.join(model_dir, version)
    config_path = os.path.join(version_dir, 'config.json')
    
    if not os.path.exists(config_path):
        print(f"ERROR: Config file not found at {config_path}")
        return None
    
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    print(f"Loaded configuration for {version}")
    print(f"  LapSRN: {config.get('lapsrn_channels', 64)} channels, {config.get('lapsrn_blocks', 5)} blocks (baseline)")
    print(f"  DRRN: {config.get('drrn_channels', 128)} channels, {config.get('drrn_blocks', 25)} blocks (baseline)")
    print(f"  Backbone: {config.get('backbone', 'resnet50')}")
    print(f"  Perceptual Loss (training only): L1={config.get('l1_weight', 1.0)}, Perceptual={config.get('perceptual_weight', 0.1)}\n")
    
    # Datasets
    print("Creating datasets...")
    sr_dataset = SuperResolutionDataset(data_dir, hr_patch_size=64, scale_factor=4)
    drrn_dataset = DRRNDataset(data_dir, patch_size=64, scale_factor=2)
    class_dataset = ClassificationDataset(data_dir, enhance_size=224)
    
    sr_loader = DataLoader(sr_dataset, batch_size=16, shuffle=False, num_workers=2)
    drrn_loader = DataLoader(drrn_dataset, batch_size=16, shuffle=False, num_workers=2)
    class_loader = DataLoader(class_dataset, batch_size=16, shuffle=False, num_workers=2)
    
    print(f"  SR dataset: {len(sr_dataset)} samples")
    print(f"  DRRN dataset: {len(drrn_dataset)} samples")
    print(f"  Classification dataset: {len(class_dataset)} samples\n")
    
    # Load models
    print("Loading models...")
    lapsrn = LapSRN().to(device)
    drrn = DRRN().to(device)
    classifier = MedicalImageClassifier().to(device)
    
    lapsrn.load_state_dict(torch.load(os.path.join(version_dir, 'lapsrn_best.pth'), map_location=device))
    drrn.load_state_dict(torch.load(os.path.join(version_dir, 'drrn_best.pth'), map_location=device))
    classifier.load_state_dict(torch.load(os.path.join(version_dir, 'classifier_best.pth'), map_location=device))
    
    lapsrn.eval()
    drrn.eval()
    classifier.eval()
    print("✓ Models loaded successfully\n")
    
    # Evaluate LapSRN
    print("="*80)
    print("[1/3] Evaluating LapSRN (16x16 → 64x64)")
    print("="*80)
    
    lapsrn_psnr_list, lapsrn_ssim_list = [], []
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(sr_loader, desc="LapSRN Evaluation"):
            lr_imgs = lr_imgs.to(device)
            sr_output, _ = lapsrn(lr_imgs)
            
            sr_np = sr_output.cpu().numpy()
            hr_np = hr_imgs.numpy()
            
            for i in range(sr_np.shape[0]):
                sr_img = np.clip(sr_np[i, 0], 0, 1)
                hr_img = np.clip(hr_np[i, 0], 0, 1)
                lapsrn_psnr_list.append(psnr(hr_img, sr_img, data_range=1.0))
                lapsrn_ssim_list.append(ssim(hr_img, sr_img, data_range=1.0))
    
    lapsrn_psnr_mean = np.mean(lapsrn_psnr_list)
    lapsrn_ssim_mean = np.mean(lapsrn_ssim_list)
    
    print(f"\n✓ LapSRN Results:")
    print(f"  PSNR: {lapsrn_psnr_mean:.4f} dB")
    print(f"  SSIM: {lapsrn_ssim_mean:.4f}\n")
    
    # Evaluate DRRN
    print("="*80)
    print("[2/3] Evaluating DRRN (64x64 → 128x128)")
    print("="*80)
    
    drrn_psnr_list, drrn_ssim_list = [], []
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(drrn_loader, desc="DRRN Evaluation"):
            lr_imgs = lr_imgs.to(device)
            sr_output = drrn(lr_imgs)
            
            sr_np = sr_output.cpu().numpy()
            hr_np = hr_imgs.numpy()
            
            for i in range(sr_np.shape[0]):
                sr_img = np.clip(sr_np[i, 0], 0, 1)
                hr_img = np.clip(hr_np[i, 0], 0, 1)
                drrn_psnr_list.append(psnr(hr_img, sr_img, data_range=1.0))
                drrn_ssim_list.append(ssim(hr_img, sr_img, data_range=1.0))
    
    drrn_psnr_mean = np.mean(drrn_psnr_list)
    drrn_ssim_mean = np.mean(drrn_ssim_list)
    
    print(f"\n✓ DRRN Results:")
    print(f"  PSNR: {drrn_psnr_mean:.4f} dB")
    print(f"  SSIM: {drrn_ssim_mean:.4f}\n")
    
    # Evaluate Classifier
    print("="*80)
    print("[3/3] Evaluating Classifier")
    print("="*80)
    
    all_preds, all_labels, all_urgency_preds, all_urgency_true = [], [], [], []
    
    with torch.no_grad():
        for images, labels, urgency in tqdm(class_loader, desc="Classifier Evaluation"):
            images = images.to(device)
            class_out, urgency_out, _ = classifier(images)
            _, predicted = torch.max(class_out, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_urgency_preds.extend(urgency_out.cpu().numpy().flatten())
            all_urgency_true.extend(urgency.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds) * 100
    conf_matrix = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, 
                                        target_names=['Normal', 'Ischemia', 'Bleeding'], output_dict=True)
    urgency_mse = np.mean((np.array(all_urgency_preds) - np.array(all_urgency_true))**2)
    urgency_mae = np.mean(np.abs(np.array(all_urgency_preds) - np.array(all_urgency_true)))
    
    print(f"\n✓ Classification Results:")
    print(f"  Accuracy: {accuracy:.2f}%")
    print(f"\n  Confusion Matrix:")
    print(f"  {conf_matrix}")
    print(f"\n  Per-Class Metrics:")
    for class_name in ['Normal', 'Ischemia', 'Bleeding']:
        metrics = class_report[class_name]
        print(f"    {class_name}:")
        print(f"      Precision: {metrics['precision']:.4f}")
        print(f"      Recall: {metrics['recall']:.4f}")
        print(f"      F1-Score: {metrics['f1-score']:.4f}")
    
    print(f"\n  Urgency Prediction:")
    print(f"    MSE: {urgency_mse:.4f}")
    print(f"    MAE: {urgency_mae:.4f}\n")
    
    # Save results
    results = {
        'version': version,
        'lapsrn': { 'psnr': float(lapsrn_psnr_mean), 'ssim': float(lapsrn_ssim_mean) },
        'drrn': { 'psnr': float(drrn_psnr_mean), 'ssim': float(drrn_ssim_mean) },
        'classifier': {
            'accuracy': float(accuracy),
            'confusion_matrix': conf_matrix.tolist(),
            'classification_report': class_report,
            'urgency_mse': float(urgency_mse),
            'urgency_mae': float(urgency_mae)
        }
    }
    
    results_path = os.path.join(version_dir, 'evaluation_results.json')
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"✓ Results saved to: {results_path}")
    print(f"\n{'='*80}")
    print(f"EVALUATION COMPLETE FOR {version}")
    print(f"{'='*80}\n")
    
    return results

In [None]:
results = evaluate_model(version='v10_perceptual_loss', data_dir='./preprocessed_data', model_dir='./trained_models_v10')