# EXPERIMENT 2 (Improvement 1): Preprocessing, Data Augmentation and Deep Supervision

**Improvements implemented:**
1. **CLAHE**: Contrast Limited Adaptive Histogram Equalization to enhance structures
2. **Medical Data Augmentation**: ElasticTransform, GridDistortion, OpticalDistortion
3. **Deep Supervision**: Multiple outputs at different scales
4. **Test Time Augmentation (TTA)**: Ensemble of predictions
5. **CoarseDropout**: Regularization simulating occlusions

## 1. Installation and Imports

In [None]:
!pip install segmentation-models-pytorch albumentations opencv-python -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp
from skimage.draw import polygon

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 2. Configuration

In [None]:
# ==== OPTION 1: Mount Google Drive ====
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    ROOT_DIR = '/content/drive/MyDrive/PapilaDB/'
    print("Drive mounted!")
except:
    print("Running locally.")
    ROOT_DIR = '/content/PapilaDB/'

print(f"ROOT_DIR: {ROOT_DIR}")

# Hyperparameters
BATCH_SIZE = 8
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
IMG_SIZE = 512

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'

## 3. Prepare Data

In [None]:
img_dir = ROOT_DIR + 'FundusImages/'
contour_dir = ROOT_DIR + 'ExpertsSegmentations/Contours/'

img_files = sorted(os.listdir(img_dir))
contour_files = sorted(os.listdir(contour_dir))
disc_contours = [f for f in contour_files if 'disc' in f.lower()]

def get_pairs():
    pairs = []
    for img_file in img_files:
        img_id = os.path.splitext(img_file)[0]
        for cont in disc_contours:
            if img_id in cont:
                pairs.append({
                    'image': os.path.join(img_dir, img_file),
                    'contour': os.path.join(contour_dir, cont)
                })
                break
    return pairs

pairs = get_pairs()
print(f'Imagens: {len(img_files)} | Contornos: {len(disc_contours)} | Pares: {len(pairs)}')

Imagens: 488 | Contornos: 976 | Pares: 488


## 4. CLAHE Preprocessing

In [None]:
def apply_clahe_preprocessing(image, **kwargs):
    """
    Applies CLAHE (Contrast Limited Adaptive Histogram Equalization)
    on the luminance channel to enhance structures in fundus images
    """
    # Convert to LAB color space
    lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)

    # Apply CLAHE on L channel (luminance)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    l_clahe = clahe.apply(l)

    # Recombine channels
    lab_clahe = cv2.merge([l_clahe, a, b])

    # Convert back to RGB
    result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
    return result

In [None]:
# Visualize CLAHE effect
def compare_clahe_effect(pairs, idx=0):
    pair = pairs[idx]
    img_original = np.array(Image.open(pair['image']).convert('RGB'))
    img_clahe = apply_clahe_preprocessing(img_original)

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    axes[0].imshow(img_original)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    axes[1].imshow(img_clahe)
    axes[1].set_title('With CLAHE (Enhanced Contrast)')
    axes[1].axis('off')
    plt.suptitle('CLAHE Preprocessing Effect', fontsize=14)
    plt.tight_layout()
    plt.show()

compare_clahe_effect(pairs, 0)

## 5. Advanced Data Augmentation

In [None]:
def get_train_transforms():
    """
    Advanced Data Augmentation for medical images:
    - CLAHE as preprocessing
    - ElasticTransform to simulate anatomical deformations
    - GridDistortion and OpticalDistortion
    - CoarseDropout for regularization
    """
    return A.Compose([
        # CLAHE preprocessing (always applied)
        A.Lambda(image=apply_clahe_preprocessing),

        A.Resize(IMG_SIZE, IMG_SIZE),

        # Geometric augmentations
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=45, p=0.5),

        # Medical image-specific deformations
        A.OneOf([
            A.ElasticTransform(alpha=120, sigma=120 * 0.05, p=1.0),
            A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0),
            A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1.0),
        ], p=0.4),

        # Color/intensity augmentations
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50)),
            A.GaussianBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
            A.MotionBlur(blur_limit=3),
        ], p=0.3),

        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
            A.CLAHE(clip_limit=4),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20),
            A.RandomGamma(gamma_limit=(80, 120)),
        ], p=0.4),

        # Regularization: CoarseDropout (simulates occlusions)
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32,
                        min_holes=1, min_height=8, min_width=8,
                        fill_value=0, p=0.3),

        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

def get_val_transforms():
    return A.Compose([
        A.Lambda(image=apply_clahe_preprocessing),
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

## 6. Dataset

In [None]:
class OpticDiscDataset(Dataset):
    def __init__(self, pairs, transforms=None):
        self.pairs = pairs
        self.transforms = transforms

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]
        image = np.array(Image.open(pair['image']).convert('RGB'))
        h, w = image.shape[:2]

        contour = np.loadtxt(pair['contour'])
        mask = np.zeros((h, w), dtype=np.uint8)
        rr, cc = polygon(contour[:, 1], contour[:, 0], mask.shape)
        mask[rr, cc] = 1

        if self.transforms:
            transformed = self.transforms(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        return image, mask.float().unsqueeze(0)

In [None]:
# Split data
train_pairs, val_pairs = train_test_split(pairs, test_size=0.2, random_state=42)

train_dataset = OpticDiscDataset(train_pairs, get_train_transforms())
val_dataset = OpticDiscDataset(val_pairs, get_val_transforms())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f'Train: {len(train_dataset)} | Validation: {len(val_dataset)}')

## 7. Deep Supervision

In [None]:
class UNetWithDeepSupervision(nn.Module):
    """
    U-Net with Deep Supervision: uses hooks to capture intermediate features
    and adds auxiliary outputs at different scales
    """
    def __init__(self, encoder_name='resnet50', encoder_weights='imagenet',
                 in_channels=3, classes=1):
        super().__init__()

        # Base U-Net model
        self.base_model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=None,
        )

        # Deep Supervision heads for different scales
        # These will be applied to encoder features at different levels
        self.ds_head_1 = nn.Sequential(
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, classes, kernel_size=1)
        )
        self.ds_head_2 = nn.Sequential(
            nn.Conv2d(512, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, classes, kernel_size=1)
        )

        self.training_mode = True

    def forward(self, x):
        # Pass through base model
        main_output = self.base_model(x)

        if self.training_mode and self.training:
            # Get encoder features for deep supervision
            features = self.base_model.encoder(x)
            # features: [x, stage1, stage2, stage3, stage4, stage5]
            # For ResNet50: approximate channels [3, 64, 256, 512, 1024, 2048]

            # Deep supervision on intermediate features
            ds_out_1 = self.ds_head_1(features[2])  # 256 channels, 1/4 resolution
            ds_out_2 = self.ds_head_2(features[3])  # 512 channels, 1/8 resolution

            # Resize to main output size
            target_size = main_output.shape[2:]
            ds_out_1 = nn.functional.interpolate(ds_out_1, size=target_size, mode='bilinear', align_corners=False)
            ds_out_2 = nn.functional.interpolate(ds_out_2, size=target_size, mode='bilinear', align_corners=False)

            return main_output, ds_out_1, ds_out_2

        return main_output

def deep_supervision_loss(outputs, target, criterion, weights=[1.0, 0.4, 0.2]):
    """Calculates combined loss for deep supervision"""
    if isinstance(outputs, tuple):
        main_out, ds1, ds2 = outputs
        loss = weights[0] * criterion(main_out, target)
        loss += weights[1] * criterion(ds1, target)
        loss += weights[2] * criterion(ds2, target)
        return loss, main_out
    else:
        return criterion(outputs, target), outputs

print("Deep Supervision defined!")

## 8. Test Time Augmentation (TTA)

In [None]:
class TestTimeAugmentation:
    """
    Test Time Augmentation: makes multiple predictions with different
    augmentations and combines results for more robust prediction
    """
    def __init__(self, model, device):
        self.model = model
        self.device = device

    def __call__(self, image):
        self.model.eval()
        predictions = []

        with torch.no_grad():
            # Original
            pred = torch.sigmoid(self.model(image))
            predictions.append(pred)

            # Horizontal flip
            flipped_h = torch.flip(image, dims=[3])
            pred_h = torch.sigmoid(self.model(flipped_h))
            pred_h = torch.flip(pred_h, dims=[3])
            predictions.append(pred_h)

            # Vertical flip
            flipped_v = torch.flip(image, dims=[2])
            pred_v = torch.sigmoid(self.model(flipped_v))
            pred_v = torch.flip(pred_v, dims=[2])
            predictions.append(pred_v)

            # Both flips
            flipped_hv = torch.flip(image, dims=[2, 3])
            pred_hv = torch.sigmoid(self.model(flipped_hv))
            pred_hv = torch.flip(pred_hv, dims=[2, 3])
            predictions.append(pred_hv)

            # Rotations
            for k in [1, 2, 3]:
                rotated = torch.rot90(image, k=k, dims=[2, 3])
                pred_rot = torch.sigmoid(self.model(rotated))
                pred_rot = torch.rot90(pred_rot, k=-k, dims=[2, 3])
                predictions.append(pred_rot)

        return torch.stack(predictions).mean(dim=0)

## 9. Model, Loss and Optimizer

In [None]:
# Criar modelo
model = UNetWithDeepSupervision(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=3,
    classes=1
).to(device)

# Loss
dice_loss = smp.losses.DiceLoss(mode='binary')
bce_loss = smp.losses.SoftBCEWithLogitsLoss()

def criterion(pred, target):
    return 0.5 * bce_loss(pred, target) + 0.5 * dice_loss(pred, target)

# Métricas
def calc_metrics(pred, target, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_bin = (pred > threshold).float()
    intersection = (pred_bin * target).sum()
    union = pred_bin.sum() + target.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    dice = (2 * intersection + 1e-6) / (pred_bin.sum() + target.sum() + 1e-6)
    return iou.item(), dice.item()

# Otimizador com scheduler
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

print(f'Modelo: U-Net com Deep Supervision')
print(f'Encoder: {ENCODER}')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

Modelo: U-Net com Deep Supervision
Encoder: resnet50


## 10. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    total_iou = 0
    total_dice = 0

    for images, masks in tqdm(loader, desc='Train'):
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss, main_output = deep_supervision_loss(outputs, masks, criterion)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        iou, dice = calc_metrics(main_output, masks)
        total_iou += iou
        total_dice += dice

    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0

    for images, masks in tqdm(loader, desc='Val'):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)

        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice

    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

## 11. Training

In [None]:
history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': [],
           'train_dice': [], 'val_dice': []}
best_dice = 0

print("="*60)
print("EXPERIMENT 2 - CLAHE + Data Aug + Deep Supervision")
print("="*60)

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')

    train_loss, train_iou, train_dice = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_iou, val_dice = validate(model, val_loader, criterion)
    scheduler.step()

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)
    history['train_dice'].append(train_dice)
    history['val_dice'].append(val_dice)

    print(f'Train - Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f}')
    print(f'Val   - Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f}')

    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), 'best_exp2_model.pth')
        print(f'*** Model saved! Dice: {best_dice:.4f} ***')

print(f"\nBest Dice: {best_dice:.4f}")

## 12. Training Graphs

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Validation')
axes[0].set_title('Loss')
axes[0].legend()

axes[1].plot(history['train_iou'], label='Train')
axes[1].plot(history['val_iou'], label='Validation')
axes[1].set_title('IoU')
axes[1].legend()

axes[2].plot(history['train_dice'], label='Train')
axes[2].plot(history['val_dice'], label='Validation')
axes[2].set_title('Dice Score')
axes[2].legend()

for ax in axes:
    ax.set_xlabel('Epoch')
    ax.grid(True, alpha=0.3)

plt.suptitle('Experiment 2: CLAHE + Data Aug + Deep Supervision', fontsize=14)
plt.tight_layout()
plt.show()

## 13. Evaluation with TTA

In [None]:
# Load best model
model.load_state_dict(torch.load('best_exp2_model.pth'))
model.eval()

tta = TestTimeAugmentation(model, device)

# Evaluation
all_iou_no_tta = []
all_dice_no_tta = []
all_iou_tta = []
all_dice_tta = []

print("Evaluating (with and without TTA)...")

with torch.no_grad():
    for images, masks in tqdm(val_loader):
        images, masks = images.to(device), masks.to(device)

        for i in range(images.shape[0]):
            img = images[i:i+1]
            mask = masks[i:i+1]

            # Without TTA
            pred = torch.sigmoid(model(img))
            pred_bin = (pred > 0.5).float()
            intersection = (pred_bin * mask).sum()
            union = pred_bin.sum() + mask.sum() - intersection
            all_iou_no_tta.append(((intersection + 1e-6) / (union + 1e-6)).item())
            all_dice_no_tta.append(((2 * intersection + 1e-6) / (pred_bin.sum() + mask.sum() + 1e-6)).item())

            # With TTA
            pred_tta = tta(img)
            pred_bin_tta = (pred_tta > 0.5).float()
            intersection = (pred_bin_tta * mask).sum()
            union = pred_bin_tta.sum() + mask.sum() - intersection
            all_iou_tta.append(((intersection + 1e-6) / (union + 1e-6)).item())
            all_dice_tta.append(((2 * intersection + 1e-6) / (pred_bin_tta.sum() + mask.sum() + 1e-6)).item())

print('\n' + '='*60)
print('RESULTS - EXPERIMENT 2')
print('='*60)
print('\nWithout TTA:')
print(f'  IoU:  {np.mean(all_iou_no_tta):.4f} +/- {np.std(all_iou_no_tta):.4f}')
print(f'  Dice: {np.mean(all_dice_no_tta):.4f} +/- {np.std(all_dice_no_tta):.4f}')
print('\nWith TTA (7 augmentations):')
print(f'  IoU:  {np.mean(all_iou_tta):.4f} +/- {np.std(all_iou_tta):.4f}')
print(f'  Dice: {np.mean(all_dice_tta):.4f} +/- {np.std(all_dice_tta):.4f}')
print(f'\nImprovement with TTA: +{(np.mean(all_dice_tta) - np.mean(all_dice_no_tta))*100:.2f}% Dice')

## 14. Visualize Predictions

In [None]:
def predict_and_show(dataset, indices):
    fig, axes = plt.subplots(len(indices), 4, figsize=(20, 5*len(indices)))

    for i, idx in enumerate(indices):
        img, mask = dataset[idx]

        with torch.no_grad():
            pred = model(img.unsqueeze(0).to(device))
            pred = torch.sigmoid(pred).cpu().squeeze().numpy()

        img_np = img.numpy().transpose(1, 2, 0)
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)

        mask_np = mask.squeeze().numpy()
        pred_bin = (pred > 0.5).astype(np.float32)

        overlay = img_np.copy()
        overlay[pred_bin > 0.5] = overlay[pred_bin > 0.5] * 0.5 + np.array([0, 1, 0]) * 0.5

        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title('Image')
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 2].imshow(pred_bin, cmap='gray')
        axes[i, 2].set_title('Prediction')
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Overlay')

        for ax in axes[i]: ax.axis('off')

    plt.tight_layout()
    plt.show()

predict_and_show(val_dataset, [0, 1, 2, 3])

## 15. Save Results

In [None]:
import pickle

results_exp2 = {
    'history': history,
    'all_iou_no_tta': all_iou_no_tta,
    'all_dice_no_tta': all_dice_no_tta,
    'all_iou_tta': all_iou_tta,
    'all_dice_tta': all_dice_tta,
    'best_dice': best_dice
}

with open('results_exp2.pkl', 'wb') as f:
    pickle.dump(results_exp2, f)

print('Results saved to results_exp2.pkl')