# SVAMITVA Feature Extraction — Universal Training Pipeline

**Objective:** Train a unified multi-task segmentation model capable of extracting 10 shapefile classes from Drone Imagery. 

This notebook supports training on a DGX A100 Server (CUDA) and local Apple Silicon (MPS/CPU) environments automatically.

## Environment Workflow
1. Put data in a folder structured like `DATA/MAP1/*.tif`.
2. Execute the Setup (Cell 1).
3. Execute Training (Cell 2).
4. Wait for Universal `MAP5_best.pt` to be stored in the `./checkpoints/` directory.

---
## CELL 1 — Setup & Environment Configuration (Run Once)
This cell dynamically checks variables, imports required codebase files (`feature_extractor.py`, etc.), and allocates resources.

In [1]:
import os
import sys
import time
from pathlib import Path

# Ensure PyTorch handles high memory loads on DGX natively
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast  # Handles PyTorch 2.4/2.5+ safely

# ── Workspace Directory Checking ──────────────────────────────────────────
NOTEBOOK_DIR = Path.cwd()
if str(NOTEBOOK_DIR) not in sys.path:
    sys.path.insert(0, str(NOTEBOOK_DIR))

DGX_DATA_DIR = Path("/jupyter/sods.user04/DATA")
LOCAL_DATA_DIR = NOTEBOOK_DIR / "DATA"
FALLBACK_LOCAL = NOTEBOOK_DIR.parent / "DATA"

if DGX_DATA_DIR.exists():
    DATA_DIR = DGX_DATA_DIR
    ENV_TYPE = "DGX GPU Cluster"
elif LOCAL_DATA_DIR.exists():
    DATA_DIR = LOCAL_DATA_DIR
    ENV_TYPE = "Local PC"
elif FALLBACK_LOCAL.exists():
    DATA_DIR = FALLBACK_LOCAL
    ENV_TYPE = "Local PC (Parent Dir)"
else:
    print(f"[WARNING] No DATA folder found directly! Assuming DGX/Linux mount failure. Will artificially default to {DGX_DATA_DIR}, please create a /DATA directory containing MAP1, MAP2, etc.")
    DATA_DIR = LOCAL_DATA_DIR
    ENV_TYPE = "Missing / Artificial"

CKPT_DIR = NOTEBOOK_DIR / "checkpoints"
CKPT_DIR.mkdir(parents=True, exist_ok=True)
(NOTEBOOK_DIR / "logs").mkdir(exist_ok=True)

# ── Device Allocation (CUDA / MPS / CPU) ──────────────────────────────────
if torch.cuda.is_available():
    device = torch.device("cuda")
    amp_device = "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    amp_device = "cpu" # AutoCast is mostly disabled/flaky for MPS, using CPU autocast for safety
else:
    device = torch.device("cpu")
    amp_device = "cpu"

print(f"Environment  : {ENV_TYPE}")
print(f"DATA dir     : {DATA_DIR}  (exists={DATA_DIR.exists()})")
print(f"Device       : {device}\n")
if device.type == "cuda":
    for i in range(torch.cuda.device_count()):
        g = torch.cuda.get_device_properties(i)
        print(f"  GPU {i}: {g.name}  ({g.total_memory/1e9:.1f} GB)")

# ── Training Configuration ────────────────────────────────────────────────
CONFIG = dict(
    backbone         = "resnet50",
    pretrained       = True,   # Automatically fetch generic image weights locally
    image_size       = 512,
    batch_size       = 8 if device.type == "cuda" else 4,   # Reduce for local debugging
    epochs_per_map   = 50,
    learning_rate    = 2e-4,
    weight_decay     = 1e-4,
    num_workers      = 0,
    mixed_precision  = device.type == "cuda",
    gradient_clip    = 1.0,
    
    # Priority multi-task loss weightings
    building_weight        = 1.0,
    roof_weight            = 0.5,
    road_weight            = 0.8,
    waterbody_weight       = 0.8,
    road_centerline_weight = 0.7,
    waterbody_line_weight  = 0.7,
    waterbody_point_weight = 0.9,
    utility_line_weight    = 0.7,
    utility_poly_weight    = 0.8,
    bridge_weight          = 1.0,
    railway_weight         = 0.9,
)

TARGET_KEYS = [
    "building_mask", "road_mask", "road_centerline_mask",
    "waterbody_mask", "waterbody_line_mask", "waterbody_point_mask",
    "utility_line_mask", "utility_poly_mask",
    "bridge_mask", "railway_mask", "roof_type_mask",
]

print("\nImports and Configuration Complete ✓")

  import pynvml  # type: ignore[import]


Environment  : Local PC
DATA dir     : /Users/aaronr/Downloads/geoii-main/DATA  (exists=True)
Device       : mps


Imports and Configuration Complete ✓


---
## CELL 2 — Training Engine Core
Initializes building blocks from the native `.py` codebase: dataset loading, architecture, logic scaling, focal loss algorithms, backwards passes.

In [2]:
from models.feature_extractor import FeatureExtractorModel
from models.losses import MultiTaskLoss
from training.metrics import MetricTracker
from data.dataset import SvamitvaDataset
from data.augmentation import get_train_transforms

def move_targets(batch):
    return {k: batch[k].to(device) for k in TARGET_KEYS if k in batch}

def build_model(load_from: Path = None):
    m = FeatureExtractorModel(
        backbone=CONFIG["backbone"],
        pretrained=CONFIG["pretrained"],
        num_roof_classes=5,
    )
    if load_from and load_from.exists():
        state = torch.load(load_from, map_location="cpu")
        m.load_state_dict(state["model"], strict=False)
        print(f"Loaded weights from checkpoint: {load_from.name}")
    if torch.cuda.device_count() > 1:
        m = nn.DataParallel(m)
    return m.to(device)


def train_map(map_name: str, resume_from: Path = None):
    map_dir  = DATA_DIR / map_name
    best_out = CKPT_DIR / f"{map_name}_best.pt"
    last_out = CKPT_DIR / f"{map_name}_latest.pt"

    if not map_dir.exists():
        print(f"[SKIPPING/ERROR] Folder NOT FOUND: {map_dir}")
        return best_out if best_out.exists() else None

    print(f"\n{'='*70}")
    print(f"  Training Engine Activating for Region : {map_name}")
    print(f"  Loading prior state weights           : {resume_from.name if resume_from and resume_from.exists() else 'RANDOM / SCRATCH'}")
    print(f"{'='*70}")

    # Model, Engine Setup
    model_w   = build_model(load_from=resume_from)
    inner     = model_w.module if isinstance(model_w, nn.DataParallel) else model_w

    # Dataset Setup
    try:
        ds = SvamitvaDataset(
            root_dir   = DATA_DIR,
            image_size = CONFIG["image_size"],
            transform  = get_train_transforms(CONFIG["image_size"]),
            mode       = "train",
        )
        ds.samples = [s for s in ds.samples if s["map_name"] == map_name]
        print(f"  Generated dynamically cached tiles    : {len(ds)}")
    except Exception as e:
        print(f"Failed to load dataset for {map_name}: {e}")
        return None

    loader = DataLoader(ds, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=CONFIG["num_workers"], pin_memory=(device.type == "cuda"))

    # Loss Setup (Includes BinaryFocalLoss)
    loss_fn = MultiTaskLoss(**{k: v for k, v in CONFIG.items() if k.endswith("_weight")}).to(device)
    optimizer = torch.optim.AdamW(model_w.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs_per_map"], eta_min=1e-6)
    
    # Mixed precision tracking (only cuda natively fully handles this reliably in PyTorch amp currently)
    use_amp = device.type == "cuda"
    scaler = GradScaler(enabled=use_amp, device='cuda' if use_amp else None)

    # Loop
    best_iou  = 0.0
    for epoch in range(1, CONFIG["epochs_per_map"] + 1):
        model_w.train()
        tracker  = MetricTracker()
        run_loss = 0.0
        n_steps  = 0
        t0       = time.time()

        for batch in loader:
            images  = batch["image"].to(device)
            targets = move_targets(batch)
            optimizer.zero_grad(set_to_none=True)
            
            # Forward Pass
            with autocast(device_type=amp_device, enabled=use_amp):
                preds         = model_w(images)
                total_loss, loss_dict = loss_fn(preds, targets)
            
            # NaN Guard & Backward Pass
            if not torch.isfinite(total_loss):
                print(f"      [NaN SKIP] skipping bad gradients in step...")
                continue 

            if use_amp:
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                if CONFIG["gradient_clip"] > 0:
                    nn.utils.clip_grad_norm_(model_w.parameters(), CONFIG["gradient_clip"])
                optimizer.step()

            run_loss += total_loss.item()
            tracker.update(preds, targets)
            n_steps += 1

        scheduler.step()
        m        = tracker.compute()
        avg_loss = run_loss / max(n_steps, 1)
        avg_iou  = m.get("avg_iou", 0.0)

        print((f"  Epoch {epoch:2d}/{CONFIG['epochs_per_map']} | "
               f"loss: {avg_loss:.4f} | iou: {avg_iou:.4f} | "
               f"time: {time.time()-t0:.0f}s"))

        # Checkpointing
        ckpt = {
            "model"    : inner.state_dict(),
            "epoch"    : epoch,
            "map_name" : map_name,
            "best_iou" : best_iou,
            "metrics"  : m,
        }
        torch.save(ckpt, last_out)
        if avg_iou > best_iou:
            best_iou = avg_iou
            ckpt["best_iou"] = best_iou
            torch.save(ckpt, best_out)
            print(f"    → New Best Checkpoint saved! IoU = {best_iou:.4f}")

    print(f"\n  [MAP COMPLETE] {map_name} finished.  Best IoU={best_iou:.4f}")
    return best_out

def analyse_checkpoint(ckpt_path: Path):
    if not ckpt_path.exists():
        print(f"Checkpoint missing: {ckpt_path}")
        return
    st = torch.load(ckpt_path, map_location="cpu")
    m  = st.get("metrics", {})
    print(f"\n{'─'*60}")
    print(f"  Checkpoint : {ckpt_path.name}")
    print(f"  MAP trained: {st.get('map_name','?')} | Epoch {st.get('epoch','?')} | Best IoU {st.get('best_iou',0):.4f}")
    print(f"  Per-task IoU:")
    for k, v in m.items():
        if k.endswith("_iou") and k != "avg_iou":
            bar = "█" * int(v * 20)
            print(f"    {k:30s} {v:.4f}  {bar}")
    print(f"{'─'*60}\n")
print("Training Core Active ✓")

Training Core Active ✓


---
## CELL 3 — Executive Training Suite 
Run this block sequentially. It handles progressive multi-map knowledge assimilation (

In [3]:
# Execute progressively. It handles missing maps gracefully, carrying over checkpoints across regions 
cpt1 = train_map("MAP1", resume_from=None)
if cpt1: analyse_checkpoint(cpt1)

cpt2 = train_map("MAP2", resume_from=cpt1)
if cpt2: analyse_checkpoint(cpt2)

cpt3 = train_map("MAP3", resume_from=cpt2)
if cpt3: analyse_checkpoint(cpt3)

cpt4 = train_map("MAP4", resume_from=cpt3)
if cpt4: analyse_checkpoint(cpt4)

cpt5 = train_map("MAP5", resume_from=cpt4)
if cpt5: analyse_checkpoint(cpt5)

print("\n*** SVAMITVA UNIVERSAL PIPELINE EXHAUSTED ***")
print("Final Model Weights successfully saved in ./checkpoints/")

[SKIPPING/ERROR] Folder NOT FOUND: /Users/aaronr/Downloads/geoii-main/DATA/MAP1
[SKIPPING/ERROR] Folder NOT FOUND: /Users/aaronr/Downloads/geoii-main/DATA/MAP2
[SKIPPING/ERROR] Folder NOT FOUND: /Users/aaronr/Downloads/geoii-main/DATA/MAP3
[SKIPPING/ERROR] Folder NOT FOUND: /Users/aaronr/Downloads/geoii-main/DATA/MAP4
[SKIPPING/ERROR] Folder NOT FOUND: /Users/aaronr/Downloads/geoii-main/DATA/MAP5

*** SVAMITVA UNIVERSAL PIPELINE EXHAUSTED ***
Final Model Weights successfully saved in ./checkpoints/
