# SVAMITVA Feature Extraction — Universal Training Pipeline

**Objective:** Train a unified multi-task segmentation model to extract 10 shapefile classes from Drone Imagery.

Automatically detects environment: DGX A100 (CUDA) or Apple Silicon (MPS/CPU).

## Workflow
1. Structure data as `DATA/MAP1/*.tif` + `*.shp` files
2. Run **Cell 1** (Setup — run once)
3. Run **Cell 2** (Training Engine — run once)
4. Run **Cell 3** (Execute Training — run to completion)
5. Best checkpoint saved as `./checkpoints/MAP5_best.pt`

---
## Cell 1 — Setup & Environment Configuration
Detects hardware, sets DATA path, creates config.

In [None]:
import os
import sys
import time
from pathlib import Path

# Must be set before any CUDA context is created
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast

# ── Workspace path setup ────────────────────────────────────────────────────
NOTEBOOK_DIR = Path.cwd()
if str(NOTEBOOK_DIR) not in sys.path:
    sys.path.insert(0, str(NOTEBOOK_DIR))

# Auto-detect DATA directory
DGX_DATA_DIR   = Path("/jupyter/sods.user04/DATA")
LOCAL_DATA_DIR = NOTEBOOK_DIR / "DATA"
FALLBACK_LOCAL = NOTEBOOK_DIR.parent / "DATA"

if DGX_DATA_DIR.exists():
    DATA_DIR, ENV_TYPE = DGX_DATA_DIR, "DGX GPU Cluster"
elif LOCAL_DATA_DIR.exists():
    DATA_DIR, ENV_TYPE = LOCAL_DATA_DIR, "Local PC"
elif FALLBACK_LOCAL.exists():
    DATA_DIR, ENV_TYPE = FALLBACK_LOCAL, "Local PC (Parent Dir)"
else:
    print("[WARNING] No DATA folder found. Defaulting to LOCAL_DATA_DIR.")
    DATA_DIR, ENV_TYPE = LOCAL_DATA_DIR, "Missing / Artificial"

CKPT_DIR = NOTEBOOK_DIR / "checkpoints"
CKPT_DIR.mkdir(parents=True, exist_ok=True)
(NOTEBOOK_DIR / "logs").mkdir(exist_ok=True)

# ── Device detection ────────────────────────────────────────────────────────
if torch.cuda.is_available():
    device, amp_device = torch.device("cuda"), "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device, amp_device = torch.device("mps"), "cpu"  # MPS autocast is flaky
else:
    device, amp_device = torch.device("cpu"), "cpu"

print(f"Environment : {ENV_TYPE}")
print(f"DATA dir    : {DATA_DIR}  (exists={DATA_DIR.exists()})")
print(f"Device      : {device}")
if device.type == "cuda":
    for i in range(torch.cuda.device_count()):
        g = torch.cuda.get_device_properties(i)
        print(f"  GPU {i}: {g.name}  ({g.total_memory/1e9:.1f} GB)")

# ── Training configuration ───────────────────────────────────────────────────
CONFIG = dict(
    backbone              = "resnet50",
    pretrained            = True,
    image_size            = 512,
    batch_size            = 8 if device.type == "cuda" else 4,
    epochs_per_map        = 50,
    learning_rate         = 2e-4,
    weight_decay          = 1e-4,
    num_workers           = 0,
    mixed_precision       = device.type == "cuda",
    gradient_clip         = 1.0,
    building_weight       = 1.0,
    roof_weight           = 0.5,
    road_weight           = 0.8,
    waterbody_weight      = 0.8,
    road_centerline_weight= 0.7,
    waterbody_line_weight = 0.7,
    waterbody_point_weight= 0.9,
    utility_line_weight   = 0.7,
    utility_poly_weight   = 0.8,
    bridge_weight         = 1.0,
    railway_weight        = 0.9,
)

TARGET_KEYS = [
    "building_mask", "road_mask", "road_centerline_mask",
    "waterbody_mask", "waterbody_line_mask", "waterbody_point_mask",
    "utility_line_mask", "utility_poly_mask",
    "bridge_mask", "railway_mask", "roof_type_mask",
]

print("\nSetup complete ✓")

---
## Cell 2 — Training Engine Core
Loads model, dataset, loss, and defines `train_map` and `analyse_checkpoint`.

In [None]:
from models.feature_extractor import FeatureExtractor
from models.losses import MultiTaskLoss
from training.metrics import MetricTracker
from data.dataset import SvamitvaDataset
from data.augmentation import get_train_transforms

def move_targets(batch):
    return {k: batch[k].to(device) for k in TARGET_KEYS if k in batch}

def build_model(load_from: Path = None):
    m = FeatureExtractor(
        backbone=CONFIG["backbone"],
        pretrained=CONFIG["pretrained"],
        num_roof_classes=5,
    )
    if load_from and load_from.exists():
        state = torch.load(load_from, map_location="cpu", weights_only=False)
        # Support both raw state_dict and wrapped checkpoint
        weights = state.get("model") or state.get("model_state_dict") or state
        m.load_state_dict(weights, strict=False)
        print(f"Loaded weights: {load_from.name}")
    if torch.cuda.device_count() > 1:
        m = nn.DataParallel(m)
    return m.to(device)


def train_map(map_name: str, resume_from: Path = None):
    map_dir  = DATA_DIR / map_name
    best_out = CKPT_DIR / f"{map_name}_best.pt"
    last_out = CKPT_DIR / f"{map_name}_latest.pt"

    if not map_dir.exists():
        print(f"[SKIP] Folder not found: {map_dir}")
        return best_out if best_out.exists() else None

    print(f"\n{'='*70}")
    print(f"  Region     : {map_name}")
    print(f"  Checkpoint : {resume_from.name if resume_from and resume_from.exists() else 'RANDOM / SCRATCH'}")
    print(f"{'='*70}")

    model_w = build_model(load_from=resume_from)
    inner   = model_w.module if isinstance(model_w, nn.DataParallel) else model_w

    # Dataset
    try:
        ds = SvamitvaDataset(
            root_dir=DATA_DIR,
            image_size=CONFIG["image_size"],
            transform=get_train_transforms(CONFIG["image_size"]),
            mode="train",
        )
        ds.samples = [s for s in ds.samples if s["map_name"] == map_name]
        print(f"  Tiles: {len(ds)}")
    except Exception as e:
        print(f"Dataset load failed for {map_name}: {e}")
        return None

    loader = DataLoader(
        ds,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        num_workers=CONFIG["num_workers"],
        pin_memory=(device.type == "cuda"),
    )

    loss_fn   = MultiTaskLoss(**{k: v for k, v in CONFIG.items() if k.endswith("_weight")}).to(device)
    optimizer = torch.optim.AdamW(model_w.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs_per_map"], eta_min=1e-6)
    use_amp   = device.type == "cuda"
    scaler    = GradScaler(enabled=use_amp, device="cuda" if use_amp else None)

    best_iou = 0.0
    for epoch in range(1, CONFIG["epochs_per_map"] + 1):
        model_w.train()
        tracker  = MetricTracker()
        run_loss = 0.0
        n_steps  = 0
        t0       = time.time()

        for batch in loader:
            images  = batch["image"].to(device)
            targets = move_targets(batch)
            optimizer.zero_grad(set_to_none=True)

            with autocast(device_type=amp_device, enabled=use_amp):
                preds               = model_w(images)
                total_loss, loss_d  = loss_fn(preds, targets)

            if not torch.isfinite(total_loss):
                print(f"  [NaN SKIP] epoch {epoch} step {n_steps}")
                continue

            if use_amp:
                scaler.scale(total_loss).backward()
                if CONFIG["gradient_clip"] > 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model_w.parameters(), CONFIG["gradient_clip"])
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                if CONFIG["gradient_clip"] > 0:
                    nn.utils.clip_grad_norm_(model_w.parameters(), CONFIG["gradient_clip"])
                optimizer.step()

            run_loss += total_loss.item()
            tracker.update(preds, targets)
            n_steps += 1

        scheduler.step()
        m        = tracker.compute()
        avg_loss = run_loss / max(n_steps, 1)
        avg_iou  = m.get("avg_iou", 0.0)

        print(f"  Epoch {epoch:2d}/{CONFIG['epochs_per_map']} | loss: {avg_loss:.4f} | iou: {avg_iou:.4f} | {time.time()-t0:.0f}s")

        ckpt = {"model": inner.state_dict(), "epoch": epoch, "map_name": map_name, "best_iou": best_iou, "metrics": m}
        torch.save(ckpt, last_out)
        if avg_iou > best_iou:
            best_iou = avg_iou
            ckpt["best_iou"] = best_iou
            torch.save(ckpt, best_out)
            print(f"    → New best! IoU = {best_iou:.4f}")

    print(f"\n  [DONE] {map_name}  Best IoU={best_iou:.4f}")
    return best_out


def analyse_checkpoint(ckpt_path: Path):
    if not ckpt_path or not ckpt_path.exists():
        print(f"Checkpoint missing: {ckpt_path}")
        return
    st = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    m  = st.get("metrics", {})
    print(f"\n{'─'*60}")
    print(f"  Checkpoint : {ckpt_path.name}")
    print(f"  MAP: {st.get('map_name','?')} | Epoch {st.get('epoch','?')} | Best IoU {st.get('best_iou',0):.4f}")
    for k, v in m.items():
        if k.endswith("_iou") and k != "avg_iou":
            bar = "█" * int(v * 20)
            print(f"  {k:30s} {v:.4f}  {bar}")
    print(f"{'─'*60}\n")

print("Training engine ready ✓")

---
## Cell 3 — Execute Training
Progressive multi-map training. Each MAP inherits weights from the previous one.

In [None]:
cpt1 = train_map("MAP1", resume_from=None)
if cpt1: analyse_checkpoint(cpt1)

cpt2 = train_map("MAP2", resume_from=cpt1)
if cpt2: analyse_checkpoint(cpt2)

cpt3 = train_map("MAP3", resume_from=cpt2)
if cpt3: analyse_checkpoint(cpt3)

cpt4 = train_map("MAP4", resume_from=cpt3)
if cpt4: analyse_checkpoint(cpt4)

cpt5 = train_map("MAP5", resume_from=cpt4)
if cpt5: analyse_checkpoint(cpt5)

print("\n*** SVAMITVA UNIVERSAL PIPELINE COMPLETE ***")
print("Final weights saved in ./checkpoints/")