In [1]:
!pip install pydicom numpy scikit-image pillow scipy SimpleITK


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m26.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
"""
Medical Image Training - v7_attention (CBAM Attention Mechanisms)
Configuration: 64 LapSRN channels, 5 blocks | 128 DRRN channels, 25 blocks | LeakyReLU | CBAM Attention
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from tqdm import tqdm
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')

# ==============================================================================
# CONFIGURATION
# ==============================================================================

class Config:
    VERSION = 'v7_attention'
    DATA_DIR = './preprocessed_data'
    SAVE_DIR = './trained_models_v7'
    
    EPOCHS_SR = 50
    EPOCHS_CLASS = 30
    BATCH_SIZE = 16
    LEARNING_RATE = 1e-4
    
    LAPSRN_SCALE = 4
    DRRN_SCALE = 2
    TOTAL_SCALE = 8
    
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # v7_attention specific - ADDING CBAM ATTENTION
    LAPSRN_CHANNELS = 64
    LAPSRN_BLOCKS = 5
    DRRN_CHANNELS = 128
    DRRN_BLOCKS = 25
    KERNEL_SIZE = 3
    ACTIVATION = 'leaky'
    BACKBONE = 'resnet50'
    USE_ATTENTION = True  # NEW: Enable CBAM attention


# ==============================================================================
# DATASETS
# ==============================================================================

class SuperResolutionDataset(Dataset):
    def __init__(self, preprocessed_data_dir, hr_patch_size=64, scale_factor=4):
        self.hr_patch_size = hr_patch_size
        self.lr_patch_size = hr_patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files) * 4
    
    def __getitem__(self, idx):
        img_idx = idx // 4
        img_path = self.image_files[img_idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = np.random.randint(0, max(1, h - self.hr_patch_size + 1))
        left = np.random.randint(0, max(1, w - self.hr_patch_size + 1))
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class DRRNDataset(Dataset):
    def __init__(self, preprocessed_data_dir, patch_size=64, scale_factor=2):
        self.hr_patch_size = patch_size
        self.lr_patch_size = patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files) * 4
    
    def __getitem__(self, idx):
        img_idx = idx // 4
        img_path = self.image_files[img_idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = np.random.randint(0, max(1, h - self.hr_patch_size + 1))
        left = np.random.randint(0, max(1, w - self.hr_patch_size + 1))
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class ClassificationDataset(Dataset):
    def __init__(self, preprocessed_data_dir, enhance_size=224):
        self.enhance_size = enhance_size
        self.data = []
        
        category_map = {'Normal': 0, 'Ischemia': 1, 'Bleeding': 2}
        urgency_map = {'Normal': 0.1, 'Ischemia': 0.7, 'Bleeding': 0.95}
        
        for category, label in category_map.items():
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.data.append({
                            'path': os.path.join(category_path, filename),
                            'label': label,
                            'urgency': urgency_map[category]
                        })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        img = Image.open(sample['path']).convert('L')
        img = img.resize((self.enhance_size, self.enhance_size), Image.BICUBIC)
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array.copy()).unsqueeze(0).float()
        
        return img_tensor, sample['label'], sample['urgency']


# ==============================================================================
# ATTENTION MODULES (CBAM)
# ==============================================================================

class ChannelAttention(nn.Module):
    """Channel Attention Module"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    """Spatial Attention Module"""
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)


class CBAM(nn.Module):
    """Convolutional Block Attention Module"""
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        self.channel_attention = ChannelAttention(channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)
    
    def forward(self, x):
        out = x * self.channel_attention(x)
        out = out * self.spatial_attention(out)
        return out


# ==============================================================================
# BUILDING BLOCKS WITH ATTENTION
# ==============================================================================

class ResidualBlock(nn.Module):
    """Residual block with CBAM attention"""
    def __init__(self, channels, kernel_size=3, use_attention=True):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
        
        self.use_attention = use_attention
        if use_attention:
            self.attention = CBAM(channels)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        
        if self.use_attention:
            out = self.attention(out)
        
        return self.activation(out + residual)


class RecursiveBlock(nn.Module):
    """Recursive block with CBAM attention"""
    def __init__(self, channels, kernel_size=3, use_attention=True):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
        
        self.use_attention = use_attention
        if use_attention:
            self.attention = CBAM(channels)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.activation(self.conv2(out))
        
        if self.use_attention:
            out = self.attention(out)
        
        return out + residual


# ==============================================================================
# MODELS - v7_attention
# ==============================================================================

class LapSRN(nn.Module):
    """v7_attention: LapSRN with CBAM attention"""
    def __init__(self, scale_factor=4, num_channels=1, use_attention=True):
        super().__init__()
        self.scale_factor = scale_factor
        self.num_levels = 2  # 2x2 = 4x
        ch = 64
        
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, ch, 3, padding=1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.pyramid_levels = nn.ModuleList()
        self.image_reconstruction = nn.ModuleList()
        
        for _ in range(self.num_levels):
            layers = []
            for _ in range(5):
                layers.append(ResidualBlock(ch, 3, use_attention))
            layers.append(nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.2, True))
            
            self.pyramid_levels.append(nn.Sequential(*layers))
            self.image_reconstruction.append(nn.Conv2d(ch, num_channels, 3, padding=1))
    
    def forward(self, x):
        features = self.feature_extraction(x)
        outputs = []
        current_features = features
        
        for level_idx in range(self.num_levels):
            current_features = self.pyramid_levels[level_idx](current_features)
            img_out = self.image_reconstruction[level_idx](current_features)
            
            if level_idx > 0:
                img_out = img_out + F.interpolate(outputs[-1], scale_factor=2, mode='bilinear', align_corners=False)
            else:
                img_out = img_out + F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            
            outputs.append(img_out)
        
        return outputs[-1], outputs


class DRRN(nn.Module):
    """v7_attention: DRRN with CBAM attention"""
    def __init__(self, num_channels=1, scale_factor=2, use_attention=True):
        super().__init__()
        self.scale_factor = scale_factor
        ch = 128
        
        self.input_conv = nn.Conv2d(num_channels, ch, 3, padding=1)
        
        self.recursive_blocks = nn.ModuleList()
        for _ in range(25):
            self.recursive_blocks.append(RecursiveBlock(ch, 3, use_attention))
        
        self.fusion = nn.Sequential(
            nn.Conv2d(ch * 3, ch, 1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.upsample = nn.Sequential(
            nn.Conv2d(ch, ch * 4, 3, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True)
        )
        
        self.output_conv = nn.Sequential(
            nn.Conv2d(ch, 64, 3, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, num_channels, 3, padding=1)
        )
    
    def forward(self, x):
        input_upsampled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
        
        features = self.input_conv(x)
        multi_scale_features = []
        current = features
        
        collect_indices = [8, 16, 24]
        
        for idx, block in enumerate(self.recursive_blocks):
            current = block(current)
            if idx in collect_indices:
                multi_scale_features.append(current)
        
        fused = torch.cat(multi_scale_features, dim=1)
        fused = self.fusion(fused)
        upsampled = self.upsample(fused)
        output = self.output_conv(upsampled)
        
        return output + input_upsampled


class MedicalImageClassifier(nn.Module):
    """v7_attention: Standard ResNet50 classifier (same as baseline)"""
    def __init__(self, num_classes=3):
        super().__init__()
        
        from torchvision import models
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        self.classification_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self.urgency_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.feature_head = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classification_head(features), self.urgency_head(features), self.feature_head(features)


# ==============================================================================
# TRAINING
# ==============================================================================

def train_model():
    config = Config()
    
    print(f"\n{'='*80}")
    print(f"TRAINING {config.VERSION.upper()}")
    print(f"{'='*80}")
    print(f"Configuration:")
    print(f"  - LapSRN: {config.LAPSRN_CHANNELS} channels, {config.LAPSRN_BLOCKS} blocks")
    print(f"  - DRRN: {config.DRRN_CHANNELS} channels, {config.DRRN_BLOCKS} blocks")
    print(f"  - Kernel: {config.KERNEL_SIZE}x{config.KERNEL_SIZE}")
    print(f"  - Activation: {config.ACTIVATION}")
    print(f"  - Backbone: {config.BACKBONE.upper()}")
    print(f"  - Attention: CBAM (Channel + Spatial)")
    print(f"  - Device: {config.DEVICE}")
    print(f"{'='*80}\n")
    
    version_save_dir = os.path.join(config.SAVE_DIR, config.VERSION)
    os.makedirs(version_save_dir, exist_ok=True)
    
    # Initialize models
    lapsrn = LapSRN(use_attention=config.USE_ATTENTION).to(config.DEVICE)
    drrn = DRRN(use_attention=config.USE_ATTENTION).to(config.DEVICE)
    classifier = MedicalImageClassifier().to(config.DEVICE)
    
    print(f"\nAttention mechanism info:")
    print(f"  - Type: CBAM (Convolutional Block Attention Module)")
    print(f"  - Channel Attention: Adaptive avg/max pooling + FC layers")
    print(f"  - Spatial Attention: 7x7 convolution on channel-pooled features")
    print(f"  - Applied after: Every residual/recursive block")
    
    # Create datasets
    sr_dataset = SuperResolutionDataset(config.DATA_DIR, hr_patch_size=64, scale_factor=4)
    drrn_dataset = DRRNDataset(config.DATA_DIR, patch_size=64, scale_factor=2)
    class_dataset = ClassificationDataset(config.DATA_DIR, enhance_size=224)
    
    # Split datasets (80/20)
    train_sr, val_sr = torch.utils.data.random_split(sr_dataset, 
        [int(0.8*len(sr_dataset)), len(sr_dataset)-int(0.8*len(sr_dataset))])
    train_drrn, val_drrn = torch.utils.data.random_split(drrn_dataset,
        [int(0.8*len(drrn_dataset)), len(drrn_dataset)-int(0.8*len(drrn_dataset))])
    train_class, val_class = torch.utils.data.random_split(class_dataset,
        [int(0.8*len(class_dataset)), len(class_dataset)-int(0.8*len(class_dataset))])
    
    # DataLoaders
    train_sr_loader = DataLoader(train_sr, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    train_drrn_loader = DataLoader(train_drrn, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    train_class_loader = DataLoader(train_class, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    
    # Train LapSRN
    print("\n" + "="*80)
    print("[1/3] Training LapSRN (16x16 → 64x64, 4x upsampling) + CBAM")
    print("="*80)
    
    optimizer = optim.Adam(lapsrn.parameters(), lr=config.LEARNING_RATE)
    criterion = nn.L1Loss()
    best_loss = float('inf')
    
    for epoch in range(config.EPOCHS_SR):
        lapsrn.train()
        train_loss = 0
        pbar = tqdm(train_sr_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_SR}')
        
        for lr_imgs, hr_imgs in pbar:
            lr_imgs, hr_imgs = lr_imgs.to(config.DEVICE), hr_imgs.to(config.DEVICE)
            optimizer.zero_grad()
            sr_output, _ = lapsrn(lr_imgs)
            loss = criterion(sr_output, hr_imgs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})
        
        avg_loss = train_loss / len(train_sr_loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(lapsrn.state_dict(), os.path.join(version_save_dir, 'lapsrn_best.pth'))
    
    print(f"✓ LapSRN training complete (best loss: {best_loss:.6f})")
    
    # Train DRRN
    print("\n" + "="*80)
    print("[2/3] Training DRRN (64x64 → 128x128, 2x upsampling) + CBAM")
    print("="*80)
    
    optimizer = optim.Adam(drrn.parameters(), lr=config.LEARNING_RATE)
    best_loss = float('inf')
    
    for epoch in range(config.EPOCHS_SR):
        drrn.train()
        train_loss = 0
        pbar = tqdm(train_drrn_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_SR}')
        
        for lr_imgs, hr_imgs in pbar:
            lr_imgs, hr_imgs = lr_imgs.to(config.DEVICE), hr_imgs.to(config.DEVICE)
            optimizer.zero_grad()
            sr_output = drrn(lr_imgs)
            loss = criterion(sr_output, hr_imgs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})
        
        avg_loss = train_loss / len(train_drrn_loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(drrn.state_dict(), os.path.join(version_save_dir, 'drrn_best.pth'))
    
    print(f"✓ DRRN training complete (best loss: {best_loss:.6f})")
    
    # Train Classifier
    print("\n" + "="*80)
    print("[3/3] Training Classifier (128x128 → 224x224 → Classification)")
    print("="*80)
    
    optimizer = optim.Adam(classifier.parameters(), lr=config.LEARNING_RATE)
    class_criterion = nn.CrossEntropyLoss()
    urgency_criterion = nn.BCELoss()
    best_acc = 0.0
    
    for epoch in range(config.EPOCHS_CLASS):
        classifier.train()
        correct, total = 0, 0
        pbar = tqdm(train_class_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_CLASS}')
        
        for images, labels, urgency in pbar:
            images = images.to(config.DEVICE)
            labels = labels.to(config.DEVICE)
            urgency = urgency.to(config.DEVICE).unsqueeze(1).float()
            
            optimizer.zero_grad()
            class_out, urgency_out, _ = classifier(images)
            loss = class_criterion(class_out, labels) + 0.5 * urgency_criterion(urgency_out, urgency)
            loss.backward()
            optimizer.step()
            
            _, predicted = torch.max(class_out, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            pbar.set_postfix({'acc': f'{100*correct/total:.2f}%'})
        
        acc = 100 * correct / total
        if acc > best_acc:
            best_acc = acc
            torch.save(classifier.state_dict(), os.path.join(version_save_dir, 'classifier_best.pth'))
    
    print(f"✓ Classifier training complete (best accuracy: {best_acc:.2f}%)")
    
    # Save configuration
    config_dict = {
        'version': config.VERSION,
        'lapsrn_channels': config.LAPSRN_CHANNELS,
        'lapsrn_blocks': config.LAPSRN_BLOCKS,
        'drrn_channels': config.DRRN_CHANNELS,
        'drrn_blocks': config.DRRN_BLOCKS,
        'kernel_size': config.KERNEL_SIZE,
        'activation': config.ACTIVATION,
        'backbone': config.BACKBONE,
        'use_attention': config.USE_ATTENTION,
        'epochs_sr': config.EPOCHS_SR,
        'epochs_class': config.EPOCHS_CLASS,
        'timestamp': datetime.now().isoformat(),
        'notes': 'Added CBAM attention (channel + spatial) after every residual/recursive block'
    }
    
    with open(os.path.join(version_save_dir, 'config.json'), 'w') as f:
        json.dump(config_dict, f, indent=2)
    
    print(f"\n{'='*80}")
    print("✓ ALL TRAINING COMPLETE!")
    print(f"{'='*80}")
    print(f"Models saved to: {version_save_dir}")
    print("\nPipeline: 16x16 → LapSRN(4x) → 64x64 → DRRN(2x) → 128x128 → Classifier(224x224)")
    print("\nKey difference from v1_baseline:")
    print("  - Added CBAM attention modules after each residual/recursive block")
    print("  - Channel attention: Learns which features to emphasize")
    print("  - Spatial attention: Learns where to focus spatially")
    print("  - Expected: Better feature selection, improved detail preservation")


if __name__ == "__main__":
    train_model()


TRAINING V7_ATTENTION
Configuration:
  - LapSRN: 64 channels, 5 blocks
  - DRRN: 128 channels, 25 blocks
  - Kernel: 3x3
  - Activation: leaky
  - Backbone: RESNET50
  - Attention: CBAM (Channel + Spatial)
  - Device: cuda


Attention mechanism info:
  - Type: CBAM (Convolutional Block Attention Module)
  - Channel Attention: Adaptive avg/max pooling + FC layers
  - Spatial Attention: 7x7 convolution on channel-pooled features
  - Applied after: Every residual/recursive block

[1/3] Training LapSRN (16x16 → 64x64, 4x upsampling) + CBAM


Epoch 1/50: 100% 1328/1328 [00:40<00:00, 32.90it/s, loss=0.024421]
Epoch 2/50: 100% 1328/1328 [00:36<00:00, 36.02it/s, loss=0.009047]
Epoch 3/50: 100% 1328/1328 [00:36<00:00, 36.02it/s, loss=0.014735]
Epoch 4/50: 100% 1328/1328 [00:36<00:00, 36.19it/s, loss=0.013262]
Epoch 5/50: 100% 1328/1328 [00:36<00:00, 36.26it/s, loss=0.000074]
Epoch 6/50: 100% 1328/1328 [00:36<00:00, 36.24it/s, loss=0.000500]
Epoch 7/50: 100% 1328/1328 [00:37<00:00, 35.88it/s, loss=0.001071]
Epoch 8/50: 100% 1328/1328 [00:36<00:00, 35.97it/s, loss=0.007731]
Epoch 9/50: 100% 1328/1328 [00:36<00:00, 36.28it/s, loss=0.018475]
Epoch 10/50: 100% 1328/1328 [00:36<00:00, 35.99it/s, loss=0.022867]
Epoch 11/50: 100% 1328/1328 [00:37<00:00, 35.72it/s, loss=0.011397]
Epoch 12/50: 100% 1328/1328 [00:37<00:00, 35.50it/s, loss=0.008589]
Epoch 13/50: 100% 1328/1328 [00:37<00:00, 35.50it/s, loss=0.005311]
Epoch 14/50: 100% 1328/1328 [00:37<00:00, 35.27it/s, loss=0.012764]
Epoch 15/50: 100% 1328/1328 [00:37<00:00, 35.38it/s, loss

✓ LapSRN training complete (best loss: 0.009958)

[2/3] Training DRRN (64x64 → 128x128, 2x upsampling) + CBAM


Epoch 1/50: 100% 1328/1328 [04:12<00:00,  5.27it/s, loss=0.004676]
Epoch 2/50: 100% 1328/1328 [04:43<00:00,  4.68it/s, loss=0.000093]
Epoch 3/50: 100% 1328/1328 [04:45<00:00,  4.66it/s, loss=0.000213]
Epoch 4/50: 100% 1328/1328 [04:45<00:00,  4.65it/s, loss=0.005010]
Epoch 5/50: 100% 1328/1328 [04:45<00:00,  4.65it/s, loss=0.008265]
Epoch 6/50: 100% 1328/1328 [04:45<00:00,  4.66it/s, loss=0.001057]
Epoch 7/50: 100% 1328/1328 [04:44<00:00,  4.66it/s, loss=0.001676]
Epoch 8/50: 100% 1328/1328 [04:45<00:00,  4.65it/s, loss=0.001375]
Epoch 9/50: 100% 1328/1328 [04:45<00:00,  4.65it/s, loss=0.004319]
Epoch 10/50: 100% 1328/1328 [04:45<00:00,  4.65it/s, loss=0.003074]
Epoch 11/50: 100% 1328/1328 [04:45<00:00,  4.65it/s, loss=0.002373]
Epoch 12/50: 100% 1328/1328 [04:45<00:00,  4.65it/s, loss=0.003004]
Epoch 13/50: 100% 1328/1328 [04:45<00:00,  4.66it/s, loss=0.001025]
Epoch 14/50: 100% 1328/1328 [04:45<00:00,  4.65it/s, loss=0.003516]
Epoch 15/50: 100% 1328/1328 [04:44<00:00,  4.66it/s, loss

✓ DRRN training complete (best loss: 0.002384)

[3/3] Training Classifier (128x128 → 224x224 → Classification)


Epoch 1/30:   0% 0/332 [00:00<?, ?it/s]MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvBinWinogradRxSf3x2; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvHipImplicitGemmGroupBwdXdlops; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xW
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvHipImplicitGemmGroupWrwXdlops; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xW
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC; key: 2x512x7x7x1x3x3x1x512x16x1x1x0x2x2x0x1x1x0x0x1

In [None]:
"""
Model Evaluation Script - v7_attention (Notebook-Friendly)
Calculates PSNR, SSIM, Accuracy, and Classification Metrics

Usage in Jupyter Notebook:
    from evaluate_v7_attention import evaluate_model
    
    results = evaluate_model(
        version='v7_attention',
        data_dir='./preprocessed_data',
        model_dir='./trained_models_v7'
    )
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from tqdm import tqdm
import json
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# ==============================================================================
# DATASETS
# ==============================================================================

class SuperResolutionDataset(Dataset):
    def __init__(self, preprocessed_data_dir, hr_patch_size=64, scale_factor=4):
        self.hr_patch_size = hr_patch_size
        self.lr_patch_size = hr_patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = (h - self.hr_patch_size) // 2
        left = (w - self.hr_patch_size) // 2
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class DRRNDataset(Dataset):
    def __init__(self, preprocessed_data_dir, patch_size=64, scale_factor=2):
        self.hr_patch_size = patch_size
        self.lr_patch_size = patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = (h - self.hr_patch_size) // 2
        left = (w - self.hr_patch_size) // 2
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class ClassificationDataset(Dataset):
    def __init__(self, preprocessed_data_dir, enhance_size=224):
        self.enhance_size = enhance_size
        self.data = []
        
        category_map = {'Normal': 0, 'Ischemia': 1, 'Bleeding': 2}
        urgency_map = {'Normal': 0.1, 'Ischemia': 0.7, 'Bleeding': 0.95}
        
        for category, label in category_map.items():
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.data.append({
                            'path': os.path.join(category_path, filename),
                            'label': label,
                            'urgency': urgency_map[category]
                        })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        img = Image.open(sample['path']).convert('L')
        img = img.resize((self.enhance_size, self.enhance_size), Image.BICUBIC)
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array.copy()).unsqueeze(0).float()
        
        return img_tensor, sample['label'], sample['urgency']


# ==============================================================================
# MODEL DEFINITIONS - v7_attention
# ==============================================================================

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return self.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(out))


class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        self.channel_attention = ChannelAttention(channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)
    
    def forward(self, x):
        out = x * self.channel_attention(x)
        return out * self.spatial_attention(out)


class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, use_attention=True):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
        self.attention = CBAM(channels)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        out = self.attention(out)
        return self.activation(out + residual)


class RecursiveBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, use_attention=True):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
        self.attention = CBAM(channels)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.activation(self.conv2(out))
        out = self.attention(out)
        return out + residual


class LapSRN(nn.Module):
    def __init__(self, scale_factor=4, num_channels=1, use_attention=True):
        super().__init__()
        self.scale_factor = scale_factor
        self.num_levels = 2
        ch = 64
        
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, ch, 3, padding=1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.pyramid_levels = nn.ModuleList()
        self.image_reconstruction = nn.ModuleList()
        
        for _ in range(self.num_levels):
            layers = []
            for _ in range(5):
                layers.append(ResidualBlock(ch, 3, use_attention))
            layers.append(nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.2, True))
            
            self.pyramid_levels.append(nn.Sequential(*layers))
            self.image_reconstruction.append(nn.Conv2d(ch, num_channels, 3, padding=1))
    
    def forward(self, x):
        features = self.feature_extraction(x)
        outputs = []
        current_features = features
        
        for level_idx in range(self.num_levels):
            current_features = self.pyramid_levels[level_idx](current_features)
            img_out = self.image_reconstruction[level_idx](current_features)
            
            if level_idx > 0:
                img_out = img_out + F.interpolate(outputs[-1], scale_factor=2, mode='bilinear', align_corners=False)
            else:
                img_out = img_out + F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            
            outputs.append(img_out)
        
        return outputs[-1], outputs


class DRRN(nn.Module):
    def __init__(self, num_channels=1, scale_factor=2, use_attention=True):
        super().__init__()
        self.scale_factor = scale_factor
        ch = 128
        
        self.input_conv = nn.Conv2d(num_channels, ch, 3, padding=1)
        
        self.recursive_blocks = nn.ModuleList()
        for _ in range(25):
            self.recursive_blocks.append(RecursiveBlock(ch, 3, use_attention))
        
        self.fusion = nn.Sequential(
            nn.Conv2d(ch * 3, ch, 1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.upsample = nn.Sequential(
            nn.Conv2d(ch, ch * 4, 3, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True)
        )
        
        self.output_conv = nn.Sequential(
            nn.Conv2d(ch, 64, 3, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, num_channels, 3, padding=1)
        )
    
    def forward(self, x):
        input_upsampled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
        
        features = self.input_conv(x)
        multi_scale_features = []
        current = features
        
        collect_indices = [8, 16, 24]
        
        for idx, block in enumerate(self.recursive_blocks):
            current = block(current)
            if idx in collect_indices:
                multi_scale_features.append(current)
        
        fused = torch.cat(multi_scale_features, dim=1)
        fused = self.fusion(fused)
        upsampled = self.upsample(fused)
        output = self.output_conv(upsampled)
        
        return output + input_upsampled


class MedicalImageClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        
        from torchvision import models
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        self.classification_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self.urgency_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.feature_head = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classification_head(features), self.urgency_head(features), self.feature_head(features)


# ==============================================================================
# EVALUATION
# ==============================================================================

def evaluate_model(version='v7_attention', data_dir='./preprocessed_data', 
                   model_dir='./trained_models_v7',
                   device='cuda' if torch.cuda.is_available() else 'cpu'):
    
    print(f"\n{'='*80}")
    print(f"EVALUATING MODEL: {version}")
    print(f"{'='*80}\n")
    
    version_dir = os.path.join(model_dir, version)
    config_path = os.path.join(version_dir, 'config.json')
    
    if not os.path.exists(config_path):
        print(f"ERROR: Config file not found at {config_path}")
        return None
    
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    print(f"Loaded configuration for {version}")
    print(f"  LapSRN: {config.get('lapsrn_channels', 64)} channels, {config.get('lapsrn_blocks', 5)} blocks")
    print(f"  DRRN: {config.get('drrn_channels', 128)} channels, {config.get('drrn_blocks', 25)} blocks")
    print(f"  Backbone: {config.get('backbone', 'resnet50')}")
    print(f"  Attention: {config.get('use_attention', True)}\n")
    
    # Datasets
    print("Creating datasets...")
    sr_dataset = SuperResolutionDataset(data_dir, hr_patch_size=64, scale_factor=4)
    drrn_dataset = DRRNDataset(data_dir, patch_size=64, scale_factor=2)
    class_dataset = ClassificationDataset(data_dir, enhance_size=224)
    
    sr_loader = DataLoader(sr_dataset, batch_size=16, shuffle=False, num_workers=2)
    drrn_loader = DataLoader(drrn_dataset, batch_size=16, shuffle=False, num_workers=2)
    class_loader = DataLoader(class_dataset, batch_size=16, shuffle=False, num_workers=2)
    
    print(f"  SR dataset: {len(sr_dataset)} samples")
    print(f"  DRRN dataset: {len(drrn_dataset)} samples")
    print(f"  Classification dataset: {len(class_dataset)} samples\n")
    
    # Load models
    print("Loading models...")
    lapsrn = LapSRN(use_attention=True).to(device)
    drrn = DRRN(use_attention=True).to(device)
    classifier = MedicalImageClassifier().to(device)
    
    lapsrn.load_state_dict(torch.load(os.path.join(version_dir, 'lapsrn_best.pth'), map_location=device))
    drrn.load_state_dict(torch.load(os.path.join(version_dir, 'drrn_best.pth'), map_location=device))
    classifier.load_state_dict(torch.load(os.path.join(version_dir, 'classifier_best.pth'), map_location=device))
    
    lapsrn.eval()
    drrn.eval()
    classifier.eval()
    print("✓ Models loaded successfully\n")
    
    # Evaluate LapSRN
    print("="*80)
    print("[1/3] Evaluating LapSRN (16x16 → 64x64)")
    print("="*80)
    
    lapsrn_psnr_list, lapsrn_ssim_list = [], []
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(sr_loader, desc="LapSRN Evaluation"):
            lr_imgs = lr_imgs.to(device)
            sr_output, _ = lapsrn(lr_imgs)
            
            sr_np = sr_output.cpu().numpy()
            hr_np = hr_imgs.numpy()
            
            for i in range(sr_np.shape[0]):
                sr_img = np.clip(sr_np[i, 0], 0, 1)
                hr_img = np.clip(hr_np[i, 0], 0, 1)
                lapsrn_psnr_list.append(psnr(hr_img, sr_img, data_range=1.0))
                lapsrn_ssim_list.append(ssim(hr_img, sr_img, data_range=1.0))
    
    lapsrn_psnr_mean = np.mean(lapsrn_psnr_list)
    lapsrn_ssim_mean = np.mean(lapsrn_ssim_list)
    
    print(f"\n✓ LapSRN Results:")
    print(f"  PSNR: {lapsrn_psnr_mean:.4f} dB")
    print(f"  SSIM: {lapsrn_ssim_mean:.4f}\n")
    
    # Evaluate DRRN
    print("="*80)
    print("[2/3] Evaluating DRRN (64x64 → 128x128)")
    print("="*80)
    
    drrn_psnr_list, drrn_ssim_list = [], []
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(drrn_loader, desc="DRRN Evaluation"):
            lr_imgs = lr_imgs.to(device)
            sr_output = drrn(lr_imgs)
            
            sr_np = sr_output.cpu().numpy()
            hr_np = hr_imgs.numpy()
            
            for i in range(sr_np.shape[0]):
                sr_img = np.clip(sr_np[i, 0], 0, 1)
                hr_img = np.clip(hr_np[i, 0], 0, 1)
                drrn_psnr_list.append(psnr(hr_img, sr_img, data_range=1.0))
                drrn_ssim_list.append(ssim(hr_img, sr_img, data_range=1.0))
    
    drrn_psnr_mean = np.mean(drrn_psnr_list)
    drrn_ssim_mean = np.mean(drrn_ssim_list)
    
    print(f"\n✓ DRRN Results:")
    print(f"  PSNR: {drrn_psnr_mean:.4f} dB")
    print(f"  SSIM: {drrn_ssim_mean:.4f}\n")
    
    # Evaluate Classifier
    print("="*80)
    print("[3/3] Evaluating Classifier")
    print("="*80)
    
    all_preds, all_labels, all_urgency_preds, all_urgency_true = [], [], [], []
    
    with torch.no_grad():
        for images, labels, urgency in tqdm(class_loader, desc="Classifier Evaluation"):
            images = images.to(device)
            class_out, urgency_out, _ = classifier(images)
            _, predicted = torch.max(class_out, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_urgency_preds.extend(urgency_out.cpu().numpy().flatten())
            all_urgency_true.extend(urgency.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds) * 100
    conf_matrix = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, 
                                        target_names=['Normal', 'Ischemia', 'Bleeding'], output_dict=True)
    urgency_mse = np.mean((np.array(all_urgency_preds) - np.array(all_urgency_true))**2)
    urgency_mae = np.mean(np.abs(np.array(all_urgency_preds) - np.array(all_urgency_true)))
    
    print(f"\n✓ Classification Results:")
    print(f"  Accuracy: {accuracy:.2f}%")
    print(f"\n  Confusion Matrix:")
    print(f"  {conf_matrix}")
    print(f"\n  Per-Class Metrics:")
    for class_name in ['Normal', 'Ischemia', 'Bleeding']:
        metrics = class_report[class_name]
        print(f"    {class_name}:")
        print(f"      Precision: {metrics['precision']:.4f}")
        print(f"      Recall: {metrics['recall']:.4f}")
        print(f"      F1-Score: {metrics['f1-score']:.4f}")
    
    print(f"\n  Urgency Prediction:")
    print(f"    MSE: {urgency_mse:.4f}")
    print(f"    MAE: {urgency_mae:.4f}\n")
    
    # Save results
    results = {
        'version': version,
        'lapsrn': { 'psnr': float(lapsrn_psnr_mean), 'ssim': float(lapsrn_ssim_mean) },
        'drrn': { 'psnr': float(drrn_psnr_mean), 'ssim': float(drrn_ssim_mean) },
        'classifier': {
            'accuracy': float(accuracy),
            'confusion_matrix': conf_matrix.tolist(),
            'classification_report': class_report,
            'urgency_mse': float(urgency_mse),
            'urgency_mae': float(urgency_mae)
        }
    }
    
    results_path = os.path.join(version_dir, 'evaluation_results.json')
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"✓ Results saved to: {results_path}")
    print(f"\n{'='*80}")
    print(f"EVALUATION COMPLETE FOR {version}")
    print(f"{'='*80}\n")
    
    return results

In [None]:
results = evaluate_model(version='v7_attention', data_dir='./preprocessed_data', model_dir='./trained_models_v7')