# SVAMITVA Feature Extraction — DGX A100 Training

**Workflow:** Train one MAP at a time. After each MAP, inspect the metrics/checkpoint, then run the next MAP's cell.

```
workspace/
├── DATA/
│   ├── MAP1/  ← tif + shapefiles
│   ├── MAP2/
│   ├── MAP3/
│   ├── MAP4/
│   └── MAP5/
└── svamitva_model/
    ├── SVAMITVA_Final.ipynb
    └── checkpoints/
        ├── MAP1_best.pt      ← saved after MAP1 training
        ├── MAP2_best.pt      ← starts from MAP1 weights
        ├── MAP3_best.pt      ← starts from MAP2 weights
        ├── MAP4_best.pt
        └── MAP5_best.pt      ← final universal model
```

**How to use:**
1. Run **Setup** (Cell 1) once
2. Run **MAP1 Training** (Cell 2) → inspect output → run **MAP1 Analysis** (Cell 3)
3. Run **MAP2 Training** (Cell 4), and so on
4. The last MAP's checkpoint is the **universal model** trained on all areas


---
## CELL 1 — Setup (run once)

In [None]:
import os, sys, time
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
from pathlib import Path
from torch.amp import GradScaler, autocast
from torch.utils.data import DataLoader

# ── Paths ─────────────────────────────────────────────────────────────────────
NOTEBOOK_DIR = Path.cwd()                     # .../svamitva_model
DATA_DIR     = Path("/jupyter/sods.user04/DATA")  # absolute path on DGX
CKPT_DIR     = NOTEBOOK_DIR / "checkpoints"
CKPT_DIR.mkdir(parents=True, exist_ok=True)
(NOTEBOOK_DIR / "logs").mkdir(exist_ok=True)

if str(NOTEBOOK_DIR) not in sys.path:
    sys.path.insert(0, str(NOTEBOOK_DIR))

# ── GPU info ─────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Notebook dir : {NOTEBOOK_DIR}")
print(f"DATA dir     : {DATA_DIR}  (exists={DATA_DIR.exists()})")
print(f"Checkpoints  : {CKPT_DIR}")
print(f"Device       : {device}")
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        g = torch.cuda.get_device_properties(i)
        print(f"  GPU {i}: {g.name}  ({g.total_memory/1e9:.1f} GB)")

# ── Dataset validation ───────────────────────────────────────────────────────
SHP_PATTERNS = [
    ("Build_up.shp",          "building"),
    ("Road.shp",              "road"),
    ("Road_centre_line.shp",  "road_centerline"),
    ("Waterbody_1.shp",       "waterbody"),
    ("Waterbody_line_1.shp",  "waterbody_line"),
    ("Waterbody_point_1.shp", "waterbody_point"),
    ("Utility_1.shp",         "utility_line"),
    ("Utility_poly_1.shp",    "utility_poly"),
    # bridge and railway not present in current dataset
]

assert DATA_DIR.exists(), f"DATA not found: {DATA_DIR}"
map_dirs = sorted([d for d in DATA_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")])
print(f"\nMAP folders found: {[d.name for d in map_dirs]}")
for md in map_dirs:
    tifs = list(md.glob("*.tif")) + list(md.glob("*.tiff"))
    missing = [task for pat, task in SHP_PATTERNS if not list(md.glob(pat))]
    status = f"TIF={len(tifs)}  missing_shp={missing if missing else 'none'}"
    print(f"  {md.name}: {status}")

# ── Shared training config ───────────────────────────────────────────────────
TARGET_KEYS = [
    "building_mask", "road_mask", "road_centerline_mask",
    "waterbody_mask", "waterbody_line_mask", "waterbody_point_mask",
    "utility_line_mask", "utility_poly_mask",
    "bridge_mask", "railway_mask", "roof_type_mask",
]

CONFIG = dict(
    backbone         = "resnet50",
    pretrained       = False,   # no external URL downloads on DGX
    image_size       = 512,
    batch_size       = 8,      # A100 80GB — increase to 24 if memory allows
    epochs_per_map   = 50,
    learning_rate    = 2e-4,
    weight_decay     = 1e-4,
    num_workers      = 0,
    mixed_precision  = True,
    gradient_clip    = 1.0,
    # Loss weights
    building_weight        = 1.0,
    roof_weight            = 0.5,
    road_weight            = 0.8,
    waterbody_weight       = 0.8,
    road_centerline_weight = 0.7,
    waterbody_line_weight  = 0.7,
    waterbody_point_weight = 0.9,
    utility_line_weight    = 0.7,
    utility_poly_weight    = 0.8,
    bridge_weight          = 1.0,
    railway_weight         = 0.9,
)

print("\nConfig ready ✓")

# ── Helper: train one MAP ────────────────────────────────────────────────────
from models.feature_extractor import FeatureExtractorModel
from models.losses import MultiTaskLoss
from training.metrics import MetricTracker
from data.dataset import SvamitvaDataset
from data.augmentation import get_train_transforms

def move_targets(batch):
    return {k: batch[k].to(device) for k in TARGET_KEYS if k in batch}


def build_model(load_from: Path = None):
    """Create the model, optionally loading weights from a previous MAP."""
    m = FeatureExtractorModel(
        backbone=CONFIG["backbone"],
        pretrained=CONFIG["pretrained"],
        num_roof_classes=5,
    )
    if load_from and load_from.exists():
        state = torch.load(load_from, map_location="cpu")
        m.load_state_dict(state["model"])
        print(f"Loaded weights from: {load_from.name}")
        print(f"  Trained on: {state.get('map_name','?')}  "
              f"epoch={state.get('epoch','?')}  "
              f"best_iou={state.get('best_iou',0):.4f}")
    else:
        print("Starting from random initialisation (no prior checkpoint)")
    if torch.cuda.device_count() > 1:
        m = nn.DataParallel(m)
    return m.to(device)


def train_map(map_name: str, resume_from: Path = None):
    """
    Train for CONFIG['epochs_per_map'] epochs on the given MAP.
    resume_from: path to a previous MAP's best.pt (to load its weights).
    Returns path to this MAP's best.pt.
    """
    map_dir  = DATA_DIR / map_name
    best_out = CKPT_DIR / f"{map_name}_best.pt"
    last_out = CKPT_DIR / f"{map_name}_latest.pt"

    assert map_dir.exists(), f"MAP folder not found: {map_dir}"

    print(f"\n{'='*65}")
    print(f"  Training on: {map_name}")
    print(f"  Loading weights from: {resume_from.name if resume_from and resume_from.exists() else 'scratch'}")
    print(f"{'='*65}")

    # ── Model ─────────────────────────────────────────────────────────────
    model_w   = build_model(load_from=resume_from)
    inner     = model_w.module if isinstance(model_w, nn.DataParallel) else model_w

    # ── Data ──────────────────────────────────────────────────────────────
    ds = SvamitvaDataset(
        root_dir   = DATA_DIR,
        image_size = CONFIG["image_size"],
        transform  = get_train_transforms(CONFIG["image_size"]),
        mode       = "train",
    )
    ds.samples = [s for s in ds.samples if s["map_name"] == map_name]
    assert ds.samples, f"No samples found for {map_name} in {DATA_DIR}"
    print(f"  Tiles: {len(ds)}")

    loader = DataLoader(
        ds,
        batch_size  = CONFIG["batch_size"],
        shuffle     = True,
        num_workers = CONFIG["num_workers"],
        pin_memory  = True,
        drop_last   = len(ds) > CONFIG["batch_size"],
    )

    # ── Loss / optimiser / scheduler ──────────────────────────────────────
    loss_fn = MultiTaskLoss(
        **{k: v for k, v in CONFIG.items() if k.endswith("_weight")}
    ).to(device)

    optimizer = torch.optim.AdamW(
        model_w.parameters(),
        lr=CONFIG["learning_rate"],
        weight_decay=CONFIG["weight_decay"],
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=CONFIG["epochs_per_map"], eta_min=1e-6
    )
    scaler = GradScaler('cuda', enabled=CONFIG['mixed_precision'])

    # ── Training loop ──────────────────────────────────────────────────────
    best_iou  = 0.0
    log_lines = []

    for epoch in range(1, CONFIG["epochs_per_map"] + 1):
        model_w.train()
        tracker  = MetricTracker()
        run_loss = 0.0
        n_steps  = 0
        t0       = time.time()

        for batch in loader:
            images  = batch["image"].to(device)
            targets = move_targets(batch)
            optimizer.zero_grad()
            with autocast('cuda', enabled=CONFIG['mixed_precision']):
                preds         = model_w(images)
                total_loss, loss_dict = loss_fn(preds, targets)
            # ── NaN Guard: check BEFORE backward to protect backbone ───────
            if torch.isnan(total_loss) or torch.isinf(total_loss):
                bad = [k for k, v in loss_dict.items() if torch.isnan(v) or torch.isinf(v)]
                print(f"      [NaN SKIP] bad heads={bad}")
                optimizer.zero_grad()  # clear any pending gradient state
                continue               # skip backward entirely
            scaler.scale(total_loss).backward()
            if CONFIG["gradient_clip"] > 0:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model_w.parameters(), CONFIG["gradient_clip"])
            scaler.step(optimizer)
            scaler.update()
            run_loss += total_loss.item()
            tracker.update(preds, targets)
            n_steps += 1

        scheduler.step()
        m        = tracker.compute()
        avg_loss = run_loss / max(n_steps, 1)
        avg_iou  = m.get("avg_iou", 0.0)
        lr_now   = scheduler.get_last_lr()[0]

        line = (f"  Epoch {epoch:3d}/{CONFIG['epochs_per_map']}  "
                f"loss={avg_loss:.4f}  avg_iou={avg_iou:.4f}  "
                f"lr={lr_now:.2e}  t={time.time()-t0:.0f}s")
        print(line)
        log_lines.append(line)

        # ── Checkpoint ────────────────────────────────────────────────────
        ckpt = {
            "model"    : inner.state_dict(),
            "epoch"    : epoch,
            "map_name" : map_name,
            "avg_iou"  : avg_iou,
            "best_iou" : best_iou,
            "metrics"  : m,
            "config"   : CONFIG,
        }
        torch.save(ckpt, last_out)   # always overwrite latest

        if avg_iou > best_iou:
            best_iou = avg_iou
            ckpt["best_iou"] = best_iou
            torch.save(ckpt, best_out)  # save best
            print(f"    ★ Best checkpoint updated  ({map_name}_best.pt  iou={best_iou:.4f})")

    # Save epoch log
    log_path = NOTEBOOK_DIR / "logs" / f"{map_name}_training.log"
    log_path.write_text("\n".join(log_lines))

    print(f"\n  {map_name} done.  Best IoU={best_iou:.4f}")
    print(f"  Best checkpoint : {best_out}")
    print(f"  Training log    : {log_path}")
    return best_out


def analyse_checkpoint(ckpt_path: Path, map_name: str = None):
    """Print a summary of a saved checkpoint."""
    assert ckpt_path.exists(), f"Not found: {ckpt_path}"
    st = torch.load(ckpt_path, map_location="cpu")
    m  = st.get("metrics", {})
    print(f"\n{'─'*60}")
    print(f"  Checkpoint : {ckpt_path.name}")
    print(f"  MAP trained: {st.get('map_name','?')}  "
          f"epoch={st.get('epoch','?')}  "
          f"best_iou={st.get('best_iou',0):.4f}")
    print(f"  avg_iou    : {m.get('avg_iou',0):.4f}")
    print(f"  Per-task IoU:")
    for k, v in m.items():
        if k.endswith("_iou") and k != "avg_iou":
            bar = "█" * int(v * 20)
            print(f"    {k:30s} {v:.4f}  {bar}")
    print(f"{'─'*60}\n")

print("\nAll helpers ready. Proceed to MAP1 training cell.")


---
## CELL 2 — Train MAP1 (from scratch)

In [None]:
map1_best = train_map("MAP1", resume_from=None)


## CELL 3 — Analyse MAP1 Weights

Review per-task IoU before proceeding to MAP2.

In [None]:
map1_best = CKPT_DIR / "MAP1_best.pt"
analyse_checkpoint(map1_best)

# Optional: view training loss curve
log = (NOTEBOOK_DIR / "logs" / "MAP1_training.log").read_text()
print("Last 10 epochs:")
for line in log.strip().split("\n")[-10:]:
    print(line)


---
## CELL 4 — Train MAP2 (from MAP1 weights)

In [None]:
map2_best = train_map("MAP2", resume_from=CKPT_DIR / "MAP1_best.pt")


## CELL 5 — Analyse MAP2 Weights

In [None]:
analyse_checkpoint(CKPT_DIR / "MAP2_best.pt")
log = (NOTEBOOK_DIR / "logs" / "MAP2_training.log").read_text()
print("Last 10 epochs:")
for line in log.strip().split("\n")[-10:]: print(line)


---
## CELL 6 — Train MAP3 (from MAP2 weights)

In [None]:
map3_best = train_map("MAP3", resume_from=CKPT_DIR / "MAP2_best.pt")


## CELL 7 — Analyse MAP3 Weights

In [None]:
analyse_checkpoint(CKPT_DIR / "MAP3_best.pt")
log = (NOTEBOOK_DIR / "logs" / "MAP3_training.log").read_text()
print("Last 10 epochs:")
for line in log.strip().split("\n")[-10:]: print(line)


---
## CELL 8 — Train MAP4 (from MAP3 weights)

In [None]:
map4_best = train_map("MAP4", resume_from=CKPT_DIR / "MAP3_best.pt")


## CELL 9 — Analyse MAP4 Weights

In [None]:
analyse_checkpoint(CKPT_DIR / "MAP4_best.pt")
log = (NOTEBOOK_DIR / "logs" / "MAP4_training.log").read_text()
print("Last 10 epochs:")
for line in log.strip().split("\n")[-10:]: print(line)


---
## CELL 10 — Train MAP5 (from MAP4 weights)

In [None]:
map5_best = train_map("MAP5", resume_from=CKPT_DIR / "MAP4_best.pt")


## CELL 11 — Analyse MAP5 Weights (Final Universal Model)

In [None]:
analyse_checkpoint(CKPT_DIR / "MAP5_best.pt")
log = (NOTEBOOK_DIR / "logs" / "MAP5_training.log").read_text()
print("Last 10 epochs:")
for line in log.strip().split("\n")[-10:]: print(line)

print("\n" + "="*65)
print("TRAINING COMPLETE — MAP5_best.pt is the final universal model")
print("It has learned features from all 5 MAP areas.")
print("="*65)


---
## CELL 12 — Inference Preview (any MAP)

In [None]:
import matplotlib.pyplot as plt
from models.feature_extractor import FeatureExtractorModel
from data.preprocessing import OrthophotoPreprocessor

# ── Change these as needed ────────────────────────────────────────────────
PREVIEW_MAP  = "MAP1"           # which MAP's TIF to preview
CKPT_TO_USE  = "MAP5_best.pt"  # which checkpoint to load (use final for best results)

ckpt_path = CKPT_DIR / CKPT_TO_USE
assert ckpt_path.exists(), f"Run training first: {ckpt_path}"

state = torch.load(ckpt_path, map_location=device)
print(f"Loaded: {CKPT_TO_USE}  (trained on {state.get('map_name','?')}, best_iou={state.get('best_iou',0):.4f})")

infer = FeatureExtractorModel(
    backbone=CONFIG["backbone"], pretrained=False, num_roof_classes=5
).to(device)
infer.load_state_dict(state["model"])
infer.eval()

tifs = sorted((DATA_DIR / PREVIEW_MAP).glob("*.tif"))
assert tifs, f"No TIF in {DATA_DIR / PREVIEW_MAP}"

prep  = OrthophotoPreprocessor()
image, _ = prep.load_orthophoto(tifs[0], target_size=(512, 512))
img_t = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)

with torch.no_grad():
    preds = infer(img_t)

keys = ["building_mask", "road_mask", "waterbody_mask",
        "road_centerline_mask", "utility_line_mask",
        "bridge_mask", "railway_mask"]

fig, axes = plt.subplots(2, 4, figsize=(22, 11))
axes = axes.flatten()
axes[0].imshow(image); axes[0].set_title(f"{PREVIEW_MAP}: {tifs[0].name}"); axes[0].axis("off")
for i, k in enumerate(keys, 1):
    pm = torch.sigmoid(preds[k]).squeeze().cpu().numpy()
    axes[i].imshow(pm, cmap="jet", vmin=0, vmax=1)
    axes[i].set_title(k.replace("_mask", "")); axes[i].axis("off")
for j in range(len(keys)+1, len(axes)): axes[j].axis("off")

plt.suptitle(f"Model: {CKPT_TO_USE}  |  Input: {PREVIEW_MAP}/{tifs[0].name}", fontsize=12)
plt.tight_layout()
out = CKPT_DIR / f"{PREVIEW_MAP}_inference_preview.png"
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.show()
print(f"Saved to {out}")
