# Off-Road Semantic Scene Segmentation - Model Training
## Track 2: Desert Environment Segmentation with U-Net

This notebook implements model training, evaluation, and inference for the desert semantic segmentation challenge.

**Evaluation Metric:** Intersection over Union (IoU)

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import os
from pathlib import Path
import time
from datetime import datetime

# Import from preprocessing notebook
# (In practice, you would import or redefine the Dataset class here)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Training Configuration

In [None]:
# Training hyperparameters
BATCH_SIZE = 8
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# Model configuration
NUM_CLASSES = 10
HEIGHT = 512
WIDTH = 512

# Paths
CHECKPOINT_DIR = './checkpoints'
OUTPUT_DIR = './outputs'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Training settings
SAVE_INTERVAL = 5  # Save checkpoint every N epochs
NUM_WORKERS = 4  # For DataLoader

print(f"Training Configuration:")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Image Size: {HEIGHT}x{WIDTH}")

## 3. Model Architecture - U-Net with ResNet Encoder

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNet(nn.Module):
    """U-Net architecture for semantic segmentation"""
    
    def __init__(self, n_channels=3, n_classes=10, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


# Initialize model
model = UNet(n_channels=3, n_classes=NUM_CLASSES, bilinear=True)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel: U-Net")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 4. Loss Functions and Metrics

In [None]:
class CombinedLoss(nn.Module):
    """Combined Cross Entropy + Dice Loss for better segmentation"""
    
    def __init__(self, weight=None, ce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.ce_loss = nn.CrossEntropyLoss(weight=weight)
    
    def forward(self, pred, target):
        # Cross Entropy Loss
        ce = self.ce_loss(pred, target)
        
        # Dice Loss
        pred_softmax = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
        
        intersection = (pred_softmax * target_one_hot).sum(dim=(2, 3))
        union = pred_softmax.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2. * intersection + 1e-7) / (union + 1e-7)
        dice_loss = 1 - dice.mean()
        
        return self.ce_weight * ce + self.dice_weight * dice_loss


def calculate_iou(pred, target, num_classes):
    """
    Calculate Intersection over Union (IoU) for semantic segmentation
    
    Args:
        pred: Predicted segmentation (B, C, H, W) or (B, H, W)
        target: Ground truth segmentation (B, H, W)
        num_classes: Number of classes
    
    Returns:
        mean_iou: Mean IoU across all classes
        class_iou: IoU for each class
    """
    if pred.dim() == 4:
        pred = torch.argmax(pred, dim=1)
    
    ious = []
    
    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)
        
        intersection = (pred_cls & target_cls).sum().float()
        union = (pred_cls | target_cls).sum().float()
        
        if union == 0:
            iou = torch.tensor(float('nan'))
        else:
            iou = intersection / union
        
        ious.append(iou.item())
    
    # Calculate mean IoU (ignoring NaN values)
    valid_ious = [iou for iou in ious if not np.isnan(iou)]
    mean_iou = np.mean(valid_ious) if valid_ious else 0.0
    
    return mean_iou, ious


# Initialize loss function
criterion = CombinedLoss(ce_weight=0.5, dice_weight=0.5)
print("Loss function: Combined CE + Dice Loss")

## 5. Training and Validation Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    
    running_loss = 0.0
    running_iou = 0.0
    
    pbar = tqdm(dataloader, desc='Training')
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        with torch.no_grad():
            mean_iou, _ = calculate_iou(outputs, masks, NUM_CLASSES)
        
        running_loss += loss.item()
        running_iou += mean_iou
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'iou': f'{mean_iou:.4f}'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_iou = running_iou / len(dataloader)
    
    return epoch_loss, epoch_iou


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    
    running_loss = 0.0
    running_iou = 0.0
    all_class_ious = [[] for _ in range(NUM_CLASSES)]
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, masks)
            
            # Calculate metrics
            mean_iou, class_ious = calculate_iou(outputs, masks, NUM_CLASSES)
            
            running_loss += loss.item()
            running_iou += mean_iou
            
            # Collect class-wise IoUs
            for cls_idx, iou in enumerate(class_ious):
                if not np.isnan(iou):
                    all_class_ious[cls_idx].append(iou)
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'iou': f'{mean_iou:.4f}'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_iou = running_iou / len(dataloader)
    
    # Calculate average IoU per class
    class_iou_avg = [np.mean(ious) if ious else 0.0 for ious in all_class_ious]
    
    return epoch_loss, epoch_iou, class_iou_avg

## 6. Training Loop

In [None]:
# NOTE: You need to create train_loader and val_loader from your datasets
# This assumes you've run the preprocessing notebook or imported the datasets

# Example (uncomment and modify):
# from your_preprocessing_module import train_dataset, val_dataset
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

# Training history
history = {
    'train_loss': [],
    'train_iou': [],
    'val_loss': [],
    'val_iou': [],
    'learning_rates': []
}

best_val_iou = 0.0
best_epoch = 0

print("\n" + "="*60)
print("TRAINING START")
print("="*60 + "\n")

# Uncomment to run training
'''
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 60)
    
    # Train
    train_loss, train_iou = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_iou, class_ious = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_iou)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Record history
    history['train_loss'].append(train_loss)
    history['train_iou'].append(train_iou)
    history['val_loss'].append(val_loss)
    history['val_iou'].append(val_iou)
    history['learning_rates'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | Train IoU: {train_iou:.4f}")
    print(f"  Val Loss:   {val_loss:.4f} | Val IoU:   {val_iou:.4f}")
    print(f"  Learning Rate: {current_lr:.2e}")
    
    # Save best model
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_iou': val_iou,
            'class_ious': class_ious,
        }, os.path.join(CHECKPOINT_DIR, 'best_model.pth'))
        print(f"  ✓ New best model saved! (IoU: {val_iou:.4f})")
    
    # Save checkpoint periodically
    if (epoch + 1) % SAVE_INTERVAL == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_iou': val_iou,
        }, os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch+1}.pth'))
        print(f"  Checkpoint saved: epoch_{epoch+1}.pth")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"Best Validation IoU: {best_val_iou:.4f} at epoch {best_epoch}")
'''

## 7. Visualize Training Progress

In [None]:
def plot_training_history(history):
    """Plot training curves"""
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # IoU
    axes[1].plot(history['train_iou'], label='Train IoU', marker='o')
    axes[1].plot(history['val_iou'], label='Val IoU', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('IoU')
    axes[1].set_title('Training and Validation IoU')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Learning Rate
    axes[2].plot(history['learning_rates'], marker='o', color='green')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('Learning Rate Schedule')
    axes[2].set_yscale('log')
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'training_curves.png'), dpi=300, bbox_inches='tight')
    plt.show()

# Uncomment to plot (after training)
# plot_training_history(history)

## 8. Model Evaluation on Validation Set

In [None]:
def detailed_evaluation(model, dataloader, device, class_names):
    """Perform detailed evaluation with class-wise metrics"""
    
    model.eval()
    
    all_class_ious = [[] for _ in range(NUM_CLASSES)]
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc='Evaluating'):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            mean_iou, class_ious = calculate_iou(outputs, masks, NUM_CLASSES)
            
            for cls_idx, iou in enumerate(class_ious):
                if not np.isnan(iou):
                    all_class_ious[cls_idx].append(iou)
    
    # Calculate statistics
    results = []
    for cls_idx, (class_name, ious) in enumerate(zip(class_names, all_class_ious)):
        if ious:
            mean = np.mean(ious)
            std = np.std(ious)
            results.append({
                'class': class_name,
                'iou_mean': mean,
                'iou_std': std,
                'count': len(ious)
            })
        else:
            results.append({
                'class': class_name,
                'iou_mean': 0.0,
                'iou_std': 0.0,
                'count': 0
            })
    
    return results


def print_evaluation_results(results):
    """Print evaluation results in a formatted table"""
    
    print("\n" + "="*70)
    print("EVALUATION RESULTS - CLASS-WISE IoU")
    print("="*70)
    print(f"{'Class':<20} {'Mean IoU':<15} {'Std Dev':<15} {'Samples'}")
    print("-"*70)
    
    mean_ious = []
    for result in results:
        print(f"{result['class']:<20} {result['iou_mean']:<15.4f} {result['iou_std']:<15.4f} {result['count']}")
        if result['iou_mean'] > 0:
            mean_ious.append(result['iou_mean'])
    
    print("-"*70)
    print(f"{'Overall mIoU':<20} {np.mean(mean_ious):<15.4f}")
    print("="*70 + "\n")

# Load best model and evaluate
# Uncomment to run evaluation
'''
checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])

CLASS_NAMES = ['Trees', 'Lush Bushes', 'Dry Grass', 'Dry Bushes', 
               'Ground Clutter', 'Flowers', 'Logs', 'Rocks', 'Landscape', 'Sky']

results = detailed_evaluation(model, val_loader, device, CLASS_NAMES)
print_evaluation_results(results)
'''

## 9. Visualization of Predictions

In [None]:
def visualize_predictions(model, dataset, device, num_samples=5, class_colors=None):
    """Visualize model predictions"""
    
    model.eval()
    
    if class_colors is None:
        class_colors = plt.cm.tab10(np.linspace(0, 1, NUM_CLASSES))[:, :3] * 255
    
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for idx in indices:
        image, mask_gt = dataset[idx]
        
        # Predict
        with torch.no_grad():
            image_input = image.unsqueeze(0).to(device)
            output = model(image_input)
            pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
        
        # Denormalize image
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        image_vis = image * std + mean
        image_vis = torch.clamp(image_vis, 0, 1)
        image_np = image_vis.permute(1, 2, 0).numpy()
        
        # Create colored masks
        mask_gt_np = mask_gt.numpy()
        
        mask_gt_color = np.zeros((*mask_gt_np.shape, 3), dtype=np.uint8)
        mask_pred_color = np.zeros((*pred.shape, 3), dtype=np.uint8)
        
        for cls in range(NUM_CLASSES):
            mask_gt_color[mask_gt_np == cls] = class_colors[cls]
            mask_pred_color[pred == cls] = class_colors[cls]
        
        # Calculate IoU for this sample
        iou, _ = calculate_iou(
            torch.from_numpy(pred).unsqueeze(0),
            torch.from_numpy(mask_gt_np).unsqueeze(0),
            NUM_CLASSES
        )
        
        # Plot
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        
        axes[0].imshow(image_np)
        axes[0].set_title('Input Image')
        axes[0].axis('off')
        
        axes[1].imshow(mask_gt_color)
        axes[1].set_title('Ground Truth')
        axes[1].axis('off')
        
        axes[2].imshow(mask_pred_color)
        axes[2].set_title(f'Prediction (IoU: {iou:.3f})')
        axes[2].axis('off')
        
        # Overlay
        overlay = image_np * 0.6 + mask_pred_color / 255 * 0.4
        axes[3].imshow(overlay)
        axes[3].set_title('Overlay')
        axes[3].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, f'prediction_{idx}.png'), dpi=150, bbox_inches='tight')
        plt.show()

# Uncomment to visualize
# visualize_predictions(model, val_dataset, device, num_samples=5)

## 10. Inference on Test Set

In [None]:
def predict_test_set(model, test_loader, device, output_dir):
    """Generate predictions for test set"""
    
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    
    print("Generating test predictions...")
    
    with torch.no_grad():
        for images, filenames in tqdm(test_loader):
            images = images.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
            
            # Save predictions
            for pred, filename in zip(predictions, filenames):
                # Save as PNG
                pred_img = Image.fromarray(pred.astype(np.uint8))
                save_path = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_pred.png")
                pred_img.save(save_path)
    
    print(f"Test predictions saved to {output_dir}")

# Uncomment to generate test predictions
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
# predict_test_set(model, test_loader, device, os.path.join(OUTPUT_DIR, 'test_predictions'))

## 11. Save Final Model and Metadata

In [None]:
# Save training configuration and results
metadata = {
    'model': 'U-Net',
    'num_classes': NUM_CLASSES,
    'image_size': [HEIGHT, WIDTH],
    'batch_size': BATCH_SIZE,
    'num_epochs': NUM_EPOCHS,
    'learning_rate': LEARNING_RATE,
    'optimizer': 'AdamW',
    'loss_function': 'Combined CE + Dice',
    'best_val_iou': best_val_iou if 'best_val_iou' in locals() else None,
    'best_epoch': best_epoch if 'best_epoch' in locals() else None,
    'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}

with open(os.path.join(OUTPUT_DIR, 'training_metadata.json'), 'w') as f:
    json.dump(metadata, f, indent=2)

print("Training metadata saved!")

## Summary and Next Steps

This notebook provides:
1. ✅ U-Net model architecture for semantic segmentation
2. ✅ Combined loss function (CE + Dice)
3. ✅ IoU metric calculation
4. ✅ Training and validation loops
5. ✅ Visualization tools
6. ✅ Test set inference

**For Competition Submission:**
1. Train the model on the full training set
2. Evaluate on the validation set
3. Generate predictions for the unseen test set
4. Analyze failure cases
5. Document your methodology

**Potential Improvements:**
- Try different architectures (DeepLabV3+, SegFormer)
- Use pretrained encoders (ResNet, EfficientNet)
- Experiment with loss functions (Focal Loss, Tversky Loss)
- Apply test-time augmentation
- Ensemble multiple models
- Post-processing (CRF, conditional random fields)