# Residual SE U-Net for Fetal Head Segmentation
## Training Notebook (Kaggle Compatible)

**Platform:** Kaggle Notebooks with GPU acceleration

### Architecture Overview

This notebook implements a **Residual SE U-Net** for medical image segmentation with the following components:

**Core Architecture:**
- U-Net encoder-decoder structure with skip connections
- Residual blocks (Conv2d → BatchNorm → ReLU → Conv2d → BatchNorm + skip connection)
- MaxPool2d downsampling (encoder) and ConvTranspose2d upsampling (decoder)

**Key Innovations:**

1. **Residual Connections:**
   - Identity shortcuts in every encoder/decoder block
   - Improved gradient flow for deeper networks
   - Prevents degradation in network training

2. **Squeeze-and-Excitation (SE) Blocks:**
   - Channel-wise attention mechanism (reduction ratio: 16)
   - Applied after each residual block and on skip connections
   - Learns to emphasize informative features and suppress less useful ones
   - Global average pooling → FC layers → sigmoid gating

**Output:**
- Sigmoid activation applied in model (outputs probabilities [0, 1])
- Compatible with DiceBCELoss for balanced optimization

## 1. Environment Setup

### Kaggle Configuration

**Dataset:** `fhs-residual-se-unet` (must be added to notebook)

**Directory Structure:**
- **Project Root:** `/kaggle/input/fhs-residual-se-unet/` (read-only)
- **Outputs:** `/kaggle/working/results/`
  - Checkpoints, logs, predictions, and visualizations
  - Automatically available for download after training completes

**Steps:**
1. Verify dataset is attached to notebook
2. Install/upgrade dependencies (Albumentations 1.4.0, specific NumPy/SciPy versions)
3. Import required modules and verify CUDA availability

In [None]:
import os
import sys
from pathlib import Path

print("[Kaggle Setup]")

# Setup paths for Kaggle
project_root = Path('/kaggle/input/fhs-residual-se-unet')
output_root = Path('/kaggle/working')
cache_root = output_root / 'cache'

# Verify dataset exists
if not project_root.exists():
    raise RuntimeError(
        f"Dataset not found at {project_root}\n"
        f"Please add the 'fhs-residual-se-unet' dataset to your Kaggle notebook."
    )

if not (project_root / 'accuracy_focus').exists():
    raise RuntimeError(
        f"'accuracy_focus' folder not found in {project_root}\n"
        f"Please ensure your dataset structure is correct."
    )

print(f"Project root: {project_root} (read-only)")
print(f"Output root: {output_root} (writable)")

# Add project to path
sys.path.insert(0, str(project_root))

print(f"\n✓ Environment setup complete")

In [None]:
# Install required packages for Kaggle
print("Installing required packages...")

# This ensures override Kaggle's pre-installed packages
!pip install --force-reinstall --no-cache-dir -q \
    "numpy==1.26.4" \
    "scipy==1.11.4" \
    "scikit-learn==1.5.1" \
    "albumentations==1.4.0" \
    "opencv-python-headless==4.9.0.80" \
    "PyYAML>=5.4" \
    "tqdm>=4.62"

print("\n✓ Packages installed successfully")

In [None]:
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Import from project structure
from accuracy_focus.improved_unet.src.models.residual_se_unet.residual_se_unet_model import ResidualSEUNet
from accuracy_focus.standard_unet.src.losses import DiceBCELoss
from shared.src.data import HC18Dataset
from shared.src.metrics.segmentation_metrics import dice_coefficient, iou_score, pixel_accuracy
from shared.src.utils.visualization import save_prediction_grid, visualize_sample
from shared.src.utils.transforms import get_transforms

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Load Configuration

In [None]:
# Load configuration (use improved_unet config as template)
config_path = project_root / 'accuracy_focus' / 'improved_unet' / 'configs' / 'residual_se_unet_config.yaml'

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Adjust output paths for Kaggle environment
print(f"Adjusting paths for Kaggle environment...")
config['logging']['checkpoint_dir'] = str(output_root / 'results' / 'checkpoints')
config['logging']['log_dir'] = str(output_root / 'results' / 'logs')
config['logging']['prediction_dir'] = str(output_root / 'results' / 'predictions')
config['logging']['visualization_dir'] = str(output_root / 'results' / 'visualizations')

print(f"  Outputs will be saved to: {output_root / 'results'}")

print("\nConfiguration loaded:")
print(f"  Model: {config['model']['name']}")
print(f"  Base Filters: {config['model']['base_filters']}")
print(f"  SE Reduction Ratio: {config['model']['reduction_ratio']}")
print(f"  Learning Rate: {config['training']['optimizer']['lr']}")
print(f"  Loss Function: {config['loss']['name']}")
print(f"  Batch Size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")

## 3. Model Initialization

### Residual SE U-Net Architecture Details

**Encoder Path:**
- 4 downsampling blocks: 64 → 128 → 256 → 512 filters
- Each block: ResidualBlockSE (Conv2d + BN + ReLU + Conv2d + BN + skip) + SE
- Downsampling: MaxPool2d (2×2, stride=2)

**Bottleneck:**
- ResidualBlockSE with 1024 filters
- Captures highest-level semantic features with channel attention

**Decoder Path:**
- 4 upsampling blocks: 512 → 256 → 128 → 64 filters
- Upsampling: ConvTranspose2d (2×2, stride=2)
- Each block: ResidualBlockSE × 1 after skip connection concatenation

**Squeeze-and-Excitation (SE) Mechanism:**
- **Applied to:** Each residual block output + skip connections
- **Operation:** GlobalAvgPool → FC (reduce) → ReLU → FC (expand) → Sigmoid
- **Reduction ratio:** 16 (e.g., 512 channels → 32 → 512)
- **Purpose:** Recalibrate channel-wise feature responses adaptively

**Residual Connections:**
- Identity mappings in all encoder/decoder blocks
- Enables training of very deep networks (50+ layers possible)
- Gradient flows directly through shortcuts

**Output Layer:**
- Conv2d (1×1) to single channel
- **Sigmoid activation applied** → outputs probabilities [0, 1]
- Compatible with DiceBCELoss for hybrid optimization

In [None]:
# Set device
device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model
model = ResidualSEUNet(
    in_channels=config['model']['in_channels'],
    out_channels=config['model']['out_channels'],
    base_channels=config['model']['base_filters'],
    reduction_ratio=config['model']['reduction_ratio']
)
model = model.to(device)

# Count parameters
from accuracy_focus.improved_unet.src.models.residual_se_unet.residual_se_unet_model import count_parameters
total_params, trainable_params = count_parameters(model)

# Model summary
print(f"\nResidual SE U-Net Architecture:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / (1024**2):.2f} MB (float32)")
print(f"  Input channels: {config['model']['in_channels']}")
print(f"  Output channels: {config['model']['out_channels']}")
print(f"  Base filters: {config['model']['base_filters']}")
print(f"  SE reduction ratio: {config['model']['reduction_ratio']}")

# Test forward pass
test_input = torch.randn(1, 1, 256, 256).to(device)
test_output = model(test_input)
print(f"\nTest forward pass:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"  Output range: [{test_output.min().item():.4f}, {test_output.max().item():.4f}]")

## 4. Loss Function and Optimization

### Loss Function: DiceBCELoss

**Hybrid Loss Design:**
- **Dice Loss (80%):** Optimizes region overlap (DSC metric)
- **BCE Loss (20%):** Optimizes pixel-wise classification

**Why Hybrid Loss?**
- **Dice component:** Handles extreme class imbalance (2-10% foreground pixels typical)
- **BCE component:** Provides pixel-level gradient signals for sharp boundaries
- **Weighted combination:** Balances global structure (Dice) with local details (BCE)

**Expected input:**
- Model outputs probabilities [0, 1] (sigmoid already applied)
- Targets are binary masks {0, 1}

### Optimizer: Adam
- Adaptive learning rate per parameter (lr=1e-3)
- Momentum terms: betas=(0.9, 0.999)
- Weight decay: Configurable L2 regularization

### Learning Rate Scheduler: ReduceLROnPlateau
- Monitors validation Dice coefficient (mode='max')
- Reduces LR by factor=0.1 when performance plateaus
- Patience: Number of epochs without improvement before reduction
- Minimum LR: Prevents learning rate from becoming too small

In [None]:
# Loss function (DiceBCELoss - Combined Dice + BCE)
loss_config = config['loss']
dice_weight_config = loss_config.get('dice_weight', 0.8)
bce_weight_config = loss_config.get('bce_weight', 0.2)
smooth_config = loss_config.get('smooth', 1.0e-6)

criterion = DiceBCELoss(
    dice_weight=dice_weight_config,
    bce_weight=bce_weight_config,
    smooth=smooth_config
)
print(f"Loss Function: {loss_config['name']}")
print(f"  Dice weight: {dice_weight_config}")
print(f"  BCE weight: {bce_weight_config}")
print(f"  Smooth parameter: {smooth_config}")

# Optimizer (Adam)
optimizer_config = config['training']['optimizer']
optimizer = optim.Adam(
    model.parameters(),
    lr=optimizer_config['lr'],
    betas=tuple(optimizer_config['betas']),
    eps=optimizer_config['eps'],
    weight_decay=optimizer_config['weight_decay']
)
print(f"\nOptimizer: Adam")
print(f"  Learning rate: {optimizer_config['lr']}")

# Learning rate scheduler
scheduler_config = config['training']['scheduler']
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode=scheduler_config['mode'],
    factor=scheduler_config['factor'],
    patience=scheduler_config['patience'],
    min_lr=scheduler_config['min_lr']
)
print(f"\nScheduler: ReduceLROnPlateau")
print(f"  Mode: {scheduler_config['mode']}")
print(f"  Factor: {scheduler_config['factor']}")
print(f"  Patience: {scheduler_config['patience']}")

## 5. Dataset and Data Loaders

### Dataset: HC18 Grand Challenge
- **Source:** Fetal head circumference ultrasound images
- **Split:** Training / Validation / Test sets
- **Image size:** 256×256 pixels (grayscale)
- **Normalization:** Division by 255.0 → [0, 1] range

### Data Augmentation Strategy

**Training Augmentations (applied on-the-fly):**
- **HorizontalFlip** (p=0.5): Mirror symmetry
- **VerticalFlip** (p=0.5): Additional spatial variation
- **Rotation** (±20°, p=0.5): Orientation invariance
- **ShiftScaleRotate** (p=0.5):
  - Translation: ±10% (shift_limit=0.1)
  - Scaling: ±10% (scale_limit=0.1)
  - Combined transformations

**Validation/Test:** Preprocessing only (resize + normalize)

**Implementation:**
- Dynamic augmentation: New variations generated each epoch
- Synchronized transforms: Image-mask pairs augmented identically
- Library: Albumentations (highly optimized)

**DataLoader Configuration:**
- Batch size: From config (typically 8-16)
- num_workers: 0 (Kaggle compatibility, avoids multiprocessing issues)
- pin_memory: True (faster GPU transfer)
- Shuffle: True (training only)

In [None]:
data_config = config['data']
training_config = config['training']

# Helper to build paths
def get_path(config_path):
    """Helper to handle both absolute and relative paths"""
    p = Path(config_path)
    if p.is_absolute():
        return str(p)
    else:
        return str(project_root / config_path)

# Create augmentation transforms
print("Creating augmentation transforms...")
train_transform = get_transforms(height=256, width=256, is_train=True)
val_transform = get_transforms(height=256, width=256, is_train=False)
print("  Train transform: WITH augmentation (HorizontalFlip, Rotation, ShiftScaleRotate)")
print("  Val transform: WITHOUT augmentation (resize + normalize only)")

# Create datasets - using HC18Dataset for on-the-fly augmentation
print("\nCreating training dataset...")
train_dataset = HC18Dataset(
    image_dir=get_path(data_config['train_images']),
    mask_dir=get_path(data_config['train_masks']),
    transform=train_transform
)

print("Creating validation dataset...")
val_dataset = HC18Dataset(
    image_dir=get_path(data_config['val_images']),
    mask_dir=get_path(data_config['val_masks']),
    transform=val_transform
)

print("Creating test dataset...")
test_dataset = HC18Dataset(
    image_dir=get_path(data_config['test_images']),
    mask_dir=get_path(data_config['test_masks']),
    transform=val_transform
)

# Use num_workers=0 for Kaggle to avoid multiprocessing issues
num_workers = 0
print(f"\nDataLoader settings:")
print(f"  num_workers: {num_workers} (disabled for Kaggle)")
print(f"  Batch size: {training_config['batch_size']}")

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=training_config['batch_size'],
    shuffle=True,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=training_config['batch_size'],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

test_loader = DataLoader(
    test_dataset,
    batch_size=training_config['batch_size'],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

print(f"\n{'='*60}")
print(f"Datasets Ready:")
print(f"{'='*60}")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Test samples: {len(test_dataset)}")
print(f"  Batch size: {training_config['batch_size']}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")
print(f"  Train augmentation: ENABLED (on-the-fly)")
print(f"  Val/Test augmentation: DISABLED")
print(f"{'='*60}\n")

## 6. Data Validation

### Pre-Training Checks

Verify data integrity before training:
1. **Batch shapes:** Ensure correct tensor dimensions
2. **Value ranges:** Images [0, 1], masks {0, 1}
3. **Mask binarization:** Confirm masks are strictly binary
4. **Foreground ratio:** Check if ~2-10% (typical for fetal head)
5. **Visual inspection:** Display sample augmentations

In [None]:
# Get a batch of training data
sample_images, sample_masks = next(iter(train_loader))

print(f"Sample batch:")
print(f"  Images shape: {sample_images.shape}")
print(f"  Masks shape: {sample_masks.shape}")
print(f"  Image range: [{sample_images.min():.4f}, {sample_images.max():.4f}]")
print(f"  Mask range: [{sample_masks.min():.4f}, {sample_masks.max():.4f}]")
print(f"  Mask unique values: {torch.unique(sample_masks)}")
print(f"  Mask mean (% foreground): {sample_masks.mean():.4f}")

# CRITICAL CHECK: Ensure masks are binary {0, 1}
if not torch.all((sample_masks == 0) | (sample_masks == 1)):
    print("\n⚠️  WARNING: Masks are not binary! Check preprocessing.")
else:
    print("\n✓ Masks are properly binary {0, 1}")

# Check if masks have reasonable foreground ratio (2-10% typical for fetal head)
fg_ratio = sample_masks.mean().item()
if fg_ratio < 0.01 or fg_ratio > 0.3:
    print(f"⚠️  WARNING: Unusual foreground ratio: {fg_ratio:.2%} (expected 2-10%)")
else:
    print(f"✓ Foreground ratio looks reasonable: {fg_ratio:.2%}")

# Visualize first sample
visualize_sample(sample_images[0], sample_masks[0])

## 7. Training and Validation Functions

### Training Loop (per epoch)
1. Set model to training mode
2. Iterate through batches with progress bar
3. Forward pass: model(images) → predictions
4. Compute loss: DiceBCELoss(predictions, masks)
5. Backward pass + optimizer step
6. Track running loss statistics

### Validation Loop (per epoch)
1. Set model to evaluation mode
2. Disable gradient computation (torch.no_grad)
3. Forward pass on validation set
4. Compute loss and metrics:
   - **Dice Coefficient (DSC):** Primary metric
   - **IoU Score:** Intersection over Union
   - **Pixel Accuracy:** Overall correctness
5. Return average metrics across all samples

### Early Stopping Strategy
- Monitor validation Dice coefficient
- Save best model checkpoint when Dice improves
- Stop training if no improvement for `patience` epochs

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """
    Train for one epoch
    """
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    for batch_idx, (images, masks) in enumerate(pbar):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update statistics
        running_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    epoch_loss = running_loss / len(dataloader)
    return epoch_loss


def validate(model, dataloader, criterion, device, epoch):
    """
    Validate the model
    """
    model.eval()
    running_loss = 0.0
    dice_scores = []
    iou_scores = []
    pa_scores = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Val]", leave=False)
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, masks)
            running_loss += loss.item()
            
            # Calculate metrics
            preds = (outputs > 0.5).float()
            
            for i in range(images.size(0)):
                dice = dice_coefficient(preds[i], masks[i])
                iou = iou_score(preds[i], masks[i])
                pa = pixel_accuracy(preds[i], masks[i])
                
                dice_scores.append(dice.item())
                iou_scores.append(iou.item())
                pa_scores.append(pa.item())
            
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'dice': f"{np.mean(dice_scores):.4f}"
            })
    
    # Calculate average metrics
    val_loss = running_loss / len(dataloader)
    val_dice = np.mean(dice_scores)
    val_iou = np.mean(iou_scores)
    val_pa = np.mean(pa_scores)
    
    return val_loss, val_dice, val_iou, val_pa


print("Training functions defined!")

## 8. Training Execution

### Training Configuration
- **Epochs:** From config (typically 50-100)
- **Early stopping:** Patience for convergence
- **Checkpoint saving:** Best model based on validation Dice
- **Learning rate scheduling:** Automatic reduction on plateau

### Monitored Metrics
- **Train Loss:** DiceBCELoss on training set
- **Val Loss:** DiceBCELoss on validation set
- **Val Dice:** Primary evaluation metric (DSC)
- **Val IoU:** Intersection over Union score
- **Val PA:** Pixel accuracy
- **Learning Rate:** Current optimizer learning rate

### Training Progress Indicators
- 🏆 New best model saved
- ⬇️ Learning rate reduced
- ⚠️ No improvement warning
- ⛔ Early stopping triggered

In [None]:
print(f"{'='*60}")
print(f"Starting Training - Residual SE U-Net")
print(f"{'='*60}")

# Training configuration
num_epochs = config['training']['num_epochs']
patience = config['training']['early_stopping_patience']

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_dice': [],
    'val_iou': [],
    'val_pa': [],
    'lr': []
}

best_dice = 0.0
epochs_without_improvement = 0

print(f"Epochs: {num_epochs}")
print(f"Early Stopping Patience: {patience}")
print(f"{'='*60}\n")

# Main training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 70)
    
    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
    
    # Validate
    val_loss, val_dice, val_iou, val_pa = validate(model, val_loader, criterion, device, epoch)
    
    # Update learning rate
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_dice)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    history['val_iou'].append(val_iou)
    history['val_pa'].append(val_pa)
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\n{'='*100}")
    dice_indicator = ' 🏆' if val_dice > best_dice else ''
    lr_change = f' ⬇️ (reduced from {old_lr:.6f})' if current_lr < old_lr else ''
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")
    print(f"Val mIoU: {val_iou:.4f} | Val mPA: {val_pa:.4f}")
    print(f"LR: {current_lr:.6f}{lr_change}")
    print(f"{'='*100}")
    
    # Check for improvement
    is_best = val_dice > best_dice
    if is_best:
        best_dice = val_dice
        epochs_without_improvement = 0
        
        # Save best model
        checkpoint_dir = Path(get_path(config['logging']['checkpoint_dir']))
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        best_model_path = checkpoint_dir / 'best_model_residual_se_unet_v2.pth'
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_dice': best_dice,
            'history': history,
            'config': config
        }, best_model_path)
        
        print(f"  → Saved best model (Dice: {best_dice:.4f})")
    else:
        epochs_without_improvement += 1
        print(f"  ⚠️  No improvement for {epochs_without_improvement}/{patience} epochs")
    
    # Early stopping
    if epochs_without_improvement >= patience:
        print(f"\n{'='*70}")
        print(f"⛔ EARLY STOPPING TRIGGERED")
        print(f"{'='*70}")
        print(f"  Stopped at epoch: {epoch+1}")
        print(f"  Best Dice Score:  {best_dice:.4f}")
        print(f"  Patience limit:   {patience} epochs without improvement")
        print(f"{'='*70}")
        break

# Final summary
print(f"\n{'='*60}")
print(f"Training Completed!")
print(f"{'='*60}")
print(f"Best Validation Dice: {best_dice:.4f}")
print(f"Best Validation IoU:  {max(history['val_iou']):.4f}")
print(f"Best Validation PA:   {max(history['val_pa']):.4f}")
print(f"{'='*60}\n")

## 9. Training Visualization

### Learning Curves

**Plot 1 - Loss Curves:**
- Train vs Validation loss over epochs
- Monitors overfitting (divergence between curves)

**Plot 2 - Dice Coefficient:**
- Validation Dice over epochs
- Horizontal line marks best performance

**Plot 3 - IoU Score:**
- Validation IoU over epochs
- Secondary segmentation quality metric

**Plot 4 - Learning Rate Schedule:**
- Learning rate evolution (log scale)
- Shows ReduceLROnPlateau reductions

**Saved to:** `/kaggle/working/results/logs/training_curves.png`

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss (Dice + BCE)', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Dice coefficient
axes[0, 1].plot(history['val_dice'], label='Val Dice', color='green', linewidth=2)
axes[0, 1].axhline(y=best_dice, color='red', linestyle='--', label=f'Best: {best_dice:.4f}')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Dice Coefficient', fontsize=12)
axes[0, 1].set_title('Validation Dice Coefficient', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# IoU
axes[1, 0].plot(history['val_iou'], label='Val IoU', color='orange', linewidth=2)
axes[1, 0].axhline(y=max(history['val_iou']), color='red', linestyle='--', 
                   label=f"Best: {max(history['val_iou']):.4f}")
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('IoU Score', fontsize=12)
axes[1, 0].set_title('Validation IoU Score', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].plot(history['lr'], label='Learning Rate', color='red', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Learning Rate', fontsize=12)
axes[1, 1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=11)
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')

plt.tight_layout()
plt.show()

# Save figure
log_dir = Path(get_path(config['logging']['log_dir']))
log_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(log_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
print(f"Training curves saved to {log_dir / 'training_curves.png'}")

## 10. Model Evaluation and Visualization

### Evaluation Process
1. Load best checkpoint (highest validation Dice)
2. Run inference on test set
3. Apply threshold (0.5) to probabilities
4. Compute per-sample metrics (Dice, IoU, PA)
5. Visualize predictions with ground truth

### Prediction Visualization
- **Column 1:** Input ultrasound image
- **Column 2:** Ground truth segmentation mask
- **Column 3:** Model prediction with metrics overlay

**Saved to:** `/kaggle/working/results/predictions/sample_predictions.png`

In [None]:
# Load best model
checkpoint_path = Path(get_path(config['logging']['checkpoint_dir'])) / 'best_model_residual_se_unet_v2.pth'

checkpoint = torch.load(checkpoint_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best Dice Score: {checkpoint['best_dice']:.4f}")

# Get validation samples
test_images, test_masks = next(iter(test_loader))
test_images = test_images.to(device)

# Generate predictions
with torch.no_grad():
    test_probs = model(test_images)
    test_preds = (test_probs > 0.5).float()

# Visualize predictions
num_samples = min(4, len(test_images))
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

if num_samples == 1:
    axes = axes.reshape(1, -1)

for i in range(num_samples):
    # Move to CPU and convert to numpy
    img = test_images[i, 0].cpu().numpy()
    mask = test_masks[i, 0].numpy()
    pred = test_preds[i, 0].cpu().numpy()
    
    # Calculate metrics for this sample (ensure both tensors on same device)
    pred_tensor = test_preds[i].cpu()
    mask_tensor = test_masks[i].to(pred_tensor.device)
    dice = dice_coefficient(pred_tensor, mask_tensor).item()
    iou = iou_score(pred_tensor, mask_tensor).item()
    pa = pixel_accuracy(pred_tensor, mask_tensor)
    
    # Input image
    axes[i, 0].imshow(img, cmap='gray')
    axes[i, 0].set_title('Input Image', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(mask, cmap='gray')
    axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # Prediction
    axes[i, 2].imshow(pred, cmap='gray')
    axes[i, 2].set_title(f'Prediction\nDice: {dice:.4f} | IoU: {iou:.4f} | PA: {pa:.4f}', 
                         fontsize=10, fontweight='bold')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

# Save predictions
pred_dir = Path(get_path(config['logging']['prediction_dir']))
pred_dir.mkdir(parents=True, exist_ok=True)
save_prediction_grid(test_images[:4].cpu(), test_masks[:4], test_preds[:4].cpu(), 
                    str(pred_dir / 'sample_predictions.png'), num_samples=4)

## 11. Experiment Summary

### Model Architecture: Residual SE U-Net

**Key Components:**
1. **Residual Blocks:** Identity shortcuts for improved gradient flow
2. **SE Blocks:** Channel-wise attention (reduction ratio 16)
3. **Skip Connections:** Encoder features with SE recalibration
4. **Activation:** ReLU (hidden layers), Sigmoid (output)
5. **Normalization:** BatchNorm after each convolution

**Model Specifications:**
- Input: 1-channel grayscale (256×256)
- Output: 1-channel probability map [0, 1]
- Base filters: 64
- Depth: 5 levels (4 encoder + bottleneck)

### Training Configuration

**Loss Function:** DiceBCELoss (0.8 Dice + 0.2 BCE)
- Combines region overlap optimization with pixel-wise supervision
- Handles class imbalance effectively

**Optimizer:** Adam
- Learning rate: 1e-3
- Weight decay: Configurable L2 regularization

**Data Augmentation:**
- Horizontal/Vertical flip, Rotation (±20°)
- ShiftScaleRotate (±10% translation/scale)

### Why Residual SE U-Net?

**Advantages over Standard U-Net:**
1. **Residual connections:** Deeper networks without degradation
2. **Channel attention:** Automatic feature recalibration
3. **Better gradient flow:** Faster convergence, higher accuracy
4. **Feature reuse:** More efficient parameter usage

**Comparison to Attention U-Net:**
- **Attention U-Net:** Spatial attention on skip connections
- **Residual SE U-Net:** Channel attention everywhere + residual learning
- **Trade-off:** More parameters, potentially better feature learning

### Expected Performance

**Target Metrics:**
- **Dice Coefficient:** >97.5%
- **IoU Score:** >95.0%
- **Pixel Accuracy:** >99.0%

---

### Kaggle-Specific Notes

**Output Locations:**
- Checkpoints: `/kaggle/working/results/checkpoints/`
- Logs: `/kaggle/working/results/logs/`
- Predictions: `/kaggle/working/results/predictions/`
- Visualizations: `/kaggle/working/results/visualizations/`

**Download Results:**
All outputs automatically available in Kaggle's output tab after notebook finishes execution.

**GPU Usage:**
Enable GPU acceleration in Kaggle notebook settings for optimal training speed (~10-20 min per epoch with GPU vs hours on CPU).