In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -q
!pip install segmentation-models-pytorch albumentations opencv-python scikit-learn tqdm -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Deep Learning imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import segmentation_models_pytorch as smp

# Additional imports
from PIL import Image
import gdown
import zipfile
import json
from datetime import datetime

Configuration

In [None]:
class Config:
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    BASE_PATH = '/content/glaucoma_data'
    REFUGE_PATH = os.path.join(BASE_PATH, 'REFUGE')
    DRISHTI_PATH = os.path.join(BASE_PATH, 'Drishti-GS1')

    IMAGE_SIZE = 256
    BATCH_SIZE = 4
    NUM_EPOCHS = 20
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5

    NUM_CLASSES = 2  # OD, OC
    SEED = 42

    TRAIN_RATIO = 0.7
    VAL_RATIO = 0.15
    TEST_RATIO = 0.15

    CHECKPOINT_DIR = '/content/checkpoints'
    RESULTS_DIR = '/content/results'

    @staticmethod
    def ensure_dirs():
        os.makedirs(Config.BASE_PATH, exist_ok=True)
        os.makedirs(Config.CHECKPOINT_DIR, exist_ok=True)
        os.makedirs(Config.RESULTS_DIR, exist_ok=True)

Loss Functions

In [None]:
class DiceLoss(nn.Module):

    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        intersection = (pred * target).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()


class TverskyLoss(nn.Module):
    """Tversky Loss with focal term"""
    def __init__(self, alpha=0.3, beta=0.7, gamma=1.33, smooth=1.0):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)

        tp = (pred * target).sum(dim=(2, 3))
        fp = (pred * (1 - target)).sum(dim=(2, 3))
        fn = ((1 - pred) * target).sum(dim=(2, 3))

        tversky = tp / (tp + self.alpha * fp + self.beta * fn + self.smooth)
        focal_tversky = (1 - tversky) ** self.gamma

        return focal_tversky.mean()


class UnifiedFocalLoss(nn.Module):

    def __init__(self, delta=0.6, gamma=2.0, smooth=1.0):
        super(UnifiedFocalLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.smooth = smooth
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, pred, target):
        pred_sig = torch.sigmoid(pred)

        bce_loss = self.bce(pred, target)
        focal_ce = ((1 - pred_sig) ** self.gamma) * bce_loss
        focal_ce = focal_ce.mean()

        tp = (pred_sig * target).sum(dim=(2, 3))
        fp = (pred_sig * (1 - target)).sum(dim=(2, 3))
        fn = ((1 - pred_sig) * target).sum(dim=(2, 3))

        tversky = tp / (tp + 0.3 * fp + 0.7 * fn + self.smooth)
        focal_tversky = (1 - tversky) ** (1 / self.gamma)
        focal_tversky = focal_tversky.mean()

        unified_focal = (1 - self.delta) * focal_ce + self.delta * focal_tversky

        return unified_focal


class CombinedLoss(nn.Module):

    def __init__(self, alpha=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        bce = self.bce_loss(pred, target)
        return self.alpha * dice + (1 - self.alpha) * bce

Attention mechanisms

In [None]:
class AttentionGate(nn.Module):
    """Attention Gate for U-Net"""
    def __init__(self, in_channels, gating_channels, inter_channels=None):
        super(AttentionGate, self).__init__()
        if inter_channels is None:
            inter_channels = in_channels // 2

        self.query_conv = nn.Conv2d(gating_channels, inter_channels, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, inter_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.psi = nn.Sequential(
            nn.Conv2d(inter_channels, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x, gating):
        query = self.query_conv(gating)
        key = self.key_conv(x)
        value = self.value_conv(x)

        psi = self.psi(nn.functional.relu(query + key))
        return x * psi


class AttentionBlock(nn.Module):
    """Attention Block with channel and spatial attention"""
    def __init__(self, channels):
        super(AttentionBlock, self).__init__()
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 16, channels, 1),
            nn.Sigmoid()
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(channels, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )

    def forward(self, x):
        ca = self.channel_attention(x)
        x = x * ca
        sa = self.spatial_attention(x)
        x = x * sa
        return x

Custom UNet

In [None]:
class AttentionUNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(AttentionUNet, self).__init__()
        self.features = features

        self.enc1 = self._conv_block(in_channels, features[0])
        self.pool1 = nn.MaxPool2d(2, 2)

        self.enc2 = self._conv_block(features[0], features[1])
        self.pool2 = nn.MaxPool2d(2, 2)

        self.enc3 = self._conv_block(features[1], features[2])
        self.pool3 = nn.MaxPool2d(2, 2)

        self.enc4 = self._conv_block(features[2], features[3])
        self.pool4 = nn.MaxPool2d(2, 2)

        self.bottleneck = self._conv_block(features[3], features[3] * 2)

        self.upconv4 = nn.ConvTranspose2d(features[3] * 2, features[3], 2, 2)
        self.att4 = AttentionGate(features[3], features[3])
        self.dec4 = self._conv_block(features[3] * 2, features[3])

        self.upconv3 = nn.ConvTranspose2d(features[3], features[2], 2, 2)
        self.att3 = AttentionGate(features[2], features[2])
        self.dec3 = self._conv_block(features[2] * 2, features[2])

        self.upconv2 = nn.ConvTranspose2d(features[2], features[1], 2, 2)
        self.att2 = AttentionGate(features[1], features[1])
        self.dec2 = self._conv_block(features[1] * 2, features[1])

        self.upconv1 = nn.ConvTranspose2d(features[1], features[0], 2, 2)
        self.att1 = AttentionGate(features[0], features[0])
        self.dec1 = self._conv_block(features[0] * 2, features[0])

        self.final = nn.Conv2d(features[0], out_channels, 1)

    @staticmethod
    def _conv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):

        enc1 = self.enc1(x)
        x = self.pool1(enc1)

        enc2 = self.enc2(x)
        x = self.pool2(enc2)

        enc3 = self.enc3(x)
        x = self.pool3(enc3)

        enc4 = self.enc4(x)
        x = self.pool4(enc4)

        bottleneck = self.bottleneck(x)

        x = self.upconv4(bottleneck)
        enc4 = self.att4(enc4, x)
        x = torch.cat([x, enc4], dim=1)
        x = self.dec4(x)

        x = self.upconv3(x)
        enc3 = self.att3(enc3, x)
        x = torch.cat([x, enc3], dim=1)
        x = self.dec3(x)

        x = self.upconv2(x)
        enc2 = self.att2(enc2, x)
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)

        x = self.upconv1(x)
        enc1 = self.att1(enc1, x)
        x = torch.cat([x, enc1], dim=1)
        x = self.dec1(x)

        return self.final(x)

In [None]:
class GlaucomaDataset(Dataset):

    def __init__(self, images, masks_od, masks_oc, transform=None):
        self.images = images
        self.masks_od = masks_od
        self.masks_oc = masks_oc
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = cv2.imread(self.images[idx])
        if image is None:
            image = np.zeros((Config.IMAGE_SIZE, Config.IMAGE_SIZE, 3), dtype=np.uint8)
        else:
            image = cv2.resize(image, (Config.IMAGE_SIZE, Config.IMAGE_SIZE))

        mask_od = cv2.imread(self.masks_od[idx], cv2.IMREAD_GRAYSCALE)
        if mask_od is None:
            mask_od = np.zeros((Config.IMAGE_SIZE, Config.IMAGE_SIZE), dtype=np.uint8)
        else:
            mask_od = cv2.resize(mask_od, (Config.IMAGE_SIZE, Config.IMAGE_SIZE))

        mask_oc = cv2.imread(self.masks_oc[idx], cv2.IMREAD_GRAYSCALE)
        if mask_oc is None:
            mask_oc = np.zeros((Config.IMAGE_SIZE, Config.IMAGE_SIZE), dtype=np.uint8)
        else:
            mask_oc = cv2.resize(mask_oc, (Config.IMAGE_SIZE, Config.IMAGE_SIZE))

        if self.transform:
            augmented = self.transform(image=image, mask=mask_od)
            image = augmented['image']
            mask_od = augmented['mask']

        image = image.astype(np.float32) / 255.0
        mask_od = mask_od.astype(np.float32) / 255.0
        mask_oc = mask_oc.astype(np.float32) / 255.0

        image = torch.from_numpy(image.transpose(2, 0, 1))
        mask_od = torch.from_numpy(np.expand_dims(mask_od, 0))
        mask_oc = torch.from_numpy(np.expand_dims(mask_oc, 0))

        return image, mask_od, mask_oc

Metrics

In [None]:
def calculate_dice(pred, target, smooth=1.0):
    """Dice Coefficient"""
    pred = (pred > 0.5).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return dice.item()


def calculate_iou(pred, target, smooth=1.0):
    """IoU (Jaccard Index)"""
    pred = (pred > 0.5).float()
    intersection = (pred * target).sum()
    union = (pred + target).sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.item()


def calculate_sensitivity(pred, target):
    """Sensitivity (Recall)"""
    pred = (pred > 0.5).float()
    tp = (pred * target).sum().item()
    fn = ((1 - pred) * target).sum().item()
    sensitivity = tp / (tp + fn + 1e-6)
    return sensitivity


def calculate_specificity(pred, target):
    """Specificity"""
    pred = (pred > 0.5).float()
    tn = ((1 - pred) * (1 - target)).sum().item()
    fp = (pred * (1 - target)).sum().item()
    specificity = tn / (tn + fp + 1e-6)
    return specificity


def calculate_vcdr(mask_od, mask_oc):
    """Vertical Cup-to-Disc Ratio"""
    od_rows = np.where(mask_od.sum(axis=1) > 0)[0]
    oc_rows = np.where(mask_oc.sum(axis=1) > 0)[0]

    if len(od_rows) == 0 or len(oc_rows) == 0:
        return 0.0

    od_height = od_rows[-1] - od_rows[0] + 1
    oc_height = oc_rows[-1] - oc_rows[0] + 1

    vcdr = oc_height / (od_height + 1e-6)
    return vcdr

Training class

In [None]:
class Trainer:

    def __init__(self, model, train_loader, val_loader, test_loader,
                 criterion, optimizer, device, exp_name):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.exp_name = exp_name
        self.scaler = GradScaler()

        self.history = {
            'train_loss': [],
            'val_loss': [],
            'val_dice_od': [],
            'val_dice_oc': []
        }

        self.best_dice = 0.0

    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0

        for images, masks_od, masks_oc in tqdm(self.train_loader, desc='Training'):
            images = images.to(self.device)
            masks_od = masks_od.to(self.device)
            masks_oc = masks_oc.to(self.device)

            self.optimizer.zero_grad()

            with autocast():
                pred = self.model(images)

                loss_od = self.criterion(pred, masks_od)
                loss_oc = self.criterion(pred, masks_oc)
                loss = (loss_od + loss_oc) / 2.0

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            total_loss += loss.item()

        avg_loss = total_loss / len(self.train_loader)
        self.history['train_loss'].append(avg_loss)
        return avg_loss

    def validate(self):
        self.model.eval()
        total_loss = 0.0
        dice_od_list = []
        dice_oc_list = []

        with torch.no_grad():
            for images, masks_od, masks_oc in tqdm(self.val_loader, desc='Validating'):
                images = images.to(self.device)
                masks_od = masks_od.to(self.device)
                masks_oc = masks_oc.to(self.device)

                pred = self.model(images)

                loss_od = self.criterion(pred, masks_od)
                loss_oc = self.criterion(pred, masks_oc)
                loss = (loss_od + loss_oc) / 2.0
                total_loss += loss.item()

                pred_sig = pred

                dice_od = calculate_dice(pred_sig, masks_od)
                dice_oc = calculate_dice(pred_sig, masks_oc)

                dice_od_list.append(dice_od)
                dice_oc_list.append(dice_oc)

        avg_loss = total_loss / len(self.val_loader)
        avg_dice_od = np.mean(dice_od_list)
        avg_dice_oc = np.mean(dice_oc_list)

        self.history['val_loss'].append(avg_loss)
        self.history['val_dice_od'].append(avg_dice_od)
        self.history['val_dice_oc'].append(avg_dice_oc)

        return avg_loss, avg_dice_od, avg_dice_oc


    def train(self, num_epochs, patience=15):

        best_val_loss = float('inf')
        patience_count = 0

        for epoch in range(num_epochs):
            train_loss = self.train_epoch()
            val_loss, dice_od, dice_oc = self.validate()

            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss: {val_loss:.4f}")
            print(f"  Val Dice OD: {dice_od:.4f}")
            print(f"  Val Dice OC: {dice_oc:.4f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_count = 0
                self._save_checkpoint(epoch, val_loss)
            else:
                patience_count += 1
                if patience_count >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        return self.history

    def _save_checkpoint(self, epoch, val_loss):

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_loss': val_loss,
            'history': self.history
        }
        path = os.path.join(Config.CHECKPOINT_DIR, f'{self.exp_name}_best.pt')
        torch.save(checkpoint, path)
        print(f"  Checkpoint saved: {path}")

    def evaluate(self):
        """Evaluate on test set"""
        self.model.eval()
        metrics = {
            'dice_od': [], 'iou_od': [], 'sens_od': [], 'spec_od': [],
            'dice_oc': [], 'iou_oc': [], 'sens_oc': [], 'spec_oc': [],
            'vcdr_mae': []
        }

        with torch.no_grad():
            for images, masks_od, masks_oc in tqdm(self.test_loader, desc='Evaluating'):
                images = images.to(self.device)
                masks_od = masks_od.to(self.device)
                masks_oc = masks_oc.to(self.device)

                pred = self.model(images)

                if hasattr(self.model, 'activation') and self.model.activation is None:
                    pred = torch.sigmoid(pred)

                for i in range(images.shape[0]):
                    dice_od = calculate_dice(pred[i], masks_od[i])
                    iou_od = calculate_iou(pred[i], masks_od[i])
                    sens_od = calculate_sensitivity(pred[i], masks_od[i])
                    spec_od = calculate_specificity(pred[i], masks_od[i])

                    metrics['dice_od'].append(dice_od)
                    metrics['iou_od'].append(iou_od)
                    metrics['sens_od'].append(sens_od)
                    metrics['spec_od'].append(spec_od)

                    dice_oc = calculate_dice(pred[i], masks_oc[i])
                    iou_oc = calculate_iou(pred[i], masks_oc[i])
                    sens_oc = calculate_sensitivity(pred[i], masks_oc[i])
                    spec_oc = calculate_specificity(pred[i], masks_oc[i])

                    metrics['dice_oc'].append(dice_oc)
                    metrics['iou_oc'].append(iou_oc)
                    metrics['sens_oc'].append(sens_oc)
                    metrics['spec_oc'].append(spec_oc)

                    mask_od_np = masks_od[i].cpu().numpy()[0] > 0.5
                    mask_oc_np = masks_oc[i].cpu().numpy()[0] > 0.5
                    vcdr_gt = calculate_vcdr(mask_od_np, mask_oc_np)

                    mask_pred = (pred[i] > 0.5).cpu().numpy()[0]
                    vcdr_pred = calculate_vcdr(mask_pred, mask_pred)

                    metrics['vcdr_mae'].append(abs(vcdr_gt - vcdr_pred))

        return self._compute_statistics(metrics)

    @staticmethod
    def _compute_statistics(metrics):
        stats = {}
        for key, values in metrics.items():
            if len(values) > 0:
                stats[f'{key}_mean'] = np.mean(values)
                stats[f'{key}_std'] = np.std(values)
        return stats

In [None]:
def setup_environment():
    print("Setting up environment...")
    Config.ensure_dirs()

    np.random.seed(Config.SEED)
    torch.manual_seed(Config.SEED)
    torch.cuda.manual_seed(Config.SEED)

    print(f"Device: {Config.DEVICE}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")


def create_baseline_unet():
    print("Creating Baseline U-Net...")
    model = smp.Unet(
        encoder_name='resnet34',
        encoder_weights=None,
        in_channels=3,
        classes=1,
        activation='sigmoid'
    )
    return model


def create_resnet34_unet():
    """Create U-Net with ResNet34 encoder (pretrained)"""
    print("Creating U-Net + ResNet34 (pretrained)...")
    model = smp.Unet(
        encoder_name='resnet34',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1,
        activation='sigmoid'
    )
    return model


def create_attention_unet():
    print("Creating Attention U-Net...")
    model = AttentionUNet(
        in_channels=3,
        out_channels=1,
        features=[64, 128, 256, 512]
    )
    return model


def create_efficientnet_unet():
    print("Creating U-Net + EfficientNet-B4 (pretrained)...")
    model = smp.Unet(
        encoder_name='efficientnet-b4',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1,
        activation='sigmoid'
    )
    return model

In [None]:
def run_experiment(exp_name, model, loss_fn, train_loader, val_loader, test_loader):
    print(f"\n{'='*60}")
    print(f"Running Experiment: {exp_name}")
    print(f"{'='*60}")

    optimizer = optim.Adam(
        model.parameters(),
        lr=Config.LEARNING_RATE,
        weight_decay=Config.WEIGHT_DECAY
    )

    trainer = Trainer(
        model, train_loader, val_loader, test_loader,
        loss_fn, optimizer, Config.DEVICE, exp_name
    )

    history = trainer.train(num_epochs=Config.NUM_EPOCHS, patience=15)

    results = trainer.evaluate()

    results['exp_name'] = exp_name
    results['timestamp'] = datetime.now().isoformat()

    results_path = os.path.join(Config.RESULTS_DIR, f'{exp_name}_results.json')
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\nResults for {exp_name}:")
    print(f"  OD Dice: {results.get('dice_od_mean', 0):.4f} ± {results.get('dice_od_std', 0):.4f}")
    print(f"  OC Dice: {results.get('dice_oc_mean', 0):.4f} ± {results.get('dice_oc_std', 0):.4f}")
    print(f"  VCDR MAE: {results.get('vcdr_mae_mean', 0):.4f} ± {results.get('vcdr_mae_std', 0):.4f}")

    return results, history


def create_dummy_dataset():
    print("Creating dummy dataset for testing...")

    dataset_path = Config.BASE_PATH
    refuge_path = os.path.join(dataset_path, 'REFUGE')
    os.makedirs(refuge_path, exist_ok=True)

    n_samples = 50

    img_dir = os.path.join(refuge_path, 'images')
    mask_od_dir = os.path.join(refuge_path, 'masks_od')
    mask_oc_dir = os.path.join(refuge_path, 'masks_oc')

    for d in [img_dir, mask_od_dir, mask_oc_dir]:
        os.makedirs(d, exist_ok=True)

    for i in range(n_samples):

        center_x = np.random.randint(200, 312)
        center_y = np.random.randint(200, 312)

        disc_radius = np.random.randint(60, 100)

        cup_ratio = np.random.uniform(0.4, 0.7)
        cup_radius = int(disc_radius * cup_ratio)

        img = np.random.randint(50, 150, (512, 512, 3), dtype=np.uint8)

        cv2.circle(img, (center_x, center_y), disc_radius,
                   (220, 180, 160), -1)

        cv2.circle(img, (center_x, center_y), cup_radius,
                   (180, 140, 120), -1)

        cv2.imwrite(os.path.join(img_dir, f'img_{i:03d}.jpg'), img)

        mask_od = np.zeros((512, 512), dtype=np.uint8)
        cv2.circle(mask_od, (center_x, center_y), disc_radius, 255, -1)
        cv2.imwrite(os.path.join(mask_od_dir, f'mask_{i:03d}.png'), mask_od)

        mask_oc = np.zeros((512, 512), dtype=np.uint8)
        cv2.circle(mask_oc, (center_x, center_y), cup_radius, 255, -1)
        cv2.imwrite(os.path.join(mask_oc_dir, f'mask_{i:03d}.png'), mask_oc)

    print(f"  Created {n_samples} varied dummy samples")
    print(f"  OD radius range: 60-100 pixels")
    print(f"  OC/OD ratio range: 0.4-0.7")
    print(f"  Center position: randomized")

    return img_dir, mask_od_dir, mask_oc_dir


def load_data():
    print("Loading dataset...")

    img_dir, mask_od_dir, mask_oc_dir = create_dummy_dataset()

    images = sorted([os.path.join(img_dir, f) for f in os.listdir(img_dir)])
    masks_od = sorted([os.path.join(mask_od_dir, f) for f in os.listdir(mask_od_dir)])
    masks_oc = sorted([os.path.join(mask_oc_dir, f) for f in os.listdir(mask_oc_dir)])

    n_total = len(images)
    n_train = int(n_total * Config.TRAIN_RATIO)
    n_val = int(n_total * Config.VAL_RATIO)

    indices = list(range(n_total))
    np.random.shuffle(indices)

    train_idx = indices[:n_train]
    val_idx = indices[n_train:n_train + n_val]
    test_idx = indices[n_train + n_val:]

    train_images = [images[i] for i in train_idx]
    train_masks_od = [masks_od[i] for i in train_idx]
    train_masks_oc = [masks_oc[i] for i in train_idx]

    val_images = [images[i] for i in val_idx]
    val_masks_od = [masks_od[i] for i in val_idx]
    val_masks_oc = [masks_oc[i] for i in val_idx]

    test_images = [images[i] for i in test_idx]
    test_masks_od = [masks_od[i] for i in test_idx]
    test_masks_oc = [masks_oc[i] for i in test_idx]

    train_dataset = GlaucomaDataset(train_images, train_masks_od, train_masks_oc)
    val_dataset = GlaucomaDataset(val_images, val_masks_od, val_masks_oc)
    test_dataset = GlaucomaDataset(test_images, test_masks_od, test_masks_oc)

    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)

    print(f"Dataset loaded:")
    print(f"  Train: {len(train_dataset)} samples")
    print(f"  Val: {len(val_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")

    return train_loader, val_loader, test_loader

Run ALL Experiments

In [None]:
#EXPERIMENT 1: BASELINE U-NET (без pretrained весов)
class BaselineUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        self.model = smp.Unet(
            encoder_name="resnet34",
            encoder_weights=None,
            in_channels=in_channels,
            classes=out_channels,
            activation='sigmoid'
        )

    def forward(self, x):
        return self.model(x)
#EXPERIMENT 2: U-NET + RESNET34 (Transfer Learning)
def create_resnet_unet():
    model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1,
        activation='sigmoid'
    )
    return model

#EXPERIMENT 3: ATTENTION U-NET
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_l, 1, padding=0, bias=True),
            nn.BatchNorm2d(F_l)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_l, 1, padding=0, bias=True),
            nn.BatchNorm2d(F_l)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_l, 1, 1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        """
        g: gating signal от decoder (coarse scale)
        x: skip connection от encoder (fine scale)
        """
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi


class AttentionUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.features = features

        # Encoder
        self.enc1 = self._conv_block(in_channels, features[0])
        self.pool1 = nn.MaxPool2d(2, 2)

        self.enc2 = self._conv_block(features[0], features[1])
        self.pool2 = nn.MaxPool2d(2, 2)

        self.enc3 = self._conv_block(features[1], features[2])
        self.pool3 = nn.MaxPool2d(2, 2)

        self.enc4 = self._conv_block(features[2], features[3])
        self.pool4 = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = self._conv_block(features[3], features[3] * 2)

        # Decoder with Attention Gates
        self.upconv4 = nn.ConvTranspose2d(features[3] * 2, features[3], 2, 2)
        self.att4 = AttentionGate(features[3], features[3])
        self.dec4 = self._conv_block(features[3] * 2, features[3])

        self.upconv3 = nn.ConvTranspose2d(features[3], features[2], 2, 2)
        self.att3 = AttentionGate(features[2], features[2])
        self.dec3 = self._conv_block(features[2] * 2, features[2])

        self.upconv2 = nn.ConvTranspose2d(features[2], features[1], 2, 2)
        self.att2 = AttentionGate(features[1], features[1])
        self.dec2 = self._conv_block(features[1] * 2, features[1])

        self.upconv1 = nn.ConvTranspose2d(features[1], features[0], 2, 2)
        self.att1 = AttentionGate(features[0], features[0])
        self.dec1 = self._conv_block(features[0] * 2, features[0])

        self.final = nn.Conv2d(features[0], out_channels, 1)
        self.sigmoid = nn.Sigmoid()

    @staticmethod
    def _conv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        x = self.pool1(enc1)

        enc2 = self.enc2(x)
        x = self.pool2(enc2)

        enc3 = self.enc3(x)
        x = self.pool3(enc3)

        enc4 = self.enc4(x)
        x = self.pool4(enc4)

        bottleneck = self.bottleneck(x)

        x = self.upconv4(bottleneck)
        enc4 = self.att4(x, enc4)  # g=x (decoder), x=enc4 (encoder)
        x = torch.cat([x, enc4], dim=1)
        x = self.dec4(x)

        x = self.upconv3(x)
        enc3 = self.att3(x, enc3)
        x = torch.cat([x, enc3], dim=1)
        x = self.dec3(x)

        x = self.upconv2(x)
        enc2 = self.att2(x, enc2)
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)

        x = self.upconv1(x)
        enc1 = self.att1(x, enc1)
        x = torch.cat([x, enc1], dim=1)
        x = self.dec1(x)

        return self.sigmoid(self.final(x))

#EXPERIMENT 4: U-NET + EFFICIENTNET-B4
def create_efficientnet_unet():
    model = smp.Unet(
        encoder_name="efficientnet-b4",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1,
        activation='sigmoid'
    )
    return model

model1 = BaselineUNet(in_channels=3, out_channels=1)
print(f"Experiment 1 - Baseline U-Net: {sum(p.numel() for p in model1.parameters())/1e6:.1f}M params")
model2 = create_resnet_unet()
print(f"Experiment 2 - U-Net + ResNet34: {sum(p.numel() for p in model2.parameters())/1e6:.1f}M params")
model3 = AttentionUNet(in_channels=3, out_channels=1)
print(f"Experiment 3 - Attention U-Net: {sum(p.numel() for p in model3.parameters())/1e6:.1f}M params")
model4 = create_efficientnet_unet()
print(f"Experiment 4 - U-Net + EfficientNet-B4: {sum(p.numel() for p in model4.parameters())/1e6:.1f}M params")

Experiment 1 - Baseline U-Net: 24.4M params


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

Experiment 2 - U-Net + ResNet34: 24.4M params
Experiment 3 - Attention U-Net: 31.7M params


config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/77.9M [00:00<?, ?B/s]

Experiment 4 - U-Net + EfficientNet-B4: 20.2M params


In [None]:
def run_all_experiments():
    setup_environment()

    train_loader, val_loader, test_loader = load_data()

    all_results = {}

    # EXPERIMENT 1: Baseline U-Net
    print("\n" + "="*70)
    print("EXPERIMENT 1: Baseline U-Net")
    print("="*70)

    model1 = create_baseline_unet()
    loss1 = CombinedLoss(alpha=0.5)
    results1, _ = run_experiment(
        "Baseline_UNet",
        model1, loss1,
        train_loader, val_loader, test_loader
    )
    all_results['Baseline_UNet'] = results1

    del model1, loss1
    torch.cuda.empty_cache()
    print("\n✓ GPU memory cleared after Experiment 1")

    # EXPERIMENT 2: U-Net + ResNet34

    print("\n" + "="*70)
    print("EXPERIMENT 2: U-Net + ResNet34")
    print("="*70)

    model2 = create_resnet34_unet()
    loss2 = CombinedLoss(alpha=0.5)
    results2, _ = run_experiment(
        "UNet_ResNet34",
        model2, loss2,
        train_loader, val_loader, test_loader
    )
    all_results['UNet_ResNet34'] = results2

    del model2, loss2
    torch.cuda.empty_cache()
    print("\n✓ GPU memory cleared after Experiment 2")

    # EXPERIMENT 3: Attention U-Net

    print("\n" + "="*70)
    print("EXPERIMENT 3: Attention U-Net")
    print("="*70)
    print("⚠ Using reduced batch size for this experiment")

    original_batch_size = Config.BATCH_SIZE
    Config.BATCH_SIZE = 2

    train_dataset = train_loader.dataset
    val_dataset = val_loader.dataset
    test_dataset = test_loader.dataset

    train_loader_small = DataLoader(train_dataset, batch_size=2, shuffle=True)
    val_loader_small = DataLoader(val_dataset, batch_size=2, shuffle=False)
    test_loader_small = DataLoader(test_dataset, batch_size=2, shuffle=False)

    model3 = create_attention_unet()
    loss3 = TverskyLoss(alpha=0.3, beta=0.7, gamma=1.33)
    results3, _ = run_experiment(
        "Attention_UNet",
        model3, loss3,
        train_loader_small, val_loader_small, test_loader_small
    )
    all_results['Attention_UNet'] = results3

    Config.BATCH_SIZE = original_batch_size

    del model3, loss3, train_loader_small, val_loader_small, test_loader_small
    torch.cuda.empty_cache()
    print("\n✓ GPU memory cleared after Experiment 3")

    # EXPERIMENT 4: U-Net + EfficientNet-B4

    print("\n" + "="*70)
    print("EXPERIMENT 4: U-Net + EfficientNet-B4")
    print("="*70)

    model4 = create_efficientnet_unet()
    loss4 = UnifiedFocalLoss(delta=0.6, gamma=2.0)
    results4, _ = run_experiment(
        "UNet_EfficientNet_B4",
        model4, loss4,
        train_loader, val_loader, test_loader
    )
    all_results['UNet_EfficientNet_B4'] = results4

    del model4, loss4
    torch.cuda.empty_cache()
    print("\n✓ GPU memory cleared after Experiment 4")

    # СВОДКА РЕЗУЛЬТАТОВ

    print("\n" + "="*70)
    print("EXPERIMENT COMPARISON")
    print("="*70)
    print(f"\n{'Model':<30} {'OD Dice':<12} {'OC Dice':<12} {'VCDR MAE':<12}")
    print("-" * 70)

    for model_name, results in all_results.items():
        od_dice = results.get('dice_od_mean', 0)
        oc_dice = results.get('dice_oc_mean', 0)
        vcdr_mae = results.get('vcdr_mae_mean', 0)
        print(f"{model_name:<30} {od_dice:<12.4f} {oc_dice:<12.4f} {vcdr_mae:<12.4f}")

    summary_path = os.path.join(Config.RESULTS_DIR, 'summary.json')
    with open(summary_path, 'w') as f:
        json.dump(all_results, f, indent=2)

    print(f"\n✓ Results saved to {Config.RESULTS_DIR}")
    print("="*70)

In [None]:
run_all_experiments()

Setting up environment...
Device: cuda
PyTorch version: 2.9.0+cu126
CUDA available: True
Loading dataset...
Creating dummy dataset for testing...
  Created 50 varied dummy samples
  OD radius range: 60-100 pixels
  OC/OD ratio range: 0.4-0.7
  Center position: randomized
Dataset loaded:
  Train: 35 samples
  Val: 7 samples
  Test: 8 samples

EXPERIMENT 1: Baseline U-Net
Creating Baseline U-Net...

Running Experiment: Baseline_UNet


Training: 100%|██████████| 9/9 [00:03<00:00,  2.29it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  7.78it/s]


Epoch 1/20
  Train Loss: 0.9264
  Val Loss: 0.9311
  Val Dice OD: 0.0192
  Val Dice OC: 0.0013
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.24it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.12it/s]


Epoch 2/20
  Train Loss: 0.9039
  Val Loss: 0.9225
  Val Dice OD: 0.0000
  Val Dice OC: 0.0000
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 10.34it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 22.11it/s]


Epoch 3/20
  Train Loss: 0.8927
  Val Loss: 0.8928
  Val Dice OD: 0.0031
  Val Dice OC: 0.0005
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.01it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 25.36it/s]


Epoch 4/20
  Train Loss: 0.8844
  Val Loss: 0.8740
  Val Dice OD: 0.9269
  Val Dice OC: 0.4348
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.95it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.15it/s]


Epoch 5/20
  Train Loss: 0.8790
  Val Loss: 0.8681
  Val Dice OD: 0.8967
  Val Dice OC: 0.3997
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 13.50it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.41it/s]


Epoch 6/20
  Train Loss: 0.8747
  Val Loss: 0.8660
  Val Dice OD: 0.8399
  Val Dice OC: 0.3646
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.04it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.67it/s]


Epoch 7/20
  Train Loss: 0.8711
  Val Loss: 0.8693
  Val Dice OD: 0.7567
  Val Dice OC: 0.3166


Training: 100%|██████████| 9/9 [00:00<00:00, 14.72it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.50it/s]


Epoch 8/20
  Train Loss: 0.8680
  Val Loss: 0.8711
  Val Dice OD: 0.7615
  Val Dice OC: 0.3201


Training: 100%|██████████| 9/9 [00:00<00:00, 14.57it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.88it/s]


Epoch 9/20
  Train Loss: 0.8652
  Val Loss: 0.8698
  Val Dice OD: 0.8153
  Val Dice OC: 0.3508


Training: 100%|██████████| 9/9 [00:00<00:00, 14.50it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.06it/s]


Epoch 10/20
  Train Loss: 0.8627
  Val Loss: 0.8679
  Val Dice OD: 0.8659
  Val Dice OC: 0.3814


Training: 100%|██████████| 9/9 [00:00<00:00, 13.56it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 17.79it/s]


Epoch 11/20
  Train Loss: 0.8595
  Val Loss: 0.8655
  Val Dice OD: 0.9017
  Val Dice OC: 0.4037
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 11.74it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 22.21it/s]


Epoch 12/20
  Train Loss: 0.8570
  Val Loss: 0.8629
  Val Dice OD: 0.9276
  Val Dice OC: 0.4207
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 11.59it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 23.34it/s]


Epoch 13/20
  Train Loss: 0.8551
  Val Loss: 0.8605
  Val Dice OD: 0.9405
  Val Dice OC: 0.4291
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.54it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.55it/s]


Epoch 14/20
  Train Loss: 0.8528
  Val Loss: 0.8582
  Val Dice OD: 0.9573
  Val Dice OC: 0.4405
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.05it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.27it/s]


Epoch 15/20
  Train Loss: 0.8508
  Val Loss: 0.8560
  Val Dice OD: 0.9724
  Val Dice OC: 0.4501
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 10.48it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 21.53it/s]


Epoch 16/20
  Train Loss: 0.8489
  Val Loss: 0.8540
  Val Dice OD: 0.9797
  Val Dice OC: 0.4556
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.14it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 25.34it/s]


Epoch 17/20
  Train Loss: 0.8465
  Val Loss: 0.8518
  Val Dice OD: 0.9819
  Val Dice OC: 0.4569
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 13.82it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 25.70it/s]


Epoch 18/20
  Train Loss: 0.8443
  Val Loss: 0.8495
  Val Dice OD: 0.9840
  Val Dice OC: 0.4585
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.37it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.14it/s]


Epoch 19/20
  Train Loss: 0.8420
  Val Loss: 0.8472
  Val Dice OD: 0.9867
  Val Dice OC: 0.4613
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.68it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.48it/s]


Epoch 20/20
  Train Loss: 0.8397
  Val Loss: 0.8453
  Val Dice OD: 0.9850
  Val Dice OC: 0.4592
  Checkpoint saved: /content/checkpoints/Baseline_UNet_best.pt


Evaluating: 100%|██████████| 2/2 [00:00<00:00, 12.24it/s]


Results for Baseline_UNet:
  OD Dice: 0.9861 ± 0.0040
  OC Dice: 0.4174 ± 0.0942
  VCDR MAE: 0.4798 ± 0.0754

✓ GPU memory cleared after Experiment 1

EXPERIMENT 2: U-Net + ResNet34
Creating U-Net + ResNet34 (pretrained)...






Running Experiment: UNet_ResNet34


Training: 100%|██████████| 9/9 [00:00<00:00, 13.91it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.57it/s]


Epoch 1/20
  Train Loss: 0.8790
  Val Loss: 0.9481
  Val Dice OD: 0.2061
  Val Dice OC: 0.0711
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.89it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.94it/s]


Epoch 2/20
  Train Loss: 0.8609
  Val Loss: 0.8962
  Val Dice OD: 0.6206
  Val Dice OC: 0.2682
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.08it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.45it/s]


Epoch 3/20
  Train Loss: 0.8491
  Val Loss: 0.8779
  Val Dice OD: 0.8452
  Val Dice OC: 0.3754
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.36it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.00it/s]


Epoch 4/20
  Train Loss: 0.8416
  Val Loss: 0.8542
  Val Dice OD: 0.9448
  Val Dice OC: 0.4340
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.00it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.38it/s]


Epoch 5/20
  Train Loss: 0.8354
  Val Loss: 0.8389
  Val Dice OD: 0.9669
  Val Dice OC: 0.4478
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 10.36it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 22.32it/s]


Epoch 6/20
  Train Loss: 0.8309
  Val Loss: 0.8326
  Val Dice OD: 0.9714
  Val Dice OC: 0.4509
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 11.31it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 21.34it/s]


Epoch 7/20
  Train Loss: 0.8274
  Val Loss: 0.8299
  Val Dice OD: 0.9827
  Val Dice OC: 0.4594
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.37it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.66it/s]


Epoch 8/20
  Train Loss: 0.8249
  Val Loss: 0.8281
  Val Dice OD: 0.9858
  Val Dice OC: 0.4620
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.56it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.25it/s]


Epoch 9/20
  Train Loss: 0.8231
  Val Loss: 0.8265
  Val Dice OD: 0.9879
  Val Dice OC: 0.4639
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.29it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.00it/s]


Epoch 10/20
  Train Loss: 0.8208
  Val Loss: 0.8249
  Val Dice OD: 0.9900
  Val Dice OC: 0.4682
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 13.78it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.93it/s]


Epoch 11/20
  Train Loss: 0.8191
  Val Loss: 0.8232
  Val Dice OD: 0.9900
  Val Dice OC: 0.4676
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 11.60it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 22.95it/s]


Epoch 12/20
  Train Loss: 0.8175
  Val Loss: 0.8219
  Val Dice OD: 0.9902
  Val Dice OC: 0.4671
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 11.77it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 24.33it/s]


Epoch 13/20
  Train Loss: 0.8158
  Val Loss: 0.8206
  Val Dice OD: 0.9900
  Val Dice OC: 0.4707
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.50it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.52it/s]


Epoch 14/20
  Train Loss: 0.8150
  Val Loss: 0.8194
  Val Dice OD: 0.9909
  Val Dice OC: 0.4703
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.04it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.79it/s]


Epoch 15/20
  Train Loss: 0.8134
  Val Loss: 0.8185
  Val Dice OD: 0.9889
  Val Dice OC: 0.4732
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.17it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 27.13it/s]


Epoch 16/20
  Train Loss: 0.8126
  Val Loss: 0.8177
  Val Dice OD: 0.9909
  Val Dice OC: 0.4711
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.24it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.67it/s]


Epoch 17/20
  Train Loss: 0.8114
  Val Loss: 0.8169
  Val Dice OD: 0.9897
  Val Dice OC: 0.4730
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.06it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.75it/s]


Epoch 18/20
  Train Loss: 0.8104
  Val Loss: 0.8159
  Val Dice OD: 0.9905
  Val Dice OC: 0.4725
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 12.98it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.84it/s]


Epoch 19/20
  Train Loss: 0.8095
  Val Loss: 0.8152
  Val Dice OD: 0.9904
  Val Dice OC: 0.4726
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Training: 100%|██████████| 9/9 [00:00<00:00, 14.35it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 26.44it/s]


Epoch 20/20
  Train Loss: 0.8082
  Val Loss: 0.8142
  Val Dice OD: 0.9891
  Val Dice OC: 0.4748
  Checkpoint saved: /content/checkpoints/UNet_ResNet34_best.pt


Evaluating: 100%|██████████| 2/2 [00:00<00:00, 10.18it/s]


Results for UNet_ResNet34:
  OD Dice: 0.9887 ± 0.0015
  OC Dice: 0.4321 ± 0.0976
  VCDR MAE: 0.4798 ± 0.0754






✓ GPU memory cleared after Experiment 2

EXPERIMENT 3: Attention U-Net
⚠ Using reduced batch size for this experiment
Creating Attention U-Net...

Running Experiment: Attention_UNet


Training: 100%|██████████| 18/18 [00:02<00:00,  7.61it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 15.59it/s]


Epoch 1/20
  Train Loss: 0.8087
  Val Loss: 0.8524
  Val Dice OD: 0.9383
  Val Dice OC: 0.4115
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.21it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 18.76it/s]


Epoch 2/20
  Train Loss: 0.8011
  Val Loss: 0.8386
  Val Dice OD: 0.9757
  Val Dice OC: 0.4180
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.86it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.45it/s]


Epoch 3/20
  Train Loss: 0.7964
  Val Loss: 0.8182
  Val Dice OD: 0.9638
  Val Dice OC: 0.4270
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.90it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.01it/s]


Epoch 4/20
  Train Loss: 0.7953
  Val Loss: 0.8164
  Val Dice OD: 0.9571
  Val Dice OC: 0.4226
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 10.39it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 18.64it/s]


Epoch 5/20
  Train Loss: 0.7957
  Val Loss: 0.8157
  Val Dice OD: 0.9589
  Val Dice OC: 0.4238
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.34it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.20it/s]


Epoch 6/20
  Train Loss: 0.7956
  Val Loss: 0.8154
  Val Dice OD: 0.9603
  Val Dice OC: 0.4247
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.80it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.02it/s]


Epoch 7/20
  Train Loss: 0.7917
  Val Loss: 0.8150
  Val Dice OD: 0.9654
  Val Dice OC: 0.4280
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.88it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.01it/s]


Epoch 8/20
  Train Loss: 0.7948
  Val Loss: 0.8147
  Val Dice OD: 0.9656
  Val Dice OC: 0.4281
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.30it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 19.33it/s]


Epoch 9/20
  Train Loss: 0.7915
  Val Loss: 0.8144
  Val Dice OD: 0.9681
  Val Dice OC: 0.4298
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.22it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.02it/s]


Epoch 10/20
  Train Loss: 0.7938
  Val Loss: 0.8142
  Val Dice OD: 0.9694
  Val Dice OC: 0.4307
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.86it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 18.66it/s]


Epoch 11/20
  Train Loss: 0.7935
  Val Loss: 0.8129
  Val Dice OD: 0.9697
  Val Dice OC: 0.4308
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.08it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 17.94it/s]


Epoch 12/20
  Train Loss: 0.7910
  Val Loss: 0.8142
  Val Dice OD: 0.9716
  Val Dice OC: 0.4321


Training: 100%|██████████| 18/18 [00:01<00:00, 11.74it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 19.89it/s]


Epoch 13/20
  Train Loss: 0.7906
  Val Loss: 0.8127
  Val Dice OD: 0.9734
  Val Dice OC: 0.4333
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.10it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.29it/s]


Epoch 14/20
  Train Loss: 0.7924
  Val Loss: 0.8129
  Val Dice OD: 0.9729
  Val Dice OC: 0.4330


Training: 100%|██████████| 18/18 [00:01<00:00, 11.92it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.18it/s]


Epoch 15/20
  Train Loss: 0.7899
  Val Loss: 0.8116
  Val Dice OD: 0.9770
  Val Dice OC: 0.4357
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.64it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 18.34it/s]


Epoch 16/20
  Train Loss: 0.7910
  Val Loss: 0.8129
  Val Dice OD: 0.9766
  Val Dice OC: 0.4354


Training: 100%|██████████| 18/18 [00:01<00:00, 11.28it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 18.94it/s]


Epoch 17/20
  Train Loss: 0.7889
  Val Loss: 0.8111
  Val Dice OD: 0.9777
  Val Dice OC: 0.4361
  Checkpoint saved: /content/checkpoints/Attention_UNet_best.pt


Training: 100%|██████████| 18/18 [00:01<00:00, 11.68it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.06it/s]


Epoch 18/20
  Train Loss: 0.7908
  Val Loss: 0.8113
  Val Dice OD: 0.9774
  Val Dice OC: 0.4359


Training: 100%|██████████| 18/18 [00:01<00:00, 11.87it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 19.61it/s]


Epoch 19/20
  Train Loss: 0.7886
  Val Loss: 0.8116
  Val Dice OD: 0.9779
  Val Dice OC: 0.4362


Training: 100%|██████████| 18/18 [00:01<00:00, 11.83it/s]
Validating: 100%|██████████| 4/4 [00:00<00:00, 20.08it/s]


Epoch 20/20
  Train Loss: 0.7895
  Val Loss: 0.8118
  Val Dice OD: 0.9805
  Val Dice OC: 0.4380


Evaluating: 100%|██████████| 4/4 [00:00<00:00, 16.90it/s]



Results for Attention_UNet:
  OD Dice: 0.9806 ± 0.0035
  OC Dice: 0.4133 ± 0.0949
  VCDR MAE: 0.4798 ± 0.0754

✓ GPU memory cleared after Experiment 3

EXPERIMENT 4: U-Net + EfficientNet-B4

Running Experiment: UNet_EfficientNet_B4


Training: 100%|██████████| 9/9 [00:21<00:00,  2.39s/it]
Validating: 100%|██████████| 2/2 [00:00<00:00, 10.60it/s]


Epoch 1/20
  Train Loss: 0.6229
  Val Loss: 0.6259
  Val Dice OD: 0.1543
  Val Dice OC: 0.0996
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  7.98it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.21it/s]


Epoch 2/20
  Train Loss: 0.6148
  Val Loss: 0.6224
  Val Dice OD: 0.1758
  Val Dice OC: 0.0560
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  7.79it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.78it/s]


Epoch 3/20
  Train Loss: 0.6110
  Val Loss: 0.6195
  Val Dice OD: 0.1411
  Val Dice OC: 0.0446
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  6.22it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 14.54it/s]


Epoch 4/20
  Train Loss: 0.6094
  Val Loss: 0.6182
  Val Dice OD: 0.1540
  Val Dice OC: 0.0537
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  6.26it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.81it/s]


Epoch 5/20
  Train Loss: 0.6092
  Val Loss: 0.6178
  Val Dice OD: 0.1726
  Val Dice OC: 0.0600
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  8.08it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 18.62it/s]


Epoch 6/20
  Train Loss: 0.6087
  Val Loss: 0.6173
  Val Dice OD: 0.1795
  Val Dice OC: 0.0619
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  8.03it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.99it/s]


Epoch 7/20
  Train Loss: 0.6084
  Val Loss: 0.6167
  Val Dice OD: 0.1691
  Val Dice OC: 0.0578
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  8.18it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.48it/s]


Epoch 8/20
  Train Loss: 0.6082
  Val Loss: 0.6161
  Val Dice OD: 0.1525
  Val Dice OC: 0.0518
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  8.16it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 19.52it/s]


Epoch 9/20
  Train Loss: 0.6079
  Val Loss: 0.6155
  Val Dice OD: 0.1392
  Val Dice OC: 0.0470
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  6.20it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 16.90it/s]


Epoch 10/20
  Train Loss: 0.6078
  Val Loss: 0.6149
  Val Dice OD: 0.1307
  Val Dice OC: 0.0440
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  7.99it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 16.71it/s]


Epoch 11/20
  Train Loss: 0.6074
  Val Loss: 0.6144
  Val Dice OD: 0.1261
  Val Dice OC: 0.0424
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  8.04it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.96it/s]


Epoch 12/20
  Train Loss: 0.6070
  Val Loss: 0.6140
  Val Dice OD: 0.1240
  Val Dice OC: 0.0416
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  8.23it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 18.25it/s]


Epoch 13/20
  Train Loss: 0.6072
  Val Loss: 0.6136
  Val Dice OD: 0.1228
  Val Dice OC: 0.0412
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  8.20it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.82it/s]


Epoch 14/20
  Train Loss: 0.6066
  Val Loss: 0.6133
  Val Dice OD: 0.1222
  Val Dice OC: 0.0410
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  6.36it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.92it/s]


Epoch 15/20
  Train Loss: 0.6066
  Val Loss: 0.6130
  Val Dice OD: 0.1219
  Val Dice OC: 0.0408
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  8.02it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.82it/s]


Epoch 16/20
  Train Loss: 0.6060
  Val Loss: 0.6128
  Val Dice OD: 0.1217
  Val Dice OC: 0.0408
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  7.95it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.69it/s]


Epoch 17/20
  Train Loss: 0.6060
  Val Loss: 0.6125
  Val Dice OD: 0.1216
  Val Dice OC: 0.0407
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  7.92it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.42it/s]


Epoch 18/20
  Train Loss: 0.6058
  Val Loss: 0.6123
  Val Dice OD: 0.1216
  Val Dice OC: 0.0407
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  6.28it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 13.68it/s]


Epoch 19/20
  Train Loss: 0.6055
  Val Loss: 0.6121
  Val Dice OD: 0.1215
  Val Dice OC: 0.0407
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Training: 100%|██████████| 9/9 [00:01<00:00,  6.42it/s]
Validating: 100%|██████████| 2/2 [00:00<00:00, 20.66it/s]


Epoch 20/20
  Train Loss: 0.6051
  Val Loss: 0.6119
  Val Dice OD: 0.1214
  Val Dice OC: 0.0407
  Checkpoint saved: /content/checkpoints/UNet_EfficientNet_B4_best.pt


Evaluating: 100%|██████████| 2/2 [00:00<00:00, 11.68it/s]


Results for UNet_EfficientNet_B4:
  OD Dice: 0.1344 ± 0.0469
  OC Dice: 0.0395 ± 0.0195
  VCDR MAE: 0.4798 ± 0.0754






✓ GPU memory cleared after Experiment 4

EXPERIMENT COMPARISON

Model                          OD Dice      OC Dice      VCDR MAE    
----------------------------------------------------------------------
Baseline_UNet                  0.9861       0.4174       0.4798      
UNet_ResNet34                  0.9887       0.4321       0.4798      
Attention_UNet                 0.9806       0.4133       0.4798      
UNet_EfficientNet_B4           0.1344       0.0395       0.4798      

✓ Results saved to /content/results
