# Attention U-Net Training Notebook (Kaggle)
## Fetal Head Segmentation in Ultrasound Images

**Compatible with:** Kaggle Notebooks

This notebook implements an **Attention U-Net** architecture with:
1. Standard U-Net encoder/decoder with convolutional blocks
2. **Attention Gates** applied to skip connections before concatenation
3. Spatial attention mechanism to focus on relevant features
4. Improved feature selection through gating mechanism

**Key Innovation:** Attention Gates weight the skip connections from the encoder, allowing the network to focus on relevant spatial regions and suppress irrelevant activations.

## 1. Setup and Imports

**Kaggle Environment:**
- Uses `/kaggle/input/fhs-attention-unet` for read-only project data
- Uses `/kaggle/working/` for writable outputs (checkpoints, logs, predictions)

In [None]:
import os
import sys
from pathlib import Path

print("[Kaggle Setup]")

# Setup paths for Kaggle
project_root = Path('/kaggle/input/fhs-attention-unet')
output_root = Path('/kaggle/working')
cache_root = output_root / 'cache'

# Verify dataset exists
if not project_root.exists():
    raise RuntimeError(
        f"Dataset not found at {project_root}\n"
        f"Please add the 'fhs-attention-unet' dataset to your Kaggle notebook."
    )

if not (project_root / 'accuracy_focus').exists():
    raise RuntimeError(
        f"'accuracy_focus' folder not found in {project_root}\n"
        f"Please ensure your dataset structure is correct."
    )

print(f"Project root: {project_root} (read-only)")
print(f"Output root: {output_root} (writable)")

# Add project to path
sys.path.insert(0, str(project_root))

print(f"\n✓ Environment setup complete")

In [None]:
# Install required packages for Kaggle
print("Installing required packages...")

# This ensures override Kaggle's pre-installed packages
!pip install --force-reinstall --no-cache-dir -q \
    "numpy==1.26.4" \
    "scipy==1.11.4" \
    "scikit-learn==1.5.1" \
    "albumentations==1.4.0" \
    "opencv-python-headless==4.9.0.80" \
    "PyYAML>=5.4" \
    "tqdm>=4.62"

print("\n✓ Packages installed successfully")

In [None]:
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# ===== OLD IMPORTS (commented out) =====
from accuracy_focus.attention_unet.src.losses.dice_loss import DiceLoss
from accuracy_focus.attention_unet.src.models.attention_unet import AttentionUNet

# ===== NEW IMPORTS (BCEWithLogits + Logits Model) =====
# from accuracy_focus.attention_unet.src.losses.bce_logits import DiceBCEWithLogitsLoss
# from accuracy_focus.attention_unet.src.models.attention_unet.bce_logits_attention_unet import AttentionUNetLogits

from shared.src.data import HC18Dataset
from shared.src.metrics.segmentation_metrics import dice_coefficient, iou_score, pixel_accuracy
from shared.src.utils.visualization import save_prediction_grid, visualize_sample
from shared.src.utils.transforms import get_transforms

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Load Configuration

In [None]:
# Load configuration
config_path = project_root / 'accuracy_focus' / 'attention_unet' / 'configs' / 'attention_unet_config.yaml'

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Adjust output paths for Kaggle environment
print(f"Adjusting paths for Kaggle environment...")
config['logging']['checkpoint_dir'] = str(output_root / 'results' / 'checkpoints')
config['logging']['log_dir'] = str(output_root / 'results' / 'logs')
config['logging']['prediction_dir'] = str(output_root / 'results' / 'predictions')
config['logging']['visualization_dir'] = str(output_root / 'results' / 'visualizations')

print(f"  Outputs will be saved to: {output_root / 'results'}")

print("\nConfiguration loaded:")
print(f"  Model: {config['model']['name']}")
print(f"  Base Filters: {config['model']['base_filters']}")
print(f"  Learning Rate: {config['training']['optimizer']['lr']}")
print(f"  Loss Function: {config['loss']['name']}")
print(f"  Batch Size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")

## 3. Initialize Model

The Attention U-Net has:
- **Encoder**: 4 downsampling blocks with ConvBlock (64 → 128 → 256 → 512 channels)
- **Bottleneck**: ConvBlock (1024 channels)
- **Decoder**: 4 upsampling blocks with ConvBlock (512 → 256 → 128 → 64 channels)
- **Attention Gates**: Applied to skip connections before concatenation with decoder features
- **Skip Connections**: Attention-weighted features from encoder concatenated with decoder
- **Activation**: ReLU in convolutional blocks, Sigmoid output
- **Spatial Attention**: Attention gates focus on relevant spatial regions

In [None]:
# Set device
device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ===== OLD MODEL (commented out) =====
model = AttentionUNet(
    in_channels=config['model']['in_channels'],
    out_channels=config['model']['out_channels'],
    base_filters=config['model']['base_filters'],
    use_sigmoid=True  # Old model uses sigmoid
)

# ===== NEW MODEL (AttentionUNetLogits - no sigmoid) =====
# model = AttentionUNetLogits(
#     in_channels=config['model']['in_channels'],
#     out_channels=config['model']['out_channels'],
#     base_filters=config['model']['base_filters']
#     # No sigmoid - outputs logits for BCEWithLogitsLoss
# )
model = model.to(device)

# Count parameters
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

total_params, trainable_params = count_parameters(model)

# Model summary
print(f"\nAttention U-Net Architecture (Logits Version):")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / (1024**2):.2f} MB (float32)")
print(f"  Input channels: {config['model']['in_channels']}")
print(f"  Output channels: {config['model']['out_channels']}")
print(f"  Base filters: {config['model']['base_filters']}")
print(f"  Output type: LOGITS (no sigmoid)")

# Test forward pass
test_input = torch.randn(1, 1, 256, 256).to(device)
test_output = model(test_input)
print(f"\nTest forward pass:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"  Output range (logits): [{test_output.min().item():.4f}, {test_output.max().item():.4f}]")
print(f"  Note: Outputs are logits, not probabilities")

## 4. Setup Loss and Optimizer

- **Loss**: DiceLoss (Sorensen-Dice Loss)
- **Optimizer**: Adam with learning rate from config
- **Scheduler**: ReduceLROnPlateau (monitors validation Dice)

In [None]:
# ===== OLD LOSS FUNCTION (commented out) =====
loss_config = config['loss']
criterion = DiceLoss(smooth=1e-6)
print(f"Loss Function: {loss_config['name']}")
print(f"  Smooth parameter: 1e-6")

# ===== NEW LOSS FUNCTION (DiceBCEWithLogitsLoss) =====
# loss_config = config['loss']
# criterion = DiceBCEWithLogitsLoss(
#     dice_weight=0.5,
#     bce_weight=0.5,
#     pos_weight=None,  # Will be auto-computed from first batch
#     auto_weight=True,  # Automatic pos_weight calculation
#     smooth=1e-6
# )
# print(f"Loss Function: DiceBCEWithLogitsLoss (NEW - handles class imbalance)")
# print(f"  Dice weight: 0.5")
# print(f"  BCE weight: 0.5")
# print(f"  Smooth parameter: 1e-6")
# print(f"  Auto pos_weight: True (computed from first batch)")
# print(f"  → This loss handles extreme class imbalance automatically")

# Optimizer (Adam)
optimizer_config = config['training']['optimizer']
optimizer = optim.Adam(
    model.parameters(),
    lr=optimizer_config['lr'],
    betas=tuple(optimizer_config['betas']),
    eps=optimizer_config['eps'],
    weight_decay=optimizer_config['weight_decay']
)
print(f"\nOptimizer: Adam")
print(f"  Learning rate: {optimizer_config['lr']}")
print(f"  Weight decay: {optimizer_config['weight_decay']}")

# Learning rate scheduler
scheduler_config = config['training']['scheduler']
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode=scheduler_config['mode'],
    factor=scheduler_config['factor'],
    patience=scheduler_config['patience'],
    min_lr=scheduler_config['min_lr']
)
print(f"\nScheduler: ReduceLROnPlateau")
print(f"  Mode: {scheduler_config['mode']}")
print(f"  Factor: {scheduler_config['factor']}")
print(f"  Patience: {scheduler_config['patience']}")

## 5. Prepare Data Loaders

**Preprocessing** (applied to all images):
- Images normalized by dividing by 255.0
- Resized to 256×256 pixels
- Converted to PyTorch tensors (C, H, W)

**Augmentation** (training only - applied on-the-fly):
- Horizontal & Vertical flip (p=0.5)
- Rotation (±20°, p=0.5)
- ShiftScaleRotate: Translation (±10%), Scale (±10%), p=0.5

**Note:** Augmentations are applied dynamically during training. Fresh augmented samples are generated every epoch for better model generalization. Validation uses only preprocessing without augmentation.

In [None]:
data_config = config['data']
training_config = config['training']

# Helper to build paths
def get_path(config_path):
    """Helper to handle both absolute and relative paths"""
    p = Path(config_path)
    if p.is_absolute():
        return str(p)
    else:
        return str(project_root / config_path)

# Create augmentation transforms
print("Creating augmentation transforms...")
train_transform = get_transforms(height=256, width=256, is_train=True)
val_transform = get_transforms(height=256, width=256, is_train=False)
print("  Train transform: WITH augmentation (HorizontalFlip, Rotation, ShiftScaleRotate)")
print("  Val transform: WITHOUT augmentation (resize + normalize only)")

# Create datasets - using HC18Dataset for on-the-fly augmentation
print("\nCreating training dataset...")
train_dataset = HC18Dataset(
    image_dir=get_path(data_config['train_images']),
    mask_dir=get_path(data_config['train_masks']),
    transform=train_transform
)

print("Creating validation dataset...")
val_dataset = HC18Dataset(
    image_dir=get_path(data_config['val_images']),
    mask_dir=get_path(data_config['val_masks']),
    transform=val_transform
)

# Use num_workers=0 for Kaggle to avoid multiprocessing issues
num_workers = 0
print(f"\nDataLoader settings:")
print(f"  num_workers: {num_workers} (disabled for Kaggle)")
print(f"  Batch size: {training_config['batch_size']}")

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=training_config['batch_size'],
    shuffle=True,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=training_config['batch_size'],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=training_config['pin_memory']
)

print(f"\n{'='*60}")
print(f"Datasets Ready:")
print(f"{'='*60}")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Batch size: {training_config['batch_size']}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Train augmentation: ENABLED (on-the-fly)")
print(f"  Val augmentation: DISABLED")
print(f"{'='*60}\n")

## 6. Visualize Sample Data

In [None]:
# Get a batch of training data
sample_images, sample_masks = next(iter(train_loader))

print(f"Sample batch:")
print(f"  Images shape: {sample_images.shape}")
print(f"  Masks shape: {sample_masks.shape}")
print(f"  Image range: [{sample_images.min():.4f}, {sample_images.max():.4f}]")
print(f"  Mask range: [{sample_masks.min():.4f}, {sample_masks.max():.4f}]")
print(f"  Mask unique values: {torch.unique(sample_masks)}")
print(f"  Mask mean (% foreground): {sample_masks.mean():.4f}")

# CRITICAL CHECK: Ensure masks are binary {0, 1}
if not torch.all((sample_masks == 0) | (sample_masks == 1)):
    print("\n⚠️  WARNING: Masks are not binary! Check preprocessing.")
else:
    print("\n✓ Masks are properly binary {0, 1}")

# Check if masks have reasonable foreground ratio (2-10% typical for fetal head)
fg_ratio = sample_masks.mean().item()
if fg_ratio < 0.01 or fg_ratio > 0.3:
    print(f"⚠️  WARNING: Unusual foreground ratio: {fg_ratio:.2%} (expected 2-10%)")
else:
    print(f"✓ Foreground ratio looks reasonable: {fg_ratio:.2%}")

# Visualize first sample
visualize_sample(sample_images[0], sample_masks[0])

## 7. Training Functions

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """
    Train for one epoch
    """
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    for batch_idx, (images, masks) in enumerate(pbar):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update statistics
        running_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    epoch_loss = running_loss / len(dataloader)
    return epoch_loss


def validate(model, dataloader, criterion, device, epoch):
    """
    Validate the model
    """
    model.eval()
    running_loss = 0.0
    dice_scores = []
    iou_scores = []
    pa_scores = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Val]", leave=False)
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, masks)
            running_loss += loss.item()
            
            # Calculate metrics
            # IMPORTANT: Model outputs LOGITS, need sigmoid before thresholding
            preds = (torch.sigmoid(outputs) > 0.5).float()
            
            for i in range(images.size(0)):
                dice = dice_coefficient(preds[i], masks[i])
                iou = iou_score(preds[i], masks[i])
                pa = pixel_accuracy(preds[i], masks[i])
                
                dice_scores.append(dice.item())
                iou_scores.append(iou.item())
                pa_scores.append(pa.item())
            
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'dice': f"{np.mean(dice_scores):.4f}"
            })
    
    # Calculate average metrics
    val_loss = running_loss / len(dataloader)
    val_dice = np.mean(dice_scores)
    val_iou = np.mean(iou_scores)
    val_pa = np.mean(pa_scores)
    
    return val_loss, val_dice, val_iou, val_pa


print("Training functions defined!")

## 8. Training Loop

In [None]:
print(f"{'='*60}")
print(f"Starting Training - Attention U-Net")
print(f"{'='*60}")

# Training configuration
num_epochs = config['training']['num_epochs']
patience = config['training']['early_stopping_patience']

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_dice': [],
    'val_iou': [],
    'val_pa': [],
    'lr': []
}

best_dice = 0.0
epochs_without_improvement = 0

print(f"Epochs: {num_epochs}")
print(f"Early Stopping Patience: {patience}")
print(f"{'='*60}\n")

# Main training loop
for epoch in range(num_epochs):    
    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
    
    # Validate
    val_loss, val_dice, val_iou, val_pa = validate(model, val_loader, criterion, device, epoch)
    
    # Update learning rate
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_dice)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    history['val_iou'].append(val_iou)
    history['val_pa'].append(val_pa)
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"{'='*100}")
    dice_indicator = ' 🏆' if val_dice > best_dice else ''
    lr_change = f' ⬇️ (reduced from {old_lr:.6f})' if current_lr < old_lr else ''
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}{dice_indicator}")
    print(f"Val mIoU: {val_iou:.4f} | Val mPA: {val_pa:.4f}")
    if lr_change: print(f"LR: {current_lr:.6f}{lr_change}")
    print(f"{'='*100}")
    
    # Check for improvement
    is_best = val_dice > best_dice
    if is_best:
        best_dice = val_dice
        epochs_without_improvement = 0
        
        # Save best model
        checkpoint_dir = Path(get_path(config['logging']['checkpoint_dir']))
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        best_model_path = checkpoint_dir / 'best_model_attention_unet.pth'
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_dice': best_dice,
            'history': history,
            'config': config
        }, best_model_path)
        
        print(f"  → Saved best model (Dice: {best_dice:.4f})")
    else:
        epochs_without_improvement += 1
        print(f"  ⚠️  No improvement for {epochs_without_improvement}/{patience} epochs")
    
    # Early stopping
    if epochs_without_improvement >= patience:
        print(f"\n{'='*70}")
        print(f"⛔ EARLY STOPPING TRIGGERED")
        print(f"{'='*70}")
        print(f"  Stopped at epoch: {epoch+1}")
        print(f"  Best Dice Score:  {best_dice:.4f}")
        print(f"  Patience limit:   {patience} epochs without improvement")
        print(f"{'='*70}")
        break

# Final summary
print(f"\n{'='*60}")
print(f"Training Completed!")
print(f"{'='*60}")
print(f"Best Validation Dice: {best_dice:.4f}")
print(f"Best Validation IoU:  {max(history['val_iou']):.4f}")
print(f"Best Validation PA:   {max(history['val_pa']):.4f}")
print(f"{'='*60}\n")

## 9. Plot Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss (Dice)', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Dice coefficient
axes[0, 1].plot(history['val_dice'], label='Val Dice', color='green', linewidth=2)
axes[0, 1].axhline(y=best_dice, color='red', linestyle='--', label=f'Best: {best_dice:.4f}')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Dice Coefficient', fontsize=12)
axes[0, 1].set_title('Validation Dice Coefficient', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# IoU
axes[1, 0].plot(history['val_iou'], label='Val IoU', color='orange', linewidth=2)
axes[1, 0].axhline(y=max(history['val_iou']), color='red', linestyle='--', 
                   label=f"Best: {max(history['val_iou']):.4f}")
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('IoU Score', fontsize=12)
axes[1, 0].set_title('Validation IoU Score', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].plot(history['lr'], label='Learning Rate', color='red', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Learning Rate', fontsize=12)
axes[1, 1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=11)
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')

plt.tight_layout()
plt.show()

# Save figure
log_dir = Path(get_path(config['logging']['log_dir']))
log_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(log_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
print(f"Training curves saved to {log_dir / 'training_curves.png'}")

## 10. Visualize Predictions

In [None]:
# Load best model
checkpoint_path = Path(get_path(config['logging']['checkpoint_dir'])) / 'best_model_attention_unet.pth'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best Dice Score: {checkpoint['best_dice']:.4f}")

# Get validation samples
val_images, val_masks = next(iter(val_loader))
val_images = val_images.to(device)

# Generate predictions
with torch.no_grad():
    # IMPORTANT: Model outputs LOGITS, apply sigmoid to get probabilities
    val_logits = model(val_images)
    val_preds = (torch.sigmoid(val_logits) > 0.5).float()

# Visualize predictions
num_samples = min(4, len(val_images))
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

if num_samples == 1:
    axes = axes.reshape(1, -1)

for i in range(num_samples):
    # Move to CPU and convert to numpy
    img = val_images[i, 0].cpu().numpy()
    mask = val_masks[i, 0].numpy()
    pred = val_preds[i, 0].cpu().numpy()
    
    # Calculate metrics for this sample
    dice = dice_coefficient(val_preds[i].cpu(), val_masks[i].to(device)).item()
    iou = iou_score(val_preds[i].cpu(), val_masks[i].to(device)).item()
    
    # Input image
    axes[i, 0].imshow(img, cmap='gray')
    axes[i, 0].set_title('Input Image', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(mask, cmap='gray')
    axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # Prediction
    axes[i, 2].imshow(pred, cmap='gray')
    axes[i, 2].set_title(f'Prediction\nDice: {dice:.4f} | IoU: {iou:.4f}', 
                         fontsize=12, fontweight='bold')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

# Save predictions
pred_dir = Path(get_path(config['logging']['prediction_dir']))
pred_dir.mkdir(parents=True, exist_ok=True)
save_prediction_grid(val_images[:4].cpu(), val_masks[:4], val_preds[:4].cpu(), 
                    str(pred_dir / 'sample_predictions.png'), num_samples=4)

## 11. Summary

### Key Innovations:

1. **Attention Gates on Skip Connections**: Spatial attention mechanism weights encoder features before concatenation
   - Focuses on relevant spatial regions
   - Suppresses irrelevant activations
   - Improves feature selection

2. **Standard U-Net Backbone**: Proven encoder-decoder architecture with skip connections

3. **Improved Feature Selection**: Attention mechanism helps the network focus on the fetal head region

### Architecture Highlights:

- **Encoder**: 4 stages with ConvBlock (64→128→256→512)
- **Bottleneck**: ConvBlock (1024 channels)
- **Decoder**: 4 stages with ConvBlock + Attention Gates (512→256→128→64)
- **Total Parameters**: ~34M parameters (~131 MB)

### Training Configuration:

- **Loss**: DiceLoss (Sorensen-Dice Loss)
- **Optimizer**: Adam (lr from config, weight decay)
- **Scheduler**: ReduceLROnPlateau (monitors validation Dice)
- **Augmentation**: HorizontalFlip, Rotation, ShiftScaleRotate (on-the-fly)

### Expected Performance:

The Attention Gates should improve segmentation accuracy by:
- Better boundary detection through spatial attention
- Improved focus on relevant regions (fetal head)
- Reduced false positives by suppressing background features

### Results:

View the training curves and predictions above. Compare with:
- **Standard U-Net**: Baseline performance
- **Residual SE U-Net**: Channel attention approach
- **ASPP-Enhanced models**: Multi-scale feature extraction

---

### Kaggle-Specific Notes:

**Kaggle:**
- Outputs saved to `/kaggle/working/results/`
- Automatically available for download after notebook finishes
- Best model: `best_model_attention_unet.pth`

### Next Steps:

1. Compare with Standard U-Net and other variants
2. Analyze attention maps to understand what the network focuses on
3. Test on HC18 test set
4. Experiment with different attention configurations