In [1]:
!pip install pydicom numpy scikit-image pillow scipy SimpleITK


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m26.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [5]:
"""
Medical Image Training - v9_dense_connections (Dense Connections)
Configuration: 64 LapSRN channels, 5 blocks | 128 DRRN channels, 25 blocks | LeakyReLU | Dense Connections
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from tqdm import tqdm
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')

# ==============================================================================
# CONFIGURATION
# ==============================================================================

class Config:
    VERSION = 'v9_dense_connections'
    DATA_DIR = './preprocessed_data'
    SAVE_DIR = './trained_models_v9'
    
    EPOCHS_SR = 50
    EPOCHS_CLASS = 30
    BATCH_SIZE = 16
    LEARNING_RATE = 1e-4
    
    LAPSRN_SCALE = 4
    DRRN_SCALE = 2
    TOTAL_SCALE = 8
    
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # v9_dense_connections specific - ADDING DENSE CONNECTIONS
    LAPSRN_CHANNELS = 64
    LAPSRN_BLOCKS = 5
    DRRN_CHANNELS = 128
    DRRN_BLOCKS = 25
    KERNEL_SIZE = 3
    ACTIVATION = 'leaky'
    BACKBONE = 'resnet50'
    USE_DENSE = True  # NEW: Enable dense connections


# ==============================================================================
# DATASETS
# ==============================================================================

class SuperResolutionDataset(Dataset):
    def __init__(self, preprocessed_data_dir, hr_patch_size=64, scale_factor=4):
        self.hr_patch_size = hr_patch_size
        self.lr_patch_size = hr_patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files) * 4
    
    def __getitem__(self, idx):
        img_idx = idx // 4
        img_path = self.image_files[img_idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = np.random.randint(0, max(1, h - self.hr_patch_size + 1))
        left = np.random.randint(0, max(1, w - self.hr_patch_size + 1))
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class DRRNDataset(Dataset):
    def __init__(self, preprocessed_data_dir, patch_size=64, scale_factor=2):
        self.hr_patch_size = patch_size
        self.lr_patch_size = patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files) * 4
    
    def __getitem__(self, idx):
        img_idx = idx // 4
        img_path = self.image_files[img_idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = np.random.randint(0, max(1, h - self.hr_patch_size + 1))
        left = np.random.randint(0, max(1, w - self.hr_patch_size + 1))
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class ClassificationDataset(Dataset):
    def __init__(self, preprocessed_data_dir, enhance_size=224):
        self.enhance_size = enhance_size
        self.data = []
        
        category_map = {'Normal': 0, 'Ischemia': 1, 'Bleeding': 2}
        urgency_map = {'Normal': 0.1, 'Ischemia': 0.7, 'Bleeding': 0.95}
        
        for category, label in category_map.items():
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.data.append({
                            'path': os.path.join(category_path, filename),
                            'label': label,
                            'urgency': urgency_map[category]
                        })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        img = Image.open(sample['path']).convert('L')
        img = img.resize((self.enhance_size, self.enhance_size), Image.BICUBIC)
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array.copy()).unsqueeze(0).float()
        
        return img_tensor, sample['label'], sample['urgency']


# ==============================================================================
# BUILDING BLOCKS WITH DENSE CONNECTIONS
# ==============================================================================

class DenseBlock(nn.Module):
    """Dense block - each conv receives all previous outputs as input"""
    def __init__(self, channels, kernel_size=3, num_layers=4, growth_rate=16):
        super().__init__()
        padding = kernel_size // 2
        self.num_layers = num_layers
        self.growth_rate = growth_rate
        
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_channels = channels + i * growth_rate
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_channels, growth_rate, kernel_size, padding=padding),
                nn.LeakyReLU(0.2, True)
            ))
        
        # 1x1 conv to compress back to original channel count
        total_channels = channels + num_layers * growth_rate
        self.compress = nn.Conv2d(total_channels, channels, 1)
    
    def forward(self, x):
        features = [x]
        for layer in self.layers:
            concat = torch.cat(features, dim=1)
            out = layer(concat)
            features.append(out)
        
        # Concatenate all features
        dense_out = torch.cat(features, dim=1)
        # Compress back to original channels
        out = self.compress(dense_out)
        # Global residual
        return out + x


class ResidualBlock(nn.Module):
    """Standard residual block (used outside dense blocks)"""
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        return self.activation(out + residual)


class RecursiveBlock(nn.Module):
    """Standard recursive block (used outside dense blocks)"""
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.activation(self.conv2(out))
        return out + residual


# ==============================================================================
# MODELS - v9_dense_connections
# ==============================================================================

class LapSRN(nn.Module):
    """v9_dense_connections: LapSRN with Dense Blocks replacing residual blocks"""
    def __init__(self, scale_factor=4, num_channels=1, use_dense=True):
        super().__init__()
        self.scale_factor = scale_factor
        self.num_levels = 2  # 2x2 = 4x
        ch = 64
        
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, ch, 3, padding=1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.pyramid_levels = nn.ModuleList()
        self.image_reconstruction = nn.ModuleList()
        
        for _ in range(self.num_levels):
            layers = []
            if use_dense:
                # Use dense blocks (each dense block replaces multiple residual blocks)
                layers.append(DenseBlock(ch, kernel_size=3, num_layers=4, growth_rate=16))
                layers.append(DenseBlock(ch, kernel_size=3, num_layers=4, growth_rate=16))
            else:
                for _ in range(5):
                    layers.append(ResidualBlock(ch, 3))
            layers.append(nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.2, True))
            
            self.pyramid_levels.append(nn.Sequential(*layers))
            self.image_reconstruction.append(nn.Conv2d(ch, num_channels, 3, padding=1))
    
    def forward(self, x):
        features = self.feature_extraction(x)
        outputs = []
        current_features = features
        
        for level_idx in range(self.num_levels):
            current_features = self.pyramid_levels[level_idx](current_features)
            img_out = self.image_reconstruction[level_idx](current_features)
            
            if level_idx > 0:
                img_out = img_out + F.interpolate(outputs[-1], scale_factor=2, mode='bilinear', align_corners=False)
            else:
                img_out = img_out + F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            
            outputs.append(img_out)
        
        return outputs[-1], outputs


class DRRN(nn.Module):
    """v9_dense_connections: DRRN with Dense Blocks at multi-scale collection points"""
    def __init__(self, num_channels=1, scale_factor=2, use_dense=True):
        super().__init__()
        self.scale_factor = scale_factor
        ch = 128
        
        self.input_conv = nn.Conv2d(num_channels, ch, 3, padding=1)
        
        # Standard recursive blocks
        self.recursive_blocks = nn.ModuleList()
        for _ in range(25):
            self.recursive_blocks.append(RecursiveBlock(ch, 3))
        
        # Dense blocks placed at collection points for richer feature fusion
        self.dense_blocks = nn.ModuleList()
        if use_dense:
            for _ in range(3):  # one per collection point
                self.dense_blocks.append(DenseBlock(ch, kernel_size=3, num_layers=4, growth_rate=32))
        
        self.use_dense = use_dense
        
        self.fusion = nn.Sequential(
            nn.Conv2d(ch * 3, ch, 1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.upsample = nn.Sequential(
            nn.Conv2d(ch, ch * 4, 3, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True)
        )
        
        self.output_conv = nn.Sequential(
            nn.Conv2d(ch, 64, 3, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, num_channels, 3, padding=1)
        )
    
    def forward(self, x):
        input_upsampled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
        
        features = self.input_conv(x)
        multi_scale_features = []
        current = features
        
        collect_indices = [8, 16, 24]
        dense_idx = 0
        
        for idx, block in enumerate(self.recursive_blocks):
            current = block(current)
            if idx in collect_indices:
                # Apply dense block before collecting the feature
                if self.use_dense:
                    current = self.dense_blocks[dense_idx](current)
                    dense_idx += 1
                multi_scale_features.append(current)
        
        fused = torch.cat(multi_scale_features, dim=1)
        fused = self.fusion(fused)
        upsampled = self.upsample(fused)
        output = self.output_conv(upsampled)
        
        return output + input_upsampled


class MedicalImageClassifier(nn.Module):
    """v9_dense_connections: Standard ResNet50 classifier (same as baseline)"""
    def __init__(self, num_classes=3):
        super().__init__()
        
        from torchvision import models
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        self.classification_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self.urgency_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.feature_head = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classification_head(features), self.urgency_head(features), self.feature_head(features)


# ==============================================================================
# TRAINING
# ==============================================================================

def train_model():
    config = Config()
    
    print(f"\n{'='*80}")
    print(f"TRAINING {config.VERSION.upper()}")
    print(f"{'='*80}")
    print(f"Configuration:")
    print(f"  - LapSRN: {config.LAPSRN_CHANNELS} channels, {config.LAPSRN_BLOCKS} blocks")
    print(f"  - DRRN: {config.DRRN_CHANNELS} channels, {config.DRRN_BLOCKS} blocks")
    print(f"  - Kernel: {config.KERNEL_SIZE}x{config.KERNEL_SIZE}")
    print(f"  - Activation: {config.ACTIVATION}")
    print(f"  - Backbone: {config.BACKBONE.upper()}")
    print(f"  - Dense Connections: Enabled")
    print(f"  - Device: {config.DEVICE}")
    print(f"{'='*80}\n")
    
    version_save_dir = os.path.join(config.SAVE_DIR, config.VERSION)
    os.makedirs(version_save_dir, exist_ok=True)
    
    # Initialize models
    lapsrn = LapSRN(use_dense=config.USE_DENSE).to(config.DEVICE)
    drrn = DRRN(use_dense=config.USE_DENSE).to(config.DEVICE)
    classifier = MedicalImageClassifier().to(config.DEVICE)
    
    print(f"\nDense Connection info:")
    print(f"  - Type: DenseBlock (DenseNet-style connections)")
    print(f"  - LapSRN: 2 DenseBlocks per pyramid level (growth_rate=16, num_layers=4)")
    print(f"  - DRRN: 3 DenseBlocks at multi-scale collection points (growth_rate=32, num_layers=4)")
    print(f"  - Each layer receives ALL previous layer outputs as input")
    print(f"  - 1x1 conv compresses concatenated features back to original channels")
    
    # Create datasets
    sr_dataset = SuperResolutionDataset(config.DATA_DIR, hr_patch_size=64, scale_factor=4)
    drrn_dataset = DRRNDataset(config.DATA_DIR, patch_size=64, scale_factor=2)
    class_dataset = ClassificationDataset(config.DATA_DIR, enhance_size=224)
    
    # Split datasets (80/20)
    train_sr, val_sr = torch.utils.data.random_split(sr_dataset, 
        [int(0.8*len(sr_dataset)), len(sr_dataset)-int(0.8*len(sr_dataset))])
    train_drrn, val_drrn = torch.utils.data.random_split(drrn_dataset,
        [int(0.8*len(drrn_dataset)), len(drrn_dataset)-int(0.8*len(drrn_dataset))])
    train_class, val_class = torch.utils.data.random_split(class_dataset,
        [int(0.8*len(class_dataset)), len(class_dataset)-int(0.8*len(class_dataset))])
    
    # DataLoaders
    train_sr_loader = DataLoader(train_sr, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    train_drrn_loader = DataLoader(train_drrn, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    train_class_loader = DataLoader(train_class, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    
    # Train LapSRN
    print("\n" + "="*80)
    print("[1/3] Training LapSRN (16x16 → 64x64, 4x upsampling) + Dense Blocks")
    print("="*80)
    
    optimizer = optim.Adam(lapsrn.parameters(), lr=config.LEARNING_RATE)
    criterion = nn.L1Loss()
    best_loss = float('inf')
    
    for epoch in range(config.EPOCHS_SR):
        lapsrn.train()
        train_loss = 0
        pbar = tqdm(train_sr_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_SR}')
        
        for lr_imgs, hr_imgs in pbar:
            lr_imgs, hr_imgs = lr_imgs.to(config.DEVICE), hr_imgs.to(config.DEVICE)
            optimizer.zero_grad()
            sr_output, _ = lapsrn(lr_imgs)
            loss = criterion(sr_output, hr_imgs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})
        
        avg_loss = train_loss / len(train_sr_loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(lapsrn.state_dict(), os.path.join(version_save_dir, 'lapsrn_best.pth'))
    
    print(f"✓ LapSRN training complete (best loss: {best_loss:.6f})")
    
    # Train DRRN
    print("\n" + "="*80)
    print("[2/3] Training DRRN (64x64 → 128x128, 2x upsampling) + Dense Blocks")
    print("="*80)
    
    optimizer = optim.Adam(drrn.parameters(), lr=config.LEARNING_RATE)
    best_loss = float('inf')
    
    for epoch in range(config.EPOCHS_SR):
        drrn.train()
        train_loss = 0
        pbar = tqdm(train_drrn_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_SR}')
        
        for lr_imgs, hr_imgs in pbar:
            lr_imgs, hr_imgs = lr_imgs.to(config.DEVICE), hr_imgs.to(config.DEVICE)
            optimizer.zero_grad()
            sr_output = drrn(lr_imgs)
            loss = criterion(sr_output, hr_imgs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})
        
        avg_loss = train_loss / len(train_drrn_loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(drrn.state_dict(), os.path.join(version_save_dir, 'drrn_best.pth'))
    
    print(f"✓ DRRN training complete (best loss: {best_loss:.6f})")
    
    # Train Classifier
    print("\n" + "="*80)
    print("[3/3] Training Classifier (128x128 → 224x224 → Classification)")
    print("="*80)
    
    optimizer = optim.Adam(classifier.parameters(), lr=config.LEARNING_RATE)
    class_criterion = nn.CrossEntropyLoss()
    urgency_criterion = nn.BCELoss()
    best_acc = 0.0
    
    for epoch in range(config.EPOCHS_CLASS):
        classifier.train()
        correct, total = 0, 0
        pbar = tqdm(train_class_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS_CLASS}')
        
        for images, labels, urgency in pbar:
            images = images.to(config.DEVICE)
            labels = labels.to(config.DEVICE)
            urgency = urgency.to(config.DEVICE).unsqueeze(1).float()
            
            optimizer.zero_grad()
            class_out, urgency_out, _ = classifier(images)
            loss = class_criterion(class_out, labels) + 0.5 * urgency_criterion(urgency_out, urgency)
            loss.backward()
            optimizer.step()
            
            _, predicted = torch.max(class_out, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            pbar.set_postfix({'acc': f'{100*correct/total:.2f}%'})
        
        acc = 100 * correct / total
        if acc > best_acc:
            best_acc = acc
            torch.save(classifier.state_dict(), os.path.join(version_save_dir, 'classifier_best.pth'))
    
    print(f"✓ Classifier training complete (best accuracy: {best_acc:.2f}%)")
    
    # Save configuration
    config_dict = {
        'version': config.VERSION,
        'lapsrn_channels': config.LAPSRN_CHANNELS,
        'lapsrn_blocks': config.LAPSRN_BLOCKS,
        'drrn_channels': config.DRRN_CHANNELS,
        'drrn_blocks': config.DRRN_BLOCKS,
        'kernel_size': config.KERNEL_SIZE,
        'activation': config.ACTIVATION,
        'backbone': config.BACKBONE,
        'use_dense': config.USE_DENSE,
        'epochs_sr': config.EPOCHS_SR,
        'epochs_class': config.EPOCHS_CLASS,
        'timestamp': datetime.now().isoformat(),
        'notes': 'Added DenseNet-style dense connections. Each layer receives all previous outputs. 1x1 conv compresses back to original channels.'
    }
    
    with open(os.path.join(version_save_dir, 'config.json'), 'w') as f:
        json.dump(config_dict, f, indent=2)
    
    print(f"\n{'='*80}")
    print("✓ ALL TRAINING COMPLETE!")
    print(f"{'='*80}")
    print(f"Models saved to: {version_save_dir}")
    print("\nPipeline: 16x16 → LapSRN(4x) → 64x64 → DRRN(2x) → 128x128 → Classifier(224x224)")
    print("\nKey difference from v1_baseline:")
    print("  - Replaced residual blocks with DenseBlocks in LapSRN")
    print("  - Added DenseBlocks at DRRN multi-scale collection points")
    print("  - Each conv layer receives ALL previous layer outputs as input")
    print("  - 1x1 conv compresses concatenated features back to original channels")
    print("  - Expected: Better feature reuse, richer representations")


if __name__ == "__main__":
    train_model()


TRAINING V9_DENSE_CONNECTIONS
Configuration:
  - LapSRN: 64 channels, 5 blocks
  - DRRN: 128 channels, 25 blocks
  - Kernel: 3x3
  - Activation: leaky
  - Backbone: RESNET50
  - Dense Connections: Enabled
  - Device: cuda


Dense Connection info:
  - Type: DenseBlock (DenseNet-style connections)
  - LapSRN: 2 DenseBlocks per pyramid level (growth_rate=16, num_layers=4)
  - DRRN: 3 DenseBlocks at multi-scale collection points (growth_rate=32, num_layers=4)
  - Each layer receives ALL previous layer outputs as input
  - 1x1 conv compresses concatenated features back to original channels

[1/3] Training LapSRN (16x16 → 64x64, 4x upsampling) + Dense Blocks


Epoch 1/50: 100% 1328/1328 [00:41<00:00, 31.84it/s, loss=0.028415]
Epoch 2/50: 100% 1328/1328 [00:37<00:00, 34.95it/s, loss=0.028384]
Epoch 3/50: 100% 1328/1328 [00:37<00:00, 35.01it/s, loss=0.011137]
Epoch 4/50: 100% 1328/1328 [00:37<00:00, 35.36it/s, loss=0.020175]
Epoch 5/50: 100% 1328/1328 [00:37<00:00, 35.35it/s, loss=0.012303]
Epoch 6/50: 100% 1328/1328 [00:38<00:00, 34.77it/s, loss=0.011186]
Epoch 7/50: 100% 1328/1328 [00:38<00:00, 34.25it/s, loss=0.011265]
Epoch 8/50: 100% 1328/1328 [00:42<00:00, 31.43it/s, loss=0.016119]
Epoch 9/50: 100% 1328/1328 [00:40<00:00, 33.19it/s, loss=0.011880]
Epoch 10/50: 100% 1328/1328 [00:38<00:00, 34.53it/s, loss=0.013129]
Epoch 11/50: 100% 1328/1328 [00:38<00:00, 34.55it/s, loss=0.015241]
Epoch 12/50: 100% 1328/1328 [00:38<00:00, 34.73it/s, loss=0.000089]
Epoch 13/50: 100% 1328/1328 [00:38<00:00, 34.60it/s, loss=0.016882]
Epoch 14/50: 100% 1328/1328 [00:38<00:00, 34.35it/s, loss=0.015540]
Epoch 15/50: 100% 1328/1328 [00:39<00:00, 33.70it/s, loss

✓ LapSRN training complete (best loss: 0.010331)

[2/3] Training DRRN (64x64 → 128x128, 2x upsampling) + Dense Blocks


Epoch 1/50: 100% 1328/1328 [01:03<00:00, 20.81it/s, loss=0.003032]
Epoch 2/50: 100% 1328/1328 [00:59<00:00, 22.46it/s, loss=0.004422]
Epoch 3/50: 100% 1328/1328 [00:59<00:00, 22.27it/s, loss=0.002460]
Epoch 4/50: 100% 1328/1328 [00:59<00:00, 22.50it/s, loss=0.002153]
Epoch 5/50: 100% 1328/1328 [01:09<00:00, 19.13it/s, loss=0.002845]
Epoch 6/50: 100% 1328/1328 [01:19<00:00, 16.69it/s, loss=0.003618]
Epoch 7/50: 100% 1328/1328 [01:20<00:00, 16.51it/s, loss=0.002769]
Epoch 8/50: 100% 1328/1328 [01:27<00:00, 15.21it/s, loss=0.003031]
Epoch 9/50: 100% 1328/1328 [01:32<00:00, 14.34it/s, loss=0.002098]
Epoch 10/50: 100% 1328/1328 [01:32<00:00, 14.30it/s, loss=0.000063]
Epoch 11/50: 100% 1328/1328 [01:32<00:00, 14.33it/s, loss=0.004511]
Epoch 12/50: 100% 1328/1328 [01:32<00:00, 14.35it/s, loss=0.002387]
Epoch 13/50: 100% 1328/1328 [01:32<00:00, 14.33it/s, loss=0.008759]
Epoch 14/50: 100% 1328/1328 [01:32<00:00, 14.35it/s, loss=0.001149]
Epoch 15/50: 100% 1328/1328 [01:32<00:00, 14.31it/s, loss

✓ DRRN training complete (best loss: 0.002348)

[3/3] Training Classifier (128x128 → 224x224 → Classification)


Epoch 1/30:   0% 0/332 [00:00<?, ?it/s]MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvBinWinogradRxSf3x2; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvHipImplicitGemmGroupBwdXdlops; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xB
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xW
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvHipImplicitGemmGroupWrwXdlops; key: 2x2048x7x7x1x1x1x1x1024x16x0x0x0x2x2x0x1x1x0x0x1xNCHWxFP32xW
MIOpen(HIP): Error [ParseContents] Duplicate ID (ignored): ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC; key: 2x512x7x7x1x3x3x1x512x16x1x1x0x2x2x0x1x1x0x0x1

✓ Classifier training complete (best accuracy: 98.10%)

✓ ALL TRAINING COMPLETE!
Models saved to: ./trained_models_v9/v9_dense_connections

Pipeline: 16x16 → LapSRN(4x) → 64x64 → DRRN(2x) → 128x128 → Classifier(224x224)

Key difference from v1_baseline:
  - Replaced residual blocks with DenseBlocks in LapSRN
  - Added DenseBlocks at DRRN multi-scale collection points
  - Each conv layer receives ALL previous layer outputs as input
  - 1x1 conv compresses concatenated features back to original channels
  - Expected: Better feature reuse, richer representations





In [6]:
"""
Model Evaluation Script - v9_dense_connections (Notebook-Friendly)
Calculates PSNR, SSIM, Accuracy, and Classification Metrics

Usage in Jupyter Notebook:
    from evaluate_v9_dense_connections import evaluate_model
    
    results = evaluate_model(
        version='v9_dense_connections',
        data_dir='./preprocessed_data',
        model_dir='./trained_models_v9'
    )
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from tqdm import tqdm
import json
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# ==============================================================================
# DATASETS
# ==============================================================================

class SuperResolutionDataset(Dataset):
    def __init__(self, preprocessed_data_dir, hr_patch_size=64, scale_factor=4):
        self.hr_patch_size = hr_patch_size
        self.lr_patch_size = hr_patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = (h - self.hr_patch_size) // 2
        left = (w - self.hr_patch_size) // 2
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class DRRNDataset(Dataset):
    def __init__(self, preprocessed_data_dir, patch_size=64, scale_factor=2):
        self.hr_patch_size = patch_size
        self.lr_patch_size = patch_size // scale_factor
        self.scale_factor = scale_factor
        self.image_files = []
        
        for category in ['Normal', 'Ischemia', 'Bleeding']:
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.image_files.append(os.path.join(category_path, filename))
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('L')
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        h, w = img_array.shape
        if h < self.hr_patch_size or w < self.hr_patch_size:
            img = Image.fromarray((img_array * 255).astype(np.uint8))
            img = img.resize((self.hr_patch_size, self.hr_patch_size), Image.BICUBIC)
            img_array = np.array(img, dtype=np.float32) / 255.0
            h, w = img_array.shape
        
        top = (h - self.hr_patch_size) // 2
        left = (w - self.hr_patch_size) // 2
        hr_patch = img_array[top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        
        hr_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
        lr_pil = hr_pil.resize((self.lr_patch_size, self.lr_patch_size), Image.BICUBIC)
        lr_patch = np.array(lr_pil, dtype=np.float32) / 255.0
        
        lr_tensor = torch.from_numpy(lr_patch.copy()).unsqueeze(0).float()
        hr_tensor = torch.from_numpy(hr_patch.copy()).unsqueeze(0).float()
        
        return lr_tensor, hr_tensor


class ClassificationDataset(Dataset):
    def __init__(self, preprocessed_data_dir, enhance_size=224):
        self.enhance_size = enhance_size
        self.data = []
        
        category_map = {'Normal': 0, 'Ischemia': 1, 'Bleeding': 2}
        urgency_map = {'Normal': 0.1, 'Ischemia': 0.7, 'Bleeding': 0.95}
        
        for category, label in category_map.items():
            category_path = os.path.join(preprocessed_data_dir, category, '6_Final_Stripped')
            if os.path.exists(category_path):
                for filename in os.listdir(category_path):
                    if filename.endswith('.png'):
                        self.data.append({
                            'path': os.path.join(category_path, filename),
                            'label': label,
                            'urgency': urgency_map[category]
                        })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        img = Image.open(sample['path']).convert('L')
        img = img.resize((self.enhance_size, self.enhance_size), Image.BICUBIC)
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array.copy()).unsqueeze(0).float()
        
        return img_tensor, sample['label'], sample['urgency']


# ==============================================================================
# MODEL DEFINITIONS - v9_dense_connections
# ==============================================================================

class DenseBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, num_layers=4, growth_rate=16):
        super().__init__()
        padding = kernel_size // 2
        self.num_layers = num_layers
        self.growth_rate = growth_rate
        
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_channels = channels + i * growth_rate
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_channels, growth_rate, kernel_size, padding=padding),
                nn.LeakyReLU(0.2, True)
            ))
        
        total_channels = channels + num_layers * growth_rate
        self.compress = nn.Conv2d(total_channels, channels, 1)
    
    def forward(self, x):
        features = [x]
        for layer in self.layers:
            concat = torch.cat(features, dim=1)
            out = layer(concat)
            features.append(out)
        
        dense_out = torch.cat(features, dim=1)
        out = self.compress(dense_out)
        return out + x


class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        return self.activation(out + residual)


class RecursiveBlock(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=padding)
        self.activation = nn.LeakyReLU(0.2, True)
    
    def forward(self, x):
        residual = x
        out = self.activation(self.conv1(x))
        out = self.activation(self.conv2(out))
        return out + residual


class LapSRN(nn.Module):
    def __init__(self, scale_factor=4, num_channels=1, use_dense=True):
        super().__init__()
        self.scale_factor = scale_factor
        self.num_levels = 2
        ch = 64
        
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, ch, 3, padding=1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.pyramid_levels = nn.ModuleList()
        self.image_reconstruction = nn.ModuleList()
        
        for _ in range(self.num_levels):
            layers = []
            if use_dense:
                layers.append(DenseBlock(ch, kernel_size=3, num_layers=4, growth_rate=16))
                layers.append(DenseBlock(ch, kernel_size=3, num_layers=4, growth_rate=16))
            else:
                for _ in range(5):
                    layers.append(ResidualBlock(ch, 3))
            layers.append(nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.2, True))
            
            self.pyramid_levels.append(nn.Sequential(*layers))
            self.image_reconstruction.append(nn.Conv2d(ch, num_channels, 3, padding=1))
    
    def forward(self, x):
        features = self.feature_extraction(x)
        outputs = []
        current_features = features
        
        for level_idx in range(self.num_levels):
            current_features = self.pyramid_levels[level_idx](current_features)
            img_out = self.image_reconstruction[level_idx](current_features)
            
            if level_idx > 0:
                img_out = img_out + F.interpolate(outputs[-1], scale_factor=2, mode='bilinear', align_corners=False)
            else:
                img_out = img_out + F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            
            outputs.append(img_out)
        
        return outputs[-1], outputs


class DRRN(nn.Module):
    def __init__(self, num_channels=1, scale_factor=2, use_dense=True):
        super().__init__()
        self.scale_factor = scale_factor
        ch = 128
        
        self.input_conv = nn.Conv2d(num_channels, ch, 3, padding=1)
        
        self.recursive_blocks = nn.ModuleList()
        for _ in range(25):
            self.recursive_blocks.append(RecursiveBlock(ch, 3))
        
        self.dense_blocks = nn.ModuleList()
        if use_dense:
            for _ in range(3):
                self.dense_blocks.append(DenseBlock(ch, kernel_size=3, num_layers=4, growth_rate=32))
        
        self.use_dense = use_dense
        
        self.fusion = nn.Sequential(
            nn.Conv2d(ch * 3, ch, 1),
            nn.LeakyReLU(0.2, True)
        )
        
        self.upsample = nn.Sequential(
            nn.Conv2d(ch, ch * 4, 3, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True)
        )
        
        self.output_conv = nn.Sequential(
            nn.Conv2d(ch, 64, 3, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, num_channels, 3, padding=1)
        )
    
    def forward(self, x):
        input_upsampled = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)
        
        features = self.input_conv(x)
        multi_scale_features = []
        current = features
        
        collect_indices = [8, 16, 24]
        dense_idx = 0
        
        for idx, block in enumerate(self.recursive_blocks):
            current = block(current)
            if idx in collect_indices:
                if self.use_dense:
                    current = self.dense_blocks[dense_idx](current)
                    dense_idx += 1
                multi_scale_features.append(current)
        
        fused = torch.cat(multi_scale_features, dim=1)
        fused = self.fusion(fused)
        upsampled = self.upsample(fused)
        output = self.output_conv(upsampled)
        
        return output + input_upsampled


class MedicalImageClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        
        from torchvision import models
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        self.classification_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self.urgency_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.feature_head = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classification_head(features), self.urgency_head(features), self.feature_head(features)


# ==============================================================================
# EVALUATION
# ==============================================================================

def evaluate_model(version='v9_dense_connections', data_dir='./preprocessed_data', 
                   model_dir='./trained_models_v9',
                   device='cuda' if torch.cuda.is_available() else 'cpu'):
    
    print(f"\n{'='*80}")
    print(f"EVALUATING MODEL: {version}")
    print(f"{'='*80}\n")
    
    version_dir = os.path.join(model_dir, version)
    config_path = os.path.join(version_dir, 'config.json')
    
    if not os.path.exists(config_path):
        print(f"ERROR: Config file not found at {config_path}")
        return None
    
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    print(f"Loaded configuration for {version}")
    print(f"  LapSRN: {config.get('lapsrn_channels', 64)} channels + DenseBlocks (growth_rate=16)")
    print(f"  DRRN: {config.get('drrn_channels', 128)} channels + DenseBlocks (growth_rate=32)")
    print(f"  Backbone: {config.get('backbone', 'resnet50')}")
    print(f"  Dense Connections: {config.get('use_dense', True)}\n")
    
    # Datasets
    print("Creating datasets...")
    sr_dataset = SuperResolutionDataset(data_dir, hr_patch_size=64, scale_factor=4)
    drrn_dataset = DRRNDataset(data_dir, patch_size=64, scale_factor=2)
    class_dataset = ClassificationDataset(data_dir, enhance_size=224)
    
    sr_loader = DataLoader(sr_dataset, batch_size=16, shuffle=False, num_workers=2)
    drrn_loader = DataLoader(drrn_dataset, batch_size=16, shuffle=False, num_workers=2)
    class_loader = DataLoader(class_dataset, batch_size=16, shuffle=False, num_workers=2)
    
    print(f"  SR dataset: {len(sr_dataset)} samples")
    print(f"  DRRN dataset: {len(drrn_dataset)} samples")
    print(f"  Classification dataset: {len(class_dataset)} samples\n")
    
    # Load models
    print("Loading models...")
    lapsrn = LapSRN(use_dense=True).to(device)
    drrn = DRRN(use_dense=True).to(device)
    classifier = MedicalImageClassifier().to(device)
    
    lapsrn.load_state_dict(torch.load(os.path.join(version_dir, 'lapsrn_best.pth'), map_location=device))
    drrn.load_state_dict(torch.load(os.path.join(version_dir, 'drrn_best.pth'), map_location=device))
    classifier.load_state_dict(torch.load(os.path.join(version_dir, 'classifier_best.pth'), map_location=device))
    
    lapsrn.eval()
    drrn.eval()
    classifier.eval()
    print("✓ Models loaded successfully\n")
    
    # Evaluate LapSRN
    print("="*80)
    print("[1/3] Evaluating LapSRN (16x16 → 64x64)")
    print("="*80)
    
    lapsrn_psnr_list, lapsrn_ssim_list = [], []
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(sr_loader, desc="LapSRN Evaluation"):
            lr_imgs = lr_imgs.to(device)
            sr_output, _ = lapsrn(lr_imgs)
            
            sr_np = sr_output.cpu().numpy()
            hr_np = hr_imgs.numpy()
            
            for i in range(sr_np.shape[0]):
                sr_img = np.clip(sr_np[i, 0], 0, 1)
                hr_img = np.clip(hr_np[i, 0], 0, 1)
                lapsrn_psnr_list.append(psnr(hr_img, sr_img, data_range=1.0))
                lapsrn_ssim_list.append(ssim(hr_img, sr_img, data_range=1.0))
    
    lapsrn_psnr_mean = np.mean(lapsrn_psnr_list)
    lapsrn_ssim_mean = np.mean(lapsrn_ssim_list)
    
    print(f"\n✓ LapSRN Results:")
    print(f"  PSNR: {lapsrn_psnr_mean:.4f} dB")
    print(f"  SSIM: {lapsrn_ssim_mean:.4f}\n")
    
    # Evaluate DRRN
    print("="*80)
    print("[2/3] Evaluating DRRN (64x64 → 128x128)")
    print("="*80)
    
    drrn_psnr_list, drrn_ssim_list = [], []
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(drrn_loader, desc="DRRN Evaluation"):
            lr_imgs = lr_imgs.to(device)
            sr_output = drrn(lr_imgs)
            
            sr_np = sr_output.cpu().numpy()
            hr_np = hr_imgs.numpy()
            
            for i in range(sr_np.shape[0]):
                sr_img = np.clip(sr_np[i, 0], 0, 1)
                hr_img = np.clip(hr_np[i, 0], 0, 1)
                drrn_psnr_list.append(psnr(hr_img, sr_img, data_range=1.0))
                drrn_ssim_list.append(ssim(hr_img, sr_img, data_range=1.0))
    
    drrn_psnr_mean = np.mean(drrn_psnr_list)
    drrn_ssim_mean = np.mean(drrn_ssim_list)
    
    print(f"\n✓ DRRN Results:")
    print(f"  PSNR: {drrn_psnr_mean:.4f} dB")
    print(f"  SSIM: {drrn_ssim_mean:.4f}\n")
    
    # Evaluate Classifier
    print("="*80)
    print("[3/3] Evaluating Classifier")
    print("="*80)
    
    all_preds, all_labels, all_urgency_preds, all_urgency_true = [], [], [], []
    
    with torch.no_grad():
        for images, labels, urgency in tqdm(class_loader, desc="Classifier Evaluation"):
            images = images.to(device)
            class_out, urgency_out, _ = classifier(images)
            _, predicted = torch.max(class_out, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_urgency_preds.extend(urgency_out.cpu().numpy().flatten())
            all_urgency_true.extend(urgency.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds) * 100
    conf_matrix = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, 
                                        target_names=['Normal', 'Ischemia', 'Bleeding'], output_dict=True)
    urgency_mse = np.mean((np.array(all_urgency_preds) - np.array(all_urgency_true))**2)
    urgency_mae = np.mean(np.abs(np.array(all_urgency_preds) - np.array(all_urgency_true)))
    
    print(f"\n✓ Classification Results:")
    print(f"  Accuracy: {accuracy:.2f}%")
    print(f"\n  Confusion Matrix:")
    print(f"  {conf_matrix}")
    print(f"\n  Per-Class Metrics:")
    for class_name in ['Normal', 'Ischemia', 'Bleeding']:
        metrics = class_report[class_name]
        print(f"    {class_name}:")
        print(f"      Precision: {metrics['precision']:.4f}")
        print(f"      Recall: {metrics['recall']:.4f}")
        print(f"      F1-Score: {metrics['f1-score']:.4f}")
    
    print(f"\n  Urgency Prediction:")
    print(f"    MSE: {urgency_mse:.4f}")
    print(f"    MAE: {urgency_mae:.4f}\n")
    
    # Save results
    results = {
        'version': version,
        'lapsrn': { 'psnr': float(lapsrn_psnr_mean), 'ssim': float(lapsrn_ssim_mean) },
        'drrn': { 'psnr': float(drrn_psnr_mean), 'ssim': float(drrn_ssim_mean) },
        'classifier': {
            'accuracy': float(accuracy),
            'confusion_matrix': conf_matrix.tolist(),
            'classification_report': class_report,
            'urgency_mse': float(urgency_mse),
            'urgency_mae': float(urgency_mae)
        }
    }
    
    results_path = os.path.join(version_dir, 'evaluation_results.json')
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"✓ Results saved to: {results_path}")
    print(f"\n{'='*80}")
    print(f"EVALUATION COMPLETE FOR {version}")
    print(f"{'='*80}\n")
    
    return results

In [8]:
results = evaluate_model(version='v9_dense_connections', data_dir='./preprocessed_data', model_dir='./trained_models_v9')


EVALUATING MODEL: v9_dense_connections

Loaded configuration for v9_dense_connections
  LapSRN: 64 channels + DenseBlocks (growth_rate=16)
  DRRN: 128 channels + DenseBlocks (growth_rate=32)
  Backbone: resnet50
  Dense Connections: True

Creating datasets...
  SR dataset: 6636 samples
  DRRN dataset: 6636 samples
  Classification dataset: 6636 samples

Loading models...
✓ Models loaded successfully

[1/3] Evaluating LapSRN (16x16 → 64x64)


LapSRN Evaluation: 100% 415/415 [00:12<00:00, 32.93it/s]



✓ LapSRN Results:
  PSNR: 32.7081 dB
  SSIM: 0.8212

[2/3] Evaluating DRRN (64x64 → 128x128)


DRRN Evaluation: 100% 415/415 [00:12<00:00, 33.29it/s]



✓ DRRN Results:
  PSNR: 45.2190 dB
  SSIM: 0.9839

[3/3] Evaluating Classifier


Classifier Evaluation: 100% 415/415 [00:14<00:00, 28.58it/s]



✓ Classification Results:
  Accuracy: 97.24%

  Confusion Matrix:
  [[4402   21    4]
 [  57 1058    1]
 [  95    5  993]]

  Per-Class Metrics:
    Normal:
      Precision: 0.9666
      Recall: 0.9944
      F1-Score: 0.9803
    Ischemia:
      Precision: 0.9760
      Recall: 0.9480
      F1-Score: 0.9618
    Bleeding:
      Precision: 0.9950
      Recall: 0.9085
      F1-Score: 0.9498

  Urgency Prediction:
    MSE: 0.0130
    MAE: 0.0622

✓ Results saved to: ./trained_models_v9/v9_dense_connections/evaluation_results.json

EVALUATION COMPLETE FOR v9_dense_connections

