# EXPERIMENT 3 (Improvement 2): Attention U-Net - Architecture Modification

**Neural network architecture modifications:**
1. **Attention Gates**: Attention mechanisms in skip connections
2. **Squeeze-and-Excitation (SE) Blocks**: Adaptive channel recalibration
3. **ASPP (Atrous Spatial Pyramid Pooling)**: Multi-scale context at bottleneck
4. **Focal Loss**: To handle class imbalance

This is a substantial modification to the network topology, not just an increase in width/depth.

## 1. Installation and Imports

In [None]:
!pip install segmentation-models-pytorch albumentations opencv-python -q

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp
from skimage.draw import polygon

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 2. Configuration

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    ROOT_DIR = '/content/drive/MyDrive/PapilaDB/'
    print("Drive mounted!")
except:
    print("Running locally.")
    ROOT_DIR = '/content/PapilaDB/'

# Hyperparameters
BATCH_SIZE = 8
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
IMG_SIZE = 512

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'

## 3. Prepare Data

In [None]:
img_dir = ROOT_DIR + 'FundusImages/'
contour_dir = ROOT_DIR + 'ExpertsSegmentations/Contours/'

img_files = sorted(os.listdir(img_dir))
contour_files = sorted(os.listdir(contour_dir))
disc_contours = [f for f in contour_files if 'disc' in f.lower()]

def get_pairs():
    pairs = []
    for img_file in img_files:
        img_id = os.path.splitext(img_file)[0]
        for cont in disc_contours:
            if img_id in cont:
                pairs.append({
                    'image': os.path.join(img_dir, img_file),
                    'contour': os.path.join(contour_dir, cont)
                })
                break
    return pairs

pairs = get_pairs()
print(f'Images: {len(img_files)} | Contours: {len(disc_contours)} | Pairs: {len(pairs)}')

## 4. Preprocessing and Data Augmentation

Using the same transforms from Exp2 for fair comparison

In [None]:
def apply_clahe_preprocessing(image, **kwargs):
    lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    l_clahe = clahe.apply(l)
    lab_clahe = cv2.merge([l_clahe, a, b])
    return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)

def get_train_transforms():
    return A.Compose([
        A.Lambda(image=apply_clahe_preprocessing),
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=45, p=0.5),
        A.OneOf([
            A.ElasticTransform(alpha=120, sigma=120 * 0.05, p=1.0),
            A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0),
            A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1.0),
        ], p=0.4),
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50)),
            A.GaussianBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
        ], p=0.3),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
            A.CLAHE(clip_limit=4),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20),
        ], p=0.4),
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32,
                        min_holes=1, min_height=8, min_width=8, fill_value=0, p=0.3),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

def get_val_transforms():
    return A.Compose([
        A.Lambda(image=apply_clahe_preprocessing),
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

## 5. Dataset

In [None]:
class OpticDiscDataset(Dataset):
    def __init__(self, pairs, transforms=None):
        self.pairs = pairs
        self.transforms = transforms

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]
        image = np.array(Image.open(pair['image']).convert('RGB'))
        h, w = image.shape[:2]

        contour = np.loadtxt(pair['contour'])
        mask = np.zeros((h, w), dtype=np.uint8)
        rr, cc = polygon(contour[:, 1], contour[:, 0], mask.shape)
        mask[rr, cc] = 1

        if self.transforms:
            transformed = self.transforms(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        return image, mask.float().unsqueeze(0)

# Split
train_pairs, val_pairs = train_test_split(pairs, test_size=0.2, random_state=42)

train_dataset = OpticDiscDataset(train_pairs, get_train_transforms())
val_dataset = OpticDiscDataset(val_pairs, get_val_transforms())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f'Train: {len(train_dataset)} | Validation: {len(val_dataset)}')

## 6. Attention Modules

In [None]:
class AttentionGate(nn.Module):
    """
    Attention Gate: Allows the model to focus on relevant regions
    in skip connections, suppressing irrelevant responses.

    Ref: "Attention U-Net: Learning Where to Look for the Pancreas"
    """
    def __init__(self, F_g, F_l, F_int):
        super().__init__()

        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        if g.shape[2:] != x.shape[2:]:
            g = nn.functional.interpolate(g, size=x.shape[2:], mode='bilinear', align_corners=False)

        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)

        return x * psi

In [None]:
class SqueezeExcitation(nn.Module):
    """
    Squeeze-and-Excitation Block: Adaptive channel recalibration.
    Learns the relative importance of each feature channel.

    Ref: "Squeeze-and-Excitation Networks"
    """
    def __init__(self, channels, reduction=16):
        super().__init__()

        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.shape
        y = self.squeeze(x).view(b, c)
        y = self.excitation(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [None]:
class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling: Captures multi-scale context
    using dilated convolutions with different rates.

    Ref: "DeepLab: Semantic Image Segmentation"
    """
    def __init__(self, in_channels, out_channels, rates=[6, 12, 18]):
        super().__init__()

        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.atrous_convs = nn.ModuleList()
        for rate in rates:
            self.atrous_convs.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )

        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        num_features = out_channels * (2 + len(rates))
        self.project = nn.Sequential(
            nn.Conv2d(num_features, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        size = x.shape[2:]
        features = [self.conv1x1(x)]

        for atrous_conv in self.atrous_convs:
            features.append(atrous_conv(x))

        global_feat = self.global_pool(x)
        global_feat = nn.functional.interpolate(global_feat, size=size, mode='bilinear', align_corners=False)
        features.append(global_feat)

        x = torch.cat(features, dim=1)
        return self.project(x)

## 7. Complete Attention U-Net

In [None]:
class AttentionUNet(nn.Module):
    """
    Attention U-Net with:
    - Pretrained ResNet50 encoder
    - Attention Gates in skip connections
    - SE blocks in decoder
    - ASPP at bottleneck
    """
    def __init__(self, encoder_name='resnet50', encoder_weights='imagenet',
                 in_channels=3, classes=1):
        super().__init__()

        self.encoder = smp.encoders.get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=5,
            weights=encoder_weights
        )

        # Get encoder channels dynamically
        encoder_channels = self.encoder.out_channels
        print(f"Encoder channels: {encoder_channels}")

        # For ResNet50: [3, 64, 256, 512, 1024, 2048]
        # skips will be: [1024, 512, 256, 64] (features[4] to features[1], reversed)

        # ASPP at bottleneck
        self.bottleneck_ch = encoder_channels[-1]  # 2048
        self.aspp_out = self.bottleneck_ch // 2     # 1024
        self.aspp = ASPP(self.bottleneck_ch, self.aspp_out)

        # Skip connection channels (reversed encoder features, excluding bottleneck and input)
        # features[:-1] = [3, 64, 256, 512, 1024], [::-1] = [1024, 512, 256, 64, 3]
        # We only use the first 4 for skip connections
        skip_channels = list(encoder_channels[:-1])[::-1][:4]  # [1024, 512, 256, 64]
        print(f"Skip channels: {skip_channels}")

        decoder_channels = [256, 128, 64, 32, 16]

        # Attention Gates
        # Gate 0: g comes from ASPP (aspp_out), x comes from skip[0]
        # Gates 1-3: g comes from previous decoder, x comes from corresponding skip
        self.attention_gates = nn.ModuleList()

        # First attention gate: g = ASPP output
        self.attention_gates.append(
            AttentionGate(self.aspp_out, skip_channels[0], skip_channels[0] // 4)
        )

        # Remaining attention gates: g = output from previous decoder
        for i in range(1, 4):
            self.attention_gates.append(
                AttentionGate(decoder_channels[i-1], skip_channels[i], skip_channels[i] // 4)
            )

        # Decoder blocks
        self.decoder_blocks = nn.ModuleList()

        # First block: ASPP output + skip[0]
        in_ch = self.aspp_out + skip_channels[0]
        self.decoder_blocks.append(self._make_decoder_block(in_ch, decoder_channels[0]))

        # Following blocks: decoder[i-1] + skip[i]
        for i in range(1, 4):
            in_ch = decoder_channels[i-1] + skip_channels[i]
            self.decoder_blocks.append(self._make_decoder_block(in_ch, decoder_channels[i]))

        # Last block (no skip connection)
        self.decoder_blocks.append(self._make_decoder_block(decoder_channels[3], decoder_channels[4]))

        self.segmentation_head = nn.Conv2d(decoder_channels[-1], classes, kernel_size=1)

        # Store for debug
        self.skip_channels = skip_channels
        self.decoder_channels = decoder_channels

    def _make_decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            SqueezeExcitation(out_channels, reduction=16)
        )

    def forward(self, x):
        # Encoder
        features = self.encoder(x)

        # Bottleneck with ASPP
        x = self.aspp(features[-1])

        # Skip connections (reversed, excluding bottleneck and input)
        skips = features[1:-1][::-1]  # [features[4], features[3], features[2], features[1]]

        # Decoder with attention gates
        for i in range(4):
            x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            skip = self.attention_gates[i](x, skips[i])
            x = torch.cat([x, skip], dim=1)
            x = self.decoder_blocks[i](x)

        # Last upsample (no skip)
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.decoder_blocks[4](x)

        return self.segmentation_head(x)

## 8. Test Time Augmentation

In [None]:
class TestTimeAugmentation:
    def __init__(self, model, device):
        self.model = model
        self.device = device

    def __call__(self, image):
        self.model.eval()
        predictions = []

        with torch.no_grad():
            pred = torch.sigmoid(self.model(image))
            predictions.append(pred)

            for flip_dims in [[3], [2], [2, 3]]:
                flipped = torch.flip(image, dims=flip_dims)
                pred_f = torch.sigmoid(self.model(flipped))
                pred_f = torch.flip(pred_f, dims=flip_dims)
                predictions.append(pred_f)

            for k in [1, 2, 3]:
                rotated = torch.rot90(image, k=k, dims=[2, 3])
                pred_rot = torch.sigmoid(self.model(rotated))
                pred_rot = torch.rot90(pred_rot, k=-k, dims=[2, 3])
                predictions.append(pred_rot)

        return torch.stack(predictions).mean(dim=0)

## 9. Model, Loss and Optimizer

In [None]:
# Create model
model = AttentionUNet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=3,
    classes=1
).to(device)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Model: Attention U-Net')
print(f'Trainable parameters: {count_parameters(model):,}')

In [None]:
# Combined loss with Focal Loss
dice_loss = smp.losses.DiceLoss(mode='binary')
bce_loss = smp.losses.SoftBCEWithLogitsLoss()
focal_loss = smp.losses.FocalLoss(mode='binary', alpha=0.25, gamma=2.0)

def criterion(pred, target):
    return 0.4 * bce_loss(pred, target) + 0.4 * dice_loss(pred, target) + 0.2 * focal_loss(pred, target)

# Metrics
def calc_metrics(pred, target, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_bin = (pred > threshold).float()
    intersection = (pred_bin * target).sum()
    union = pred_bin.sum() + target.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    dice = (2 * intersection + 1e-6) / (pred_bin.sum() + target.sum() + 1e-6)
    return iou.item(), dice.item()

# Optimizer with differentiated LR
encoder_params = list(model.encoder.parameters())
decoder_params = [p for n, p in model.named_parameters() if 'encoder' not in n]

optimizer = optim.AdamW([
    {'params': encoder_params, 'lr': LEARNING_RATE * 0.1},
    {'params': decoder_params, 'lr': LEARNING_RATE}
], weight_decay=1e-4)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=[LEARNING_RATE * 0.1, LEARNING_RATE],
    epochs=NUM_EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,
    anneal_strategy='cos'
)

## 10. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, scheduler):
    model.train()
    total_loss = 0
    total_iou = 0
    total_dice = 0

    for images, masks in tqdm(loader, desc='Train'):
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice

    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0

    for images, masks in tqdm(loader, desc='Val'):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)

        total_loss += loss.item()
        iou, dice = calc_metrics(outputs, masks)
        total_iou += iou
        total_dice += dice

    n = len(loader)
    return total_loss/n, total_iou/n, total_dice/n

## 11. Training

In [None]:
history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': [],
           'train_dice': [], 'val_dice': []}
best_dice = 0

print("="*60)
print("EXPERIMENT 3 - Attention U-Net")
print("="*60)

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')

    train_loss, train_iou, train_dice = train_epoch(model, train_loader, criterion, optimizer, scheduler)
    val_loss, val_iou, val_dice = validate(model, val_loader, criterion)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)
    history['train_dice'].append(train_dice)
    history['val_dice'].append(val_dice)

    current_lr = optimizer.param_groups[1]['lr']
    print(f'Train - Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f}')
    print(f'Val   - Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f}')
    print(f'LR: {current_lr:.2e}')

    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), 'best_exp3_attention_unet.pth')
        print(f'*** Model saved! Dice: {best_dice:.4f} ***')

print(f"\nBest Dice: {best_dice:.4f}")

## 12. Training Graphs

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Validation')
axes[0].set_title('Loss')
axes[0].legend()

axes[1].plot(history['train_iou'], label='Train')
axes[1].plot(history['val_iou'], label='Validation')
axes[1].set_title('IoU')
axes[1].legend()

axes[2].plot(history['train_dice'], label='Train')
axes[2].plot(history['val_dice'], label='Validation')
axes[2].set_title('Dice Score')
axes[2].legend()

for ax in axes:
    ax.set_xlabel('Epoch')
    ax.grid(True, alpha=0.3)

plt.suptitle('Experiment 3: Attention U-Net', fontsize=14)
plt.tight_layout()
plt.show()

## 13. Evaluation with TTA

In [None]:
model.load_state_dict(torch.load('best_exp3_attention_unet.pth'))
model.eval()

tta = TestTimeAugmentation(model, device)

all_iou_no_tta = []
all_dice_no_tta = []
all_iou_tta = []
all_dice_tta = []

print("Evaluating (with and without TTA)...")

with torch.no_grad():
    for images, masks in tqdm(val_loader):
        images, masks = images.to(device), masks.to(device)

        for i in range(images.shape[0]):
            img = images[i:i+1]
            mask = masks[i:i+1]

            # Without TTA
            pred = torch.sigmoid(model(img))
            pred_bin = (pred > 0.5).float()
            intersection = (pred_bin * mask).sum()
            union = pred_bin.sum() + mask.sum() - intersection
            all_iou_no_tta.append(((intersection + 1e-6) / (union + 1e-6)).item())
            all_dice_no_tta.append(((2 * intersection + 1e-6) / (pred_bin.sum() + mask.sum() + 1e-6)).item())

            # With TTA
            pred_tta = tta(img)
            pred_bin_tta = (pred_tta > 0.5).float()
            intersection = (pred_bin_tta * mask).sum()
            union = pred_bin_tta.sum() + mask.sum() - intersection
            all_iou_tta.append(((intersection + 1e-6) / (union + 1e-6)).item())
            all_dice_tta.append(((2 * intersection + 1e-6) / (pred_bin_tta.sum() + mask.sum() + 1e-6)).item())

print('\n' + '='*60)
print('RESULTS - EXPERIMENT 3 (Attention U-Net)')
print('='*60)
print('\nWithout TTA:')
print(f'  IoU:  {np.mean(all_iou_no_tta):.4f} +/- {np.std(all_iou_no_tta):.4f}')
print(f'  Dice: {np.mean(all_dice_no_tta):.4f} +/- {np.std(all_dice_no_tta):.4f}')
print('\nWith TTA:')
print(f'  IoU:  {np.mean(all_iou_tta):.4f} +/- {np.std(all_iou_tta):.4f}')
print(f'  Dice: {np.mean(all_dice_tta):.4f} +/- {np.std(all_dice_tta):.4f}')

## 14. Visualize Predictions

In [None]:
def predict_and_show(dataset, indices):
    fig, axes = plt.subplots(len(indices), 4, figsize=(20, 5*len(indices)))

    for i, idx in enumerate(indices):
        img, mask = dataset[idx]

        with torch.no_grad():
            pred = model(img.unsqueeze(0).to(device))
            pred = torch.sigmoid(pred).cpu().squeeze().numpy()

        img_np = img.numpy().transpose(1, 2, 0)
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)

        mask_np = mask.squeeze().numpy()
        pred_bin = (pred > 0.5).astype(np.float32)

        overlay = img_np.copy()
        overlay[pred_bin > 0.5] = overlay[pred_bin > 0.5] * 0.5 + np.array([0, 1, 0]) * 0.5

        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title('Image')
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 2].imshow(pred_bin, cmap='gray')
        axes[i, 2].set_title('Prediction')
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Overlay')

        for ax in axes[i]: ax.axis('off')

    plt.tight_layout()
    plt.show()

predict_and_show(val_dataset, [0, 1, 2, 3])

## 15. Save Results

In [None]:
import pickle

results_exp3 = {
    'history': history,
    'all_iou_no_tta': all_iou_no_tta,
    'all_dice_no_tta': all_dice_no_tta,
    'all_iou_tta': all_iou_tta,
    'all_dice_tta': all_dice_tta,
    'best_dice': best_dice
}

with open('results_exp3.pkl', 'wb') as f:
    pickle.dump(results_exp3, f)

print('Results saved to results_exp3.pkl')

---
## Summary of Architecture Modifications

| Component | Description |
|------------|----------|
| **Attention Gates** | Attention mechanisms in skip connections that allow the model to focus on relevant regions, suppressing irrelevant responses |
| **SE Blocks** | Squeeze-and-Excitation blocks for adaptive recalibration of feature channels |
| **ASPP** | Atrous Spatial Pyramid Pooling at bottleneck for multi-scale context capture |
| **Focal Loss** | Added to the loss function to handle class imbalance |
| **Discriminative LR** | Differentiated learning rate for encoder (lower) and decoder (higher) |
| **OneCycleLR** | More aggressive scheduler for better generalization |

This is a substantial modification to the network topology, not just an increase in width or depth.