In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from scipy.ndimage import label

# Configuration matching original training
class FineTuningConfig:
    pretrained_path = "/content/drive/MyDrive/BPEYE/best_seg.pth"
    train_img_dir = "/content/drive/MyDrive/BPEYE/DATASET /GLAUCOMA/REFUGE2/train/images"
    train_mask_dir = "/content/drive/MyDrive/BPEYE/DATASET /GLAUCOMA/REFUGE2/train/mask"
    val_img_dir = "/content/drive/MyDrive/BPEYE/DATASET /GLAUCOMA/REFUGE2/val/images"
    val_mask_dir = "/content/drive/MyDrive/BPEYE/DATASET /GLAUCOMA/REFUGE2/val/mask"
    output_dir = "/content/drive/MyDrive/BPEYE/fine_tuned_models"

    # Training parameters (matching original)
    lr = 1e-4
    batch_size = 8
    num_workers = 4
    total_epoch = 20  # Fewer epochs for fine-tuning
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_size = (256, 256)

    # Fine-tuning specific
    freeze_encoder_epochs = 5  # First 5 epochs with frozen encoder

config = FineTuningConfig()

# Fixed FlexibleDataLoader (same as before but simplified)
class FlexibleDataLoader:
    def __init__(self, image_dir, mask_dir, image_extensions=None, mask_extensions=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir

        if image_extensions is None:
            self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
        else:
            self.image_extensions = [ext.lower() for ext in image_extensions]

        if mask_extensions is None:
            self.mask_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif']
        else:
            self.mask_extensions = [ext.lower() for ext in mask_extensions]

        self.image_paths, self.mask_paths = self._match_files()
        print(f"Found {len(self.image_paths)} matched image-mask pairs")

    def _get_files_with_extensions(self, directory, extensions):
        files = []
        if not os.path.exists(directory):
            print(f"Warning: Directory {directory} does not exist")
            return files

        for filename in os.listdir(directory):
            file_ext = os.path.splitext(filename)[1].lower()
            if file_ext in extensions and not filename.startswith('.'):
                files.append(os.path.join(directory, filename))
        return sorted(files)

    def _match_files(self):
        image_files = self._get_files_with_extensions(self.image_dir, self.image_extensions)
        mask_files = self._get_files_with_extensions(self.mask_dir, self.mask_extensions)

        image_dict = {}
        mask_dict = {}

        for img_path in image_files:
            base_name = os.path.splitext(os.path.basename(img_path))[0]
            image_dict[base_name] = img_path

        for mask_path in mask_files:
            base_name = os.path.splitext(os.path.basename(mask_path))[0]
            mask_dict[base_name] = mask_path

        matched_images = []
        matched_masks = []

        for base_name in image_dict:
            if base_name in mask_dict:
                matched_images.append(image_dict[base_name])
                matched_masks.append(mask_dict[base_name])

        return matched_images, matched_masks

    def get_data(self):
        return self.image_paths, self.mask_paths

# Dataset class (adapted from original)
class GlaucomaDataset(Dataset):
    def __init__(self, image_paths, mask_paths, output_size=(256, 256)):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.output_size = output_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img = Image.open(self.image_paths[idx]).convert('RGB')
        img = transforms.functional.to_tensor(img)
        img = transforms.functional.resize(img, self.output_size, interpolation=Image.BILINEAR)

        # Load mask - your format: OD=grey(128), OC=black(0), background=white(255)
        mask = np.array(Image.open(self.mask_paths[idx], mode='r'))

        # Convert to binary masks like in original training
        # OD: grey areas (128) -> channel 0
        # OC: black areas (0) -> channel 1
        od = (mask == 128).astype(np.float32)  # Optic disc
        oc = (mask == 0).astype(np.float32)    # Optic cup

        od = torch.from_numpy(od[None, :, :])
        oc = torch.from_numpy(oc[None, :, :])
        od = transforms.functional.resize(od, self.output_size, interpolation=Image.NEAREST)
        oc = transforms.functional.resize(oc, self.output_size, interpolation=Image.NEAREST)
        seg = torch.cat([od, oc], dim=0)

        return img, seg

# UNet Model (same as original)
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=2):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.epoch = 0

        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.down5 = Down(1024, 2048)

        factor = 2
        self.down6 = Down(2048, 4096 // factor)

        # Decoder
        self.up1 = Up(4096, 2048 // factor)
        self.up2 = Up(2048, 1024 // factor)
        self.up3 = Up(1024, 512 // factor)
        self.up4 = Up(512, 256 // factor)
        self.up5 = Up(256, 128 // factor)
        self.up6 = Up(128, 64)
        self.output_layer = OutConv(64, n_classes)

        # Define encoder modules for freezing (like original approach)
        self.encoder_modules = [
            self.inc, self.down1, self.down2, self.down3,
            self.down4, self.down5, self.down6
        ]

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        x7 = self.down6(x6)

        out = self.up1(x7, x6)
        out = self.up2(out, x5)
        out = self.up3(out, x4)
        out = self.up4(out, x3)
        out = self.up5(out, x2)
        out = self.up6(out, x1)
        out = self.output_layer(out)
        out = torch.sigmoid(out)
        return out

    def freeze_encoder(self):
        """Freeze encoder parameters"""
        for module in self.encoder_modules:
            for param in module.parameters():
                param.requires_grad = False

    def unfreeze_encoder(self):
        """Unfreeze encoder parameters"""
        for module in self.encoder_modules:
            for param in module.parameters():
                param.requires_grad = True

# Metrics and utilities (from original code)
EPS = 1e-7

def compute_dice_coef(input, target):
    batch_size = input.shape[0]
    return sum([dice_coef_sample(input[k,:,:], target[k,:,:]) for k in range(batch_size)])/batch_size

def dice_coef_sample(input, target):
    iflat = input.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    return (2. * intersection) / (iflat.sum() + tflat.sum())

def vertical_diameter(binary_segmentation):
    vertical_axis_diameter = np.sum(binary_segmentation, axis=1)
    diameter = np.max(vertical_axis_diameter, axis=1)
    return diameter

def vertical_cup_to_disc_ratio(od, oc):
    cup_diameter = vertical_diameter(oc)
    disc_diameter = vertical_diameter(od)
    return cup_diameter / (disc_diameter + EPS)

def compute_vCDR_error(pred_od, pred_oc, gt_od, gt_oc):
    pred_vCDR = vertical_cup_to_disc_ratio(pred_od, pred_oc)
    gt_vCDR = vertical_cup_to_disc_ratio(gt_od, gt_oc)
    vCDR_err = np.mean(np.abs(gt_vCDR - pred_vCDR))
    return vCDR_err, pred_vCDR, gt_vCDR

def refine_seg(pred):
    """Only retain the biggest connected component"""
    np_pred = pred.numpy()

    largest_ccs = []
    for i in range(np_pred.shape[0]):
        labeled, ncomponents = label(np_pred[i,:,:])
        bincounts = np.bincount(labeled.flat)[1:]
        if len(bincounts) == 0:
            largest_cc = labeled == 0
        else:
            largest_cc = labeled == np.argmax(bincounts)+1
        largest_cc = torch.tensor(largest_cc, dtype=torch.float32)
        largest_ccs.append(largest_cc)
    largest_ccs = torch.stack(largest_ccs)

    return largest_ccs

# Setup data loaders
def setup_data_loaders():
    print("Setting up data loaders...")

    # Create data loaders
    train_data_loader = FlexibleDataLoader(config.train_img_dir, config.train_mask_dir)
    val_data_loader = FlexibleDataLoader(config.val_img_dir, config.val_mask_dir)

    train_image_paths, train_mask_paths = train_data_loader.get_data()
    val_image_paths, val_mask_paths = val_data_loader.get_data()

    # Create datasets
    train_dataset = GlaucomaDataset(train_image_paths, train_mask_paths, config.output_size)
    val_dataset = GlaucomaDataset(val_image_paths, val_mask_paths, config.output_size)

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    return train_loader, val_loader

# Fine-tuning function (based on original training loop)
def fine_tune_model():
    print("Starting fine-tuning process...")

    # Setup data
    train_loader, val_loader = setup_data_loaders()

    # Setup model
    model = UNet(n_channels=3, n_classes=2).to(config.device)

    # Load pretrained weights
    if os.path.exists(config.pretrained_path):
        model.load_state_dict(torch.load(config.pretrained_path, map_location=config.device))
        print(f"✅ Loaded pretrained model from {config.pretrained_path}")
    else:
        print(f"❌ Pretrained model not found at {config.pretrained_path}")
        return

    # Setup loss and optimizer (same as original)
    seg_loss = torch.nn.BCELoss(reduction='mean')
    optimizer = optim.Adam(model.parameters(), lr=config.lr)

    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    # Training parameters
    nb_train_batches = len(train_loader)
    nb_val_batches = len(val_loader)
    best_val_auc = 0.

    print(f"Training batches: {nb_train_batches}, Validation batches: {nb_val_batches}")

    while model.epoch < config.total_epoch:
        # Freeze encoder for first few epochs
        if model.epoch < config.freeze_encoder_epochs:
            model.freeze_encoder()
            print(f"Epoch {model.epoch + 1}: Encoder FROZEN")
        else:
            model.unfreeze_encoder()
            if model.epoch == config.freeze_encoder_epochs:
                print(f"Epoch {model.epoch + 1}: Encoder UNFROZEN")

        # Training accumulators
        train_loss, val_loss = 0., 0.
        train_dsc_od, val_dsc_od = 0., 0.
        train_dsc_oc, val_dsc_oc = 0., 0.
        train_vCDR_error, val_vCDR_error = 0., 0.

        ############
        # TRAINING #
        ############
        model.train()
        train_data = iter(train_loader)
        for k in range(nb_train_batches):
            # Load data
            imgs, seg_gts = next(train_data)
            imgs, seg_gts = imgs.to(config.device), seg_gts.to(config.device)

            # Forward pass
            logits = model(imgs)
            loss = seg_loss(logits, seg_gts)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() / nb_train_batches

            with torch.no_grad():
                # Compute segmentation metrics
                pred_od = refine_seg((logits[:,0,:,:]>=0.5).type(torch.int8).cpu()).to(config.device)
                pred_oc = refine_seg((logits[:,1,:,:]>=0.5).type(torch.int8).cpu()).to(config.device)
                gt_od = seg_gts[:,0,:,:].type(torch.int8)
                gt_oc = seg_gts[:,1,:,:].type(torch.int8)

                dsc_od = compute_dice_coef(pred_od, gt_od)
                dsc_oc = compute_dice_coef(pred_oc, gt_oc)
                train_dsc_od += dsc_od.item()/nb_train_batches
                train_dsc_oc += dsc_oc.item()/nb_train_batches

                # Compute vCDR error
                vCDR_error, pred_vCDR, gt_vCDR = compute_vCDR_error(
                    pred_od.cpu().numpy(), pred_oc.cpu().numpy(),
                    gt_od.cpu().numpy(), gt_oc.cpu().numpy()
                )
                train_vCDR_error += vCDR_error / nb_train_batches

            # Progress
            print(f'Epoch {model.epoch+1}, iter {k+1}/{nb_train_batches}, loss {loss.item():.6f}' + ' '*20, end='\r')

        ##############
        # VALIDATION #
        ##############
        model.eval()
        with torch.no_grad():
            val_data = iter(val_loader)
            for k in range(nb_val_batches):
                # Load data
                imgs, seg_gts = next(val_data)
                imgs, seg_gts = imgs.to(config.device), seg_gts.to(config.device)

                # Forward pass
                logits = model(imgs)
                val_loss += seg_loss(logits, seg_gts).item() / nb_val_batches

                # Compute segmentation metrics
                pred_od = refine_seg((logits[:,0,:,:]>=0.5).type(torch.int8).cpu()).to(config.device)
                pred_oc = refine_seg((logits[:,1,:,:]>=0.5).type(torch.int8).cpu()).to(config.device)
                gt_od = seg_gts[:,0,:,:].type(torch.int8)
                gt_oc = seg_gts[:,1,:,:].type(torch.int8)

                dsc_od = compute_dice_coef(pred_od, gt_od)
                dsc_oc = compute_dice_coef(pred_oc, gt_oc)
                val_dsc_od += dsc_od.item()/nb_val_batches
                val_dsc_oc += dsc_oc.item()/nb_val_batches

                vCDR_error, pred_vCDR, gt_vCDR = compute_vCDR_error(
                    pred_od.cpu().numpy(), pred_oc.cpu().numpy(),
                    gt_od.cpu().numpy(), gt_oc.cpu().numpy()
                )
                val_vCDR_error += vCDR_error / nb_val_batches

                print(f'Validation iter {k+1}/{nb_val_batches}' + ' '*50, end='\r')

        # Print epoch results
        print(f'FINE-TUNING epoch {model.epoch+1}' + ' '*50)
        print(f'LOSSES: {train_loss:.4f} (train), {val_loss:.4f} (val)')
        print(f'OD segmentation (Dice Score): {train_dsc_od:.4f} (train), {val_dsc_od:.4f} (val)')
        print(f'OC segmentation (Dice Score): {train_dsc_oc:.4f} (train), {val_dsc_oc:.4f} (val)')
        print(f'vCDR error: {train_vCDR_error:.4f} (train), {val_vCDR_error:.4f} (val)')

        # Save model if best validation performance is reached
        current_val_score = val_dsc_od + val_dsc_oc
        if current_val_score > best_val_auc:
            model_path = os.path.join(config.output_dir, f'fine_tuned_best.pth')
            torch.save(model.state_dict(), model_path)
            best_val_auc = current_val_score
            print(f'✅ Best validation score reached: {current_val_score:.4f}. Model saved.')

        print('_'*50)

        # End of epoch
        model.epoch += 1

    print("Fine-tuning completed!")
    print(f"Best validation score: {best_val_auc:.4f}")
    print(f"Models saved in: {config.output_dir}")

# Run fine-tuning
if __name__ == "__main__":
    fine_tune_model()

Starting fine-tuning process...
Setting up data loaders...
Found 400 matched image-mask pairs
Found 400 matched image-mask pairs




✅ Loaded pretrained model from /content/drive/MyDrive/BPEYE/best_seg.pth
Training batches: 50, Validation batches: 50
Epoch 1: Encoder FROZEN
FINE-TUNING epoch 1                                                  
LOSSES: 0.0049 (train), 0.0062 (val)
OD segmentation (Dice Score): 0.8856 (train), 0.8354 (val)
OC segmentation (Dice Score): 0.8648 (train), 0.7645 (val)
vCDR error: 0.0843 (train), 0.1187 (val)
✅ Best validation score reached: 1.5999. Model saved.
__________________________________________________
Epoch 2: Encoder FROZEN
FINE-TUNING epoch 2                                                  
LOSSES: 0.0042 (train), 0.0057 (val)
OD segmentation (Dice Score): 0.8956 (train), 0.8506 (val)
OC segmentation (Dice Score): 0.8744 (train), 0.7928 (val)
vCDR error: 0.0885 (train), 0.1010 (val)
✅ Best validation score reached: 1.6434. Model saved.
__________________________________________________
Epoch 3: Encoder FROZEN
FINE-TUNING epoch 3                                                 