# SVAMITVA Feature Extraction Model - Training Pipeline
### Production-Ready Training for 10+ Tasks on DGX Server

This notebook provides a well-organized and robust pipeline for training the SVAMITVA multi-task deep learning model. It is configured for the **DGX Server** environment and handles 10 unique geospatial extraction tasks.

**Tasks Covered:**
1. Building Mask
2. Roof Type (Classification)
3. Road Mask
4. Road Centerline
5. Waterbody Mask
6. Waterbody Line
7. Waterbody Point
8. Utility Line
9. Utility Polygon
10. Bridge Mask
11. Railway Mask

## 1. Setup & Imports
We start by initializing the environment, setting random seeds for reproducibility, and configuring the device (GPU/CPU).

In [None]:
import os
import sys
import time
import gc
import torch
import numpy as np
import random
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Ensure project root is in path
sys.path.append(os.path.abspath('.'))

# Local module imports
from models.feature_extractor import FeatureExtractor
from models.losses import MultiTaskLoss
from data.dataset import create_dataloaders
from training.config import TrainingConfig, get_config_from_args
from training.metrics import MetricTracker
from utils.checkpoint import CheckpointManager
from utils.logging_config import setup_logging

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration & Hyperparameters
Update the paths below to match the DGX server directory structure.

In [None]:
# DGX Server Configuration
DATA_DIR = "/jupyter/sods.user04/DATA/"
CHECKPOINT_DIR = "/jupyter/sods.user04/svamitva_model/checkpoints/"
LOGS_DIR = "/jupyter/sods.user04/svamitva_model/logs/"

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)

# Training Parameters
config = TrainingConfig(
    train_dir=DATA_DIR,
    batch_size=8, 
    num_epochs=100,
    lr=1e-4,
    backbone="resnet50",
    image_size=512,
    num_workers=0,  # DGX Jupyter best practice is 0 workers to avoid IPC errors
    use_amp=True,   # Automatic Mixed Precision for faster DGX training
    checkpoint_dir=CHECKPOINT_DIR
)

logger = setup_logging(os.path.join(LOGS_DIR, "train.log"))
print(f"Configured for training on data in: {DATA_DIR}")

## 3. Data Loading
Initializing the dataloaders with automatic shapefile caching (loaded once per MAP folder).

In [None]:
print("Initializing dataloaders...")
train_loader, val_loader = create_dataloaders(
    train_dir=config.train_dir,
    batch_size=config.batch_size,
    image_size=config.image_size,
    num_workers=config.num_workers,
    val_split=0.15
)

print(f"Samples: {len(train_loader.dataset)} Training, {len(val_loader.dataset)} Validation")

## 4. Model & Loss Initialization
Loading the 10-head architecture with its multi-task focal loss.

In [None]:
model = FeatureExtractor(backbone=config.backbone).to(device)
criterion = MultiTaskLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)

checkpoint_manager = CheckpointManager(CHECKPOINT_DIR)
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.")

## 5. Main Training Loop
Features include live metric tracking, best model saving, and memory protection.

In [None]:
history = {'train_loss': [], 'val_loss': [], 'val_iou': []}
start_epoch = 0

# Optional: Resume from latest checkpoint
checkpoint_path = checkpoint_manager.get_latest_checkpoint()
if checkpoint_path:
    print(f"Found checkpoint: {checkpoint_path}")
    # model, optimizer, start_epoch, _ = checkpoint_manager.load(checkpoint_path, model, optimizer)

for epoch in range(start_epoch, config.num_epochs):
    model.train()
    train_tracker = MetricTracker()
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs} [Train]")
    for batch_idx, (images, targets) in enumerate(pbar):
        images = images.to(device)
        targets = {k: v.to(device) for k, v in targets.items()}
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=config.use_amp):
            outputs = model(images)
            loss_dict = criterion(outputs, targets)
            total_loss = loss_dict['total_loss']
        
        if torch.isnan(total_loss):
            print("NaN loss detected, skipping step")
            continue
            
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_tracker.update(loss_dict, outputs, targets)
        pbar.set_postfix({'loss': f"{total_loss.item():.4f}"})
        
    # Validation phase
    model.eval()
    val_tracker = MetricTracker()
    with torch.no_grad():
        vbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")
        for images, targets in vbar:
            images = images.to(device)
            targets = {k: v.to(device) for k, v in targets.items()}
            
            outputs = model(images)
            loss_dict = criterion(outputs, targets)
            val_tracker.update(loss_dict, outputs, targets)
            
    # Summarize & Plot
    epoch_train_loss = train_tracker.get_avg_loss()
    epoch_val_loss = val_tracker.get_avg_loss()
    epoch_val_iou = val_tracker.get_avg_iou()
    
    history['train_loss'].append(epoch_train_loss)
    history['val_loss'].append(epoch_val_loss)
    history['val_iou'].append(epoch_val_iou)
    
    checkpoint_manager.save(model, optimizer, epoch, epoch_val_iou)
    
    # Live Feedback
    clear_output(wait=True)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss History')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['val_iou'], label='Val Avg IoU', color='green')
    plt.title('Validation IoU')
    plt.legend()
    plt.show()
    
    print(f"Epoch {epoch+1} Summary:")
    print(f"  Train Loss: {epoch_train_loss:.4f}")
    print(f"  Val Loss:   {epoch_val_loss:.4f}")
    print(f"  Val IoU:    {epoch_val_iou:.4f}")

    # Memory cleanup
    gc.collect()
    if device.type == 'cuda':
        torch.cuda.empty_cache()