# Attention U-Net Training Notebook (Google Colab)
## Fetal Head Segmentation in Ultrasound Images

**Compatible with:** Google Colab

This notebook implements an **Attention U-Net** architecture with:
1. Standard U-Net encoder/decoder with convolutional blocks
2. **Attention Gates** applied to skip connections before concatenation
3. Spatial attention mechanism to focus on relevant features
4. Improved feature selection through gating mechanism
5. Apply aggressive transforms for better generalization
6. Add Gradient Clipping and Learning Rate Warm-up

**Key Innovation:** Attention Gates weight the skip connections from the encoder, allowing the network to focus on relevant spatial regions and suppress irrelevant activations.

## 1. Setup and Imports

**Google Colab Environment:**
- Clone project from GitHub: `https://github.com/TrinhThaiSonDHQT/Fetal-Head-Segmentation`
- Project cloned to `/content/Fetal-Head-Segmentation/`
- Outputs saved to `/content/Fetal-Head-Segmentation/results/` (can be downloaded after training)

In [None]:
# Clone the GitHub repository
import os

# Check if already cloned
if not os.path.exists('/content/Fetal-Head-Segmentation'):
    print("Cloning repository from GitHub...")
    !git clone https://github.com/TrinhThaiSonDHQT/Fetal-Head-Segmentation.git
    print("✓ Repository cloned successfully")
else:
    print("✓ Repository already exists")

In [None]:
import os
import sys
from pathlib import Path

print("[Google Colab Setup]")

# Setup paths for Google Colab
project_root = Path('/content/Fetal-Head-Segmentation')
output_root = project_root / 'results'
cache_root = output_root / 'cache'

# Verify project exists
if not project_root.exists():
    raise RuntimeError(
        f"Project not found at {project_root}\n"
        f"Please run the previous cell to clone the repository from GitHub."
    )

if not (project_root / 'accuracy_focus').exists():
    raise RuntimeError(
        f"'accuracy_focus' folder not found in {project_root}\n"
        f"Please ensure the repository was cloned correctly."
    )

print(f"Project root: {project_root}")
print(f"Output root: {output_root}")

# Add project to path
sys.path.insert(0, str(project_root))

print(f"\n✓ Environment setup complete")

In [None]:
# Install required packages for Google Colab
print("Installing required packages...")

# Colab has most packages pre-installed (PyTorch, NumPy, Matplotlib, OpenCV)
# Pin Albumentations to 1.3.1 for compatibility with both Colab and Kaggle
!pip install -q albumentations==1.3.1

print("\n✓ Packages installed successfully")

In [None]:
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Import project modules
from accuracy_focus.attention_unet.src.losses.dice_loss import DiceLoss
from accuracy_focus.attention_unet.src.models.attention_unet import AttentionUNet

from shared.src.data import HC18Dataset
from shared.src.metrics.segmentation_metrics import dice_coefficient, iou_score, pixel_accuracy
from shared.src.utils.visualization import save_prediction_grid, visualize_sample
from shared.src.utils.transforms.aggressive_transforms import get_aggressive_transforms

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## 2. Load Configuration

In [None]:
# Load configuration
config_path = project_root / 'accuracy_focus' / 'attention_unet' / 'configs' / 'attention_unet_v2_config.yaml'

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Adjust output paths for Google Colab environment
print(f"Adjusting paths for Google Colab environment...")
config['logging']['checkpoint_dir'] = str(output_root / 'checkpoints')
config['logging']['log_dir'] = str(output_root / 'logs')
config['logging']['prediction_dir'] = str(output_root / 'predictions')
config['logging']['visualization_dir'] = str(output_root / 'visualizations')

print(f"  Outputs will be saved to: {output_root}")
print(f"  (download from Colab files panel after training)")

print("\nConfiguration loaded:")
print(f"  Model: {config['model']['name']}")
print(f"  Base Filters: {config['model']['base_filters']}")
print(f"  Learning Rate: {config['training']['optimizer']['lr']}")
print(f"  Loss Function: {config['loss']['name']}")
print(f"  Batch Size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")

## 3. Initialize Model

The Attention U-Net has:
- **Encoder**: 4 downsampling blocks with ConvBlock (64 → 128 → 256 → 512 channels)
- **Bottleneck**: ConvBlock (1024 channels)
- **Decoder**: 4 upsampling blocks with ConvBlock (512 → 256 → 128 → 64 channels)
- **Attention Gates**: Applied to skip connections before concatenation with decoder features
- **Skip Connections**: Attention-weighted features from encoder concatenated with decoder
- **Activation**: ReLU in convolutional blocks, Sigmoid output
- **Spatial Attention**: Attention gates focus on relevant spatial regions

In [None]:
# Set device
device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model (outputs probabilities with sigmoid)
model = AttentionUNet(
    in_channels=config['model']['in_channels'],
    out_channels=config['model']['out_channels'],
    base_filters=config['model']['base_filters'],
    use_sigmoid=True 
)
model = model.to(device)

# Count parameters
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

total_params, trainable_params = count_parameters(model)

# Model summary
print(f"\nAttention U-Net Architecture:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / (1024**2):.2f} MB (float32)")
print(f"  Input channels: {config['model']['in_channels']}")
print(f"  Output channels: {config['model']['out_channels']}")
print(f"  Base filters: {config['model']['base_filters']}")

# Test forward pass
test_input = torch.randn(1, 1, 256, 256).to(device)
test_output = model(test_input)
print(f"\nTest forward pass:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"  Output range: [{test_output.min().item():.4f}, {test_output.max().item():.4f}]")

## 4. Setup Loss and Optimizer

- **Loss**: DiceLoss (Sorensen-Dice Loss)
- **Optimizer**: Adam with learning rate from config
- **Scheduler**: ReduceLROnPlateau (monitors validation Dice)

In [None]:
# Loss function with automatic class imbalance handling
loss_config = config['loss']
# Loss Function (Dice Loss)
smooth_param = loss_config['smooth']
criterion = DiceLoss(smooth=smooth_param)
print(f"Loss Function: {loss_config['name']}")
print(f"  Smooth parameter: {smooth_param}")

# Optimizer (Adam)
optimizer_config = config['training']['optimizer']
optimizer = optim.Adam(
    model.parameters(),
    lr=optimizer_config['lr'],
    betas=tuple(optimizer_config['betas']),
    eps=optimizer_config['eps'],
    weight_decay=optimizer_config['weight_decay']
)
print(f"\nOptimizer: Adam")
print(f"  Learning rate: {optimizer_config['lr']}")
print(f"  Weight decay: {optimizer_config['weight_decay']}")

# Learning rate scheduler
scheduler_config = config['training']['scheduler']
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode=scheduler_config['mode'],
    factor=scheduler_config['factor'],
    patience=scheduler_config['patience'],
    min_lr=scheduler_config['min_lr']
)
print(f"\nScheduler: ReduceLROnPlateau")
print(f"  Mode: {scheduler_config['mode']}")
print(f"  Factor: {scheduler_config['factor']}")
print(f"  Patience: {scheduler_config['patience']}")

## 5. Prepare Data Loaders

**Preprocessing** (applied to all images):
- Images normalized by dividing by 255.0
- Resized to 256×256 pixels
- Converted to PyTorch tensors (C, H, W)

**Augmentation** (training only - applied on-the-fly):
- Horizontal & Vertical flip (p=0.5)
- Rotation (±20°, p=0.5)
- ShiftScaleRotate: Translation (±10%), Scale (±10%), p=0.5
- Elastic deformation simulates tissue movement in ultrasound
- Grid distortion creates localized warping
- Gaussian noise simulates sensor noise
- CLAHE enhances local contrast

**Note:** Augmentations are applied dynamically during training. Fresh augmented samples are generated every epoch for better model generalization. Validation uses only preprocessing without augmentation.

In [None]:
data_config = config['data']
training_config = config['training']

# Helper to build paths
def get_path(config_path):
    """Helper to handle both absolute and relative paths"""
    p = Path(config_path)
    if p.is_absolute():
        return str(p)
    else:
        return str(project_root / config_path)

# Create augmentation transforms
print("Creating augmentation transforms...")
train_transform = get_aggressive_transforms(height=256, width=256, is_train=True)
val_transform = get_aggressive_transforms(height=256, width=256, is_train=False)
print("  Train transform: WITH aggressive augmentation")
print("  Val transform: WITHOUT augmentation (resize + normalize only)")

# Create datasets - using HC18Dataset for on-the-fly augmentation
print("\nCreating training dataset...")
train_dataset = HC18Dataset(
    image_dir=get_path(data_config['train_images']),
    mask_dir=get_path(data_config['train_masks']),
    transform=train_transform
)

print("Creating validation dataset...")
val_dataset = HC18Dataset(
    image_dir=get_path(data_config['val_images']),
    mask_dir=get_path(data_config['val_masks']),
    transform=val_transform
)

# Use num_workers=2 for Colab (Colab has good CPU support)
num_workers = 2
print(f"\nDataLoader settings:")
print(f"  num_workers: {num_workers}")
print(f"  Batch size: {training_config['batch_size']}")

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=training_config['batch_size'],
    shuffle=True,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=training_config['batch_size'],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

print(f"\n{'='*60}")
print(f"Datasets Ready:")
print(f"{'='*60}")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Batch size: {training_config['batch_size']}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Train augmentation: ENABLED (on-the-fly)")
print(f"  Val augmentation: DISABLED")
print(f"{'='*60}\n")

## 6. Visualize Sample Data

In [None]:
# Get a batch of training data
sample_images, sample_masks = next(iter(train_loader))

print(f"Sample batch:")
print(f"  Images shape: {sample_images.shape}")
print(f"  Masks shape: {sample_masks.shape}")
print(f"  Image range: [{sample_images.min():.4f}, {sample_images.max():.4f}]")
print(f"  Mask range: [{sample_masks.min():.4f}, {sample_masks.max():.4f}]")
print(f"  Mask unique values: {torch.unique(sample_masks)}")
print(f"  Mask mean (% foreground): {sample_masks.mean():.4f}")

# CRITICAL CHECK: Ensure masks are binary {0, 1}
if not torch.all((sample_masks == 0) | (sample_masks == 1)):
    print("\n⚠️  WARNING: Masks are not binary! Check preprocessing.")
else:
    print("\n✓ Masks are properly binary {0, 1}")

# Visualize first sample
visualize_sample(sample_images[0], sample_masks[0])

## 7. Training Functions

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """
    Train for one epoch
    """
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    for batch_idx, (images, masks) in enumerate(pbar):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()

        # Gradient Clipping (prevents gradient explosion)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        
        # Update statistics
        running_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    epoch_loss = running_loss / len(dataloader)
    return epoch_loss


def validate(model, dataloader, criterion, device, epoch):
    """
    Validate the model
    """
    model.eval()
    running_loss = 0.0
    dice_scores = []
    iou_scores = []
    pa_scores = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Val]", leave=False)
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, masks)
            running_loss += loss.item()
            
            # Calculate metrics
            preds = (outputs > 0.5).float()
            
            for i in range(images.size(0)):
                dice = dice_coefficient(preds[i], masks[i])
                iou = iou_score(preds[i], masks[i])
                pa = pixel_accuracy(preds[i], masks[i])
                
                dice_scores.append(dice.item())
                iou_scores.append(iou.item())
                pa_scores.append(pa.item())
            
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'dice': f"{np.mean(dice_scores):.4f}"
            })
    
    # Calculate average metrics
    val_loss = running_loss / len(dataloader)
    val_dice = np.mean(dice_scores)
    val_iou = np.mean(iou_scores)
    val_pa = np.mean(pa_scores)
    
    return val_loss, val_dice, val_iou, val_pa


print("Training functions defined!")

## 8. Training Loop

In [None]:
print(f"{'='*60}")
print(f"Starting Training - Attention U-Net")
print(f"{'='*60}")

# Learning Rate Warm-up
base_lr_config = config['training']['optimizer']['lr']
def get_lr(epoch, base_lr=base_lr_config, warmup_epochs=10):
    if epoch < warmup_epochs:
        return base_lr * (epoch + 1) / warmup_epochs
    else:
        return base_lr

# Training configuration
num_epochs = config['training']['num_epochs']
patience = config['training']['early_stopping_patience']

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_dice': [],
    'val_iou': [],
    'val_pa': [],
    'lr': []
}

best_dice = 0.0
epochs_without_improvement = 0

print(f"Epochs: {num_epochs}")
print(f"Early Stopping Patience: {patience}")
print(f"Learning Rate Warm-up: 10 epochs")
print(f"{'='*60}\n")

# Main training loop
for epoch in range(num_epochs):
    # Apply Learning Rate Warm-up
    current_lr = get_lr(epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr
    
    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
    
    # Validate
    val_loss, val_dice, val_iou, val_pa = validate(model, val_loader, criterion, device, epoch)
    
    # Update learning rate with scheduler (after warm-up)
    old_lr = current_lr
    if epoch >= 10:  # Only apply scheduler after warm-up
        scheduler.step(val_dice)
        current_lr = optimizer.param_groups[0]['lr']
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    history['val_iou'].append(val_iou)
    history['val_pa'].append(val_pa)
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"{'='*100}")
    dice_indicator = ' 🏆' if val_dice > best_dice else ''
    lr_change = f' ⬇️ (reduced from {old_lr:.6f})' if current_lr < old_lr else ''
    warmup_indicator = ' 🔥 (warm-up)' if epoch < 10 else ''
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}{dice_indicator}")
    print(f"Val mIoU: {val_iou:.4f} | Val mPA: {val_pa:.4f}")
    print(f"LR: {current_lr:.6f}{lr_change}{warmup_indicator}")
    print(f"{'='*100}")
    
    # Check for improvement
    is_best = val_dice > best_dice
    if is_best:
        best_dice = val_dice
        epochs_without_improvement = 0
        
        # Save best model
        checkpoint_dir = Path(get_path(config['logging']['checkpoint_dir']))
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        best_model_path = checkpoint_dir / 'best_model_attention_unet_v3.pth'
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_dice': best_dice,
            'history': history,
            'config': config
        }, best_model_path)
        
        print(f"  → Saved best model (Dice: {best_dice:.4f})")
    else:
        epochs_without_improvement += 1
        print(f"  ⚠️  No improvement for {epochs_without_improvement}/{patience} epochs")
    
    # Early stopping
    if epochs_without_improvement >= patience:
        print(f"\n{'='*70}")
        print(f"⛔ EARLY STOPPING TRIGGERED")
        print(f"{'='*70}")
        print(f"  Stopped at epoch: {epoch+1}")
        print(f"  Best Dice Score:  {best_dice:.4f}")
        print(f"  Patience limit:   {patience} epochs without improvement")
        print(f"{'='*70}")
        break

# Final summary
print(f"\n{'='*60}")
print(f"Training Completed!")
print(f"{'='*60}")
print(f"Best Validation Dice: {best_dice:.4f}")
print(f"Best Validation IoU:  {max(history['val_iou']):.4f}")
print(f"Best Validation PA:   {max(history['val_pa']):.4f}")
print(f"{'='*60}\n")

## 9. Plot Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss (Dice)', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Dice coefficient
axes[0, 1].plot(history['val_dice'], label='Val Dice', color='green', linewidth=2)
axes[0, 1].axhline(y=best_dice, color='red', linestyle='--', label=f'Best: {best_dice:.4f}')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Dice Coefficient', fontsize=12)
axes[0, 1].set_title('Validation Dice Coefficient', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# IoU
axes[1, 0].plot(history['val_iou'], label='Val IoU', color='orange', linewidth=2)
axes[1, 0].axhline(y=max(history['val_iou']), color='red', linestyle='--', 
                   label=f"Best: {max(history['val_iou']):.4f}")
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('IoU Score', fontsize=12)
axes[1, 0].set_title('Validation IoU Score', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].plot(history['lr'], label='Learning Rate', color='red', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Learning Rate', fontsize=12)
axes[1, 1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=11)
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')

plt.tight_layout()
plt.show()

# Save figure
log_dir = Path(get_path(config['logging']['log_dir']))
log_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(log_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
print(f"Training curves saved to {log_dir / 'training_curves.png'}")

## 10. Visualize Predictions

In [None]:
# Load best model
checkpoint_path = Path(get_path(config['logging']['checkpoint_dir'])) / 'best_model_attention_unet_v3.pth'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best Dice Score: {checkpoint['best_dice']:.4f}")

# Get validation samples
val_images, val_masks = next(iter(val_loader))
val_images = val_images.to(device)

# Generate predictions
with torch.no_grad():
    # Model outputs probabilities (sigmoid already applied)
    val_probs = model(val_images)
    val_preds = (val_probs > 0.5).float()

# Visualize predictions
num_samples = min(4, len(val_images))
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

if num_samples == 1:
    axes = axes.reshape(1, -1)

for i in range(num_samples):
    # Move to CPU and convert to numpy
    img = val_images[i, 0].cpu().numpy()
    mask = val_masks[i, 0].numpy()
    pred = val_preds[i, 0].cpu().numpy()
    
    # Calculate metrics for this sample
    dice = dice_coefficient(val_preds[i].cpu(), val_masks[i].to(device)).item()
    iou = iou_score(val_preds[i].cpu(), val_masks[i].to(device)).item()
    
    # Input image
    axes[i, 0].imshow(img, cmap='gray')
    axes[i, 0].set_title('Input Image', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(mask, cmap='gray')
    axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # Prediction
    axes[i, 2].imshow(pred, cmap='gray')
    axes[i, 2].set_title(f'Prediction\nDice: {dice:.4f} | IoU: {iou:.4f}', 
                         fontsize=12, fontweight='bold')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

# Save predictions
pred_dir = Path(get_path(config['logging']['prediction_dir']))
pred_dir.mkdir(parents=True, exist_ok=True)
save_prediction_grid(val_images[:4].cpu(), val_masks[:4], val_preds[:4].cpu(), 
                    str(pred_dir / 'sample_predictions.png'), num_samples=4)

## 11. Summary

### Key Innovations:

1. **Attention Gates on Skip Connections**: Spatial attention mechanism weights encoder features before concatenation
   - Focuses on relevant spatial regions
   - Suppresses irrelevant activations
   - Improves feature selection

2. **Standard U-Net Backbone**: Proven encoder-decoder architecture with skip connections

3. **Improved Feature Selection**: Attention mechanism helps the network focus on the fetal head region

4. **Aggressive Data Augmentation**: Elastic deformation, grid distortion, Gaussian noise, CLAHE

5. **Training Enhancements**: Gradient clipping (max_norm=1.0) and learning rate warm-up (10 epochs)

### Architecture Highlights:

- **Encoder**: 4 stages with ConvBlock (64→128→256→512)
- **Bottleneck**: ConvBlock (1024 channels)
- **Decoder**: 4 stages with ConvBlock + Attention Gates (512→256→128→64)
- **Total Parameters**: ~34M parameters (~131 MB)

### Training Configuration:

- **Loss**: DiceLoss (Sorensen-Dice)
- **Optimizer**: Adam (lr from config, weight decay)
- **Scheduler**: ReduceLROnPlateau (monitors validation Dice)
- **Augmentation**: Aggressive transforms (on-the-fly)
- **Gradient Clipping**: max_norm=1.0
- **LR Warm-up**: 10 epochs

### Expected Performance:

The Attention Gates should improve segmentation accuracy by:
- Better boundary detection through spatial attention
- Improved focus on relevant regions (fetal head)
- Reduced false positives by suppressing background features

### Results:

View the training curves and predictions above. Compare with:
- **Standard U-Net**: Baseline performance
- **Residual SE U-Net**: Channel attention approach
- **ASPP-Enhanced models**: Multi-scale feature extraction

---

### Google Colab Specific Notes:

**Setup:**
- Repository cloned from: `https://github.com/TrinhThaiSonDHQT/Fetal-Head-Segmentation`
- All project files loaded automatically from GitHub

**Outputs:**
- Results saved to: `/content/Fetal-Head-Segmentation/results/`
- Download results: Files panel (left sidebar) → right-click folder → Download
- Best model: `best_model_attention_unet_v2.pth`

**Tips:**
- Use GPU runtime: Runtime → Change runtime type → T4 GPU
- Keep this tab open during training (or enable background execution)
- Download results before closing notebook (files are deleted when runtime stops)

### Next Steps:

1. Compare with Standard U-Net and other variants
2. Analyze attention maps to understand what the network focuses on
3. Test on HC18 test set
4. Experiment with different attention configurations