# Optic Disc Segmentation

Using U-Net with pretrained ResNet for medical images (segmentation_models_pytorch)

## 1. Installation and Imports

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp
from skimage.draw import polygon

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 2. Configuration

In [None]:
# ==== GOOGLE COLAB: Mount Google Drive ====
from google.colab import drive
drive.mount('/content/drive')

# Adjust the path to where the dataset is located in your Drive
# Example: if you placed the PapilaDB folder in "My Drive/PapilaDB"
ROOT_DIR = '/content/drive/MyDrive/PapilaDB/'

# Check if it exists
import os
print(f"ROOT_DIR exists: {os.path.exists(ROOT_DIR)}")
if os.path.exists(ROOT_DIR):
    print(f"Contents: {os.listdir(ROOT_DIR)}")

BATCH_SIZE = 8
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
IMG_SIZE = 512

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'

## 3. Prepare Data

In [None]:
img_dir = ROOT_DIR + 'FundusImages/'
contour_dir = ROOT_DIR + 'ExpertsSegmentations/Contours/'

img_files = sorted(os.listdir(img_dir))
contour_files = sorted(os.listdir(contour_dir))

# Filter disc contours
disc_contours = [f for f in contour_files if 'disc' in f.lower()]

print(f'Images: {len(img_files)}')
print(f'Disc contours: {len(disc_contours)}')

In [None]:
# Create image-contour pairs
def get_pairs():
    pairs = []
    for img_file in img_files:
        img_id = os.path.splitext(img_file)[0]
        for cont in disc_contours:
            if img_id in cont:
                pairs.append({
                    'image': os.path.join(img_dir, img_file),
                    'contour': os.path.join(contour_dir, cont)
                })
                break
    return pairs

pairs = get_pairs()
print(f'Pairs found: {len(pairs)}')

## 4. Data Augmentation (Albumentations)

In [None]:
def get_train_transforms():
    return A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50)),
            A.GaussianBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
        ], p=0.3),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            A.CLAHE(clip_limit=2),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10),
        ], p=0.3),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

def get_val_transforms():
    return A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

## 5. Dataset

In [None]:
class OpticDiscDataset(Dataset):
    def __init__(self, pairs, transforms=None):
        self.pairs = pairs
        self.transforms = transforms

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]

        # Load image
        image = np.array(Image.open(pair['image']).convert('RGB'))
        h, w = image.shape[:2]

        # Create mask from contour
        contour = np.loadtxt(pair['contour'])
        mask = np.zeros((h, w), dtype=np.uint8)

        rr, cc = polygon(contour[:, 1], contour[:, 0], mask.shape)
        mask[rr, cc] = 1

        # Apply transformations
        if self.transforms:
            transformed = self.transforms(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        return image, mask.float().unsqueeze(0)

In [None]:
# Split data
train_pairs, val_pairs = train_test_split(pairs, test_size=0.2, random_state=42)

train_dataset = OpticDiscDataset(train_pairs, get_train_transforms())
val_dataset = OpticDiscDataset(val_pairs, get_val_transforms())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f'Train: {len(train_dataset)} | Validation: {len(val_dataset)}')

## 6. Visualize Samples

In [None]:
def show_sample(dataset, idx=0):
    img, mask = dataset[idx]

    # Denormalize
    img_np = img.numpy().transpose(1, 2, 0)
    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_np = np.clip(img_np, 0, 1)

    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img_np)
    ax[0].set_title('Image')
    ax[1].imshow(mask.squeeze(), cmap='gray')
    ax[1].set_title('Mask')

    # Overlay
    overlay = img_np.copy()
    m = mask.squeeze().numpy()
    overlay[m > 0.5] = overlay[m > 0.5] * 0.5 + np.array([0, 1, 0]) * 0.5
    ax[2].imshow(overlay)
    ax[2].set_title('Overlay')

    for a in ax: a.axis('off')
    plt.tight_layout()
    plt.show()

show_sample(train_dataset, 0)

## 7. U-Net Model with ResNet (SMP)

In [None]:
# Create model using segmentation_models_pytorch
model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=3,
    classes=1,
    activation=None  # We'll use sigmoid in the loss
)

model = model.to(device)
print(f'Model: U-Net with encoder {ENCODER}')
print(f'Weights: {ENCODER_WEIGHTS}')

In [None]:
# Alternative: Use other SMP models
# model = smp.DeepLabV3Plus(encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=1)
# model = smp.FPN(encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=1)
# model = smp.PSPNet(encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=1)

## 8. Loss and Metrics

In [None]:
# Combined loss from SMP
dice_loss = smp.losses.DiceLoss(mode='binary')
bce_loss = smp.losses.SoftBCEWithLogitsLoss()

def criterion(pred, target):
    return 0.5 * bce_loss(pred, target) + 0.5 * dice_loss(pred, target)

# Metrics
def calc_metrics(pred, target, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_bin = (pred > threshold).float()

    # IoU
    intersection = (pred_bin * target).sum()
    union = pred_bin.sum() + target.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)

    # Dice
    dice = (2 * intersection + 1e-6) / (pred_bin.sum() + target.sum() + 1e-6)

    return iou.item(), dice.item()

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

## 9. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    total_iou = 0
    total_dice = 0

    for images, masks in tqdm(loader, desc='Train'):
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice

    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n


@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0

    for images, masks in tqdm(loader, desc='Val'):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)

        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice

    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

## 10. Training

In [None]:
history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': [],
           'train_dice': [], 'val_dice': []}
best_dice = 0

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')

    train_loss, train_iou, train_dice = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_iou, val_dice = validate(model, val_loader, criterion)
    scheduler.step()

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)
    history['train_dice'].append(train_dice)
    history['val_dice'].append(val_dice)

    print(f'Train - Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f}')
    print(f'Val   - Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f}')

    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), 'best_optic_disc_model.pth')
        print(f'*** Model saved! Dice: {best_dice:.4f} ***')

## 11. Training Graphs

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Validation')
axes[0].set_title('Loss')
axes[0].legend()

axes[1].plot(history['train_iou'], label='Train')
axes[1].plot(history['val_iou'], label='Validation')
axes[1].set_title('IoU')
axes[1].legend()

axes[2].plot(history['train_dice'], label='Train')
axes[2].plot(history['val_dice'], label='Validation')
axes[2].set_title('Dice Score')
axes[2].legend()

for ax in axes:
    ax.set_xlabel('Epoch')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 12. Visualize Predictions

In [None]:
# Load best model
model.load_state_dict(torch.load('best_optic_disc_model.pth'))
model.eval()

def predict_and_show(dataset, indices):
    fig, axes = plt.subplots(len(indices), 4, figsize=(20, 5*len(indices)))

    for i, idx in enumerate(indices):
        img, mask = dataset[idx]

        with torch.no_grad():
            pred = model(img.unsqueeze(0).to(device))
            pred = torch.sigmoid(pred).cpu().squeeze().numpy()

        # Denormalize image
        img_np = img.numpy().transpose(1, 2, 0)
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)

        mask_np = mask.squeeze().numpy()
        pred_bin = (pred > 0.5).astype(np.float32)

        # Overlay
        overlay = img_np.copy()
        overlay[pred_bin > 0.5] = overlay[pred_bin > 0.5] * 0.5 + np.array([0, 1, 0]) * 0.5

        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title('Image')
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 2].imshow(pred_bin, cmap='gray')
        axes[i, 2].set_title('Prediction')
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Overlay')

        for ax in axes[i]: ax.axis('off')

    plt.tight_layout()
    plt.show()

predict_and_show(val_dataset, [0, 1, 2, 3])

## 13. Inference on New Image

In [None]:
def segment_image(image_path, model, img_size=512):
    """Segment the optic disc in a new image"""

    # Load and preprocess
    image = np.array(Image.open(image_path).convert('RGB'))
    original_size = image.shape[:2]

    transform = get_val_transforms()
    transformed = transform(image=image)
    img_tensor = transformed['image'].unsqueeze(0).to(device)

    # Prediction
    model.eval()
    with torch.no_grad():
        pred = model(img_tensor)
        pred = torch.sigmoid(pred).cpu().squeeze().numpy()

    # Resize mask to original size
    pred_resized = np.array(Image.fromarray((pred * 255).astype(np.uint8)).resize(
        (original_size[1], original_size[0]), Image.BILINEAR)) / 255.0

    return pred_resized

# Usage example
# mask = segment_image('path/to/image.jpg', model)

## 14. Final Evaluation

In [None]:
# Calculate metrics on validation set
model.load_state_dict(torch.load('best_optic_disc_model.pth'))
model.eval()

all_iou = []
all_dice = []

with torch.no_grad():
    for images, masks in val_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)

        for i in range(outputs.shape[0]):
            iou, dice = calc_metrics(outputs[i:i+1], masks[i:i+1])
            all_iou.append(iou)
            all_dice.append(dice)

print('=== Results on Validation Set ===')
print(f'IoU  - Mean: {np.mean(all_iou):.4f} | Std: {np.std(all_iou):.4f}')
print(f'Dice - Mean: {np.mean(all_dice):.4f} | Std: {np.std(all_dice):.4f}')