# SVAMITVA — Local Mac Training Pipeline (MAPC Sub-Maps)

**Target:** Apple Silicon MPS or CPU.  
**DATA path:** `/Users/aaronr/Desktop/DATA/MAPC` (pre-clipped 512×512 sub-maps)

Discovers `MAP1.*`, `MAP2.*`, etc. sub-maps, groups by parent map, and trains sequentially for rapid local prototyping.

---
## Cell 1 — Setup

In [2]:
import os
import sys
import time
from pathlib import Path
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

NOTEBOOK_DIR = Path.cwd()
if str(NOTEBOOK_DIR) not in sys.path:
    sys.path.insert(0, str(NOTEBOOK_DIR))

DATA_DIR = Path("/Users/aaronr/Desktop/DATA/MAPC")
CKPT_DIR = NOTEBOOK_DIR / "checkpoints"
CKPT_DIR.mkdir(parents=True, exist_ok=True)
(NOTEBOOK_DIR / "logs").mkdir(exist_ok=True)

# Device
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print("Apple Silicon GPU (MPS) ✅")
else:
    device = torch.device("cpu")
    print("Running on CPU ⚠️")

print(f"DATA: {DATA_DIR}  (exists={DATA_DIR.exists()})")

CONFIG = dict(
    backbone              = "resnet34",  # Lighter for local test
    pretrained            = True,
    image_size            = 256,         # Smaller for local test
    batch_size            = 4,
    epochs_per_map        = 5,           # Less epochs for local test
    learning_rate         = 2e-4,
    weight_decay          = 1e-4,
    num_workers           = 0,
    mixed_precision       = False,       # MPS AMP unstable
    gradient_clip         = 1.0,
    building_weight       = 1.0,
    roof_weight           = 0.5,
    road_weight           = 0.8,
    waterbody_weight      = 0.8,
    road_centerline_weight= 0.7,
    waterbody_line_weight = 0.7,
    waterbody_point_weight= 0.9,
    utility_line_weight   = 0.7,
    utility_poly_weight   = 0.8,
    bridge_weight         = 1.0,
    railway_weight        = 0.9,
)

TARGET_KEYS = [
    "building_mask", "road_mask", "road_centerline_mask",
    "waterbody_mask", "waterbody_line_mask", "waterbody_point_mask",
    "utility_line_mask", "utility_poly_mask",
    "bridge_mask", "railway_mask", "roof_type_mask",
]

print("\nSetup complete ✓")

Apple Silicon GPU (MPS) ✅
DATA: /Users/aaronr/Desktop/DATA/MAPC  (exists=True)

Setup complete ✓


---
## Cell 2 — Training Engine

In [None]:
from pathlib import Path
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from models.feature_extractor import FeatureExtractor
from models.losses import MultiTaskLoss
from training.metrics import MetricTracker
from data.dataset import SvamitvaDataset
from data.augmentation import get_train_transforms

# --- Single global checkpoint files ---
BEST_CKPT = CKPT_DIR / "MAP_best.pt"
LATEST_CKPT = CKPT_DIR / "MAP_latest.pt"


def move_targets(batch):
    return {k: batch[k].to(device) for k in TARGET_KEYS if k in batch}


def build_model(load_from: Path = None):
    m = FeatureExtractor(
        backbone=CONFIG["backbone"],
        pretrained=CONFIG["pretrained"],
        num_roof_classes=5,
    )

    if load_from and load_from.exists():
        state = torch.load(load_from, map_location="cpu")
        weights = state.get("model") or state.get("model_state_dict") or state
        m.load_state_dict(weights, strict=False)
        print(f"  Loaded weights from: {load_from.name}")

    return m.to(device)


def train_submap(sub_name: str, model_w, optimizer, scheduler, scaler_state, best_iou):
    """Train one sub-map (e.g. MAP1.42). Returns (model, optimizer, scheduler, best_iou)."""
    sub_dir = DATA_DIR / sub_name
    if not sub_dir.exists():
        print(f"  [SKIP] {sub_dir} not found")
        return model_w, optimizer, scheduler, best_iou, scaler_state

    ds = SvamitvaDataset(
        root_dir=DATA_DIR,
        image_size=CONFIG["image_size"],
        transform=get_train_transforms(CONFIG["image_size"]),
        mode="train",
    )
    ds.samples = [s for s in ds.samples if s["map_name"] == sub_name]

    if not ds.samples:
        print(f"  [SKIP] {sub_name}: 0 tiles")
        return model_w, optimizer, scheduler, best_iou, scaler_state

    loader = DataLoader(
        ds, batch_size=CONFIG["batch_size"],
        shuffle=True, num_workers=0, pin_memory=False
    )

    loss_fn = MultiTaskLoss(
        **{k: v for k, v in CONFIG.items() if k.endswith("_weight")}
    ).to(device)

    for epoch in range(1, CONFIG["epochs_per_map"] + 1):
        model_w.train()
        tracker = MetricTracker()
        run_loss, n_steps, t0 = 0.0, 0, time.time()

        try:
            for batch in loader:
                images = batch["image"].to(device)
                targets = move_targets(batch)
                optimizer.zero_grad(set_to_none=True)

                if not torch.isfinite(images).all():
                    continue

                preds = model_w(images)
                total_loss, _ = loss_fn(preds, targets)

                if not torch.isfinite(total_loss):
                    continue

                total_loss.backward()
                if CONFIG["gradient_clip"] > 0:
                    nn.utils.clip_grad_norm_(model_w.parameters(), CONFIG["gradient_clip"])
                optimizer.step()

                run_loss += total_loss.item()
                tracker.update(preds, targets)
                n_steps += 1

        except KeyboardInterrupt:
            print(f"\n  [INTERRUPT] Saving emergency checkpoint...")
            save_inference_checkpoint(model_w, LATEST_CKPT)
            raise

        scheduler.step()
        metrics = tracker.compute()
        avg_loss = run_loss / max(n_steps, 1)
        avg_iou = metrics.get("avg_iou", 0.0)

        print(
            f"    Epoch {epoch}/{CONFIG['epochs_per_map']} | "
            f"loss: {avg_loss:.4f} | iou: {avg_iou:.4f} | {time.time()-t0:.0f}s"
        )

        # Always save latest
        save_training_checkpoint(model_w, optimizer, epoch, sub_name, best_iou, LATEST_CKPT)
        save_inference_checkpoint(model_w, LATEST_CKPT)

        # Update best if improved
        if avg_iou > best_iou:
            best_iou = avg_iou
            save_training_checkpoint(model_w, optimizer, epoch, sub_name, best_iou, BEST_CKPT)
            save_inference_checkpoint(model_w, BEST_CKPT)
            print(f"    → New best! IoU = {best_iou:.4f}")

    return model_w, optimizer, scheduler, best_iou, None


def train_parent_map(parent_map: str, resume_from: Path = None):
    """Train each sub-map (MAP1.1, MAP1.2, …) individually in sequence.
    All checkpoints go to MAP_best.pt / MAP_latest.pt."""

    sub_maps = sorted(
        [d.name for d in DATA_DIR.iterdir()
         if d.is_dir() and d.name.startswith(parent_map + ".")],
        key=lambda n: int(n.split(".")[-1])  # sort numerically: 1,2,…,3180
    )
    if not sub_maps:
        print(f"[SKIP] No sub-maps for {parent_map}")
        return resume_from

    print(f"\n{'='*70}")
    print(f"  Parent map : {parent_map}  ({len(sub_maps)} sub-maps)")
    print(f"  Sub-maps   : {sub_maps[0]} → {sub_maps[-1]}")
    print(f"  Resume     : {resume_from.name if resume_from and resume_from.exists() else 'SCRATCH'}")
    print(f"  Saving to  : {BEST_CKPT.name} / {LATEST_CKPT.name}")
    print(f"{'='*70}")

    model_w = build_model(load_from=resume_from)

    optimizer = torch.optim.AdamW(
        model_w.parameters(),
        lr=CONFIG["learning_rate"],
        weight_decay=CONFIG["weight_decay"]
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=CONFIG["epochs_per_map"], eta_min=1e-6
    )

    # Load best_iou from existing checkpoint if resuming
    best_iou = 0.0
    if resume_from and resume_from.exists():
        try:
            st = torch.load(resume_from.with_suffix(".train.pt"), map_location="cpu")
            best_iou = st.get("best_iou", 0.0)
            print(f"  Resuming with best_iou = {best_iou:.4f}")
        except Exception:
            pass

    for i, sub_name in enumerate(sub_maps, 1):
        print(f"\n  [{i}/{len(sub_maps)}] Training {sub_name} …")
        model_w, optimizer, scheduler, best_iou, _ = train_submap(
            sub_name, model_w, optimizer, scheduler, None, best_iou
        )

    print(f"\n  ✅ {parent_map} complete — best IoU: {best_iou:.4f}")
    return BEST_CKPT if BEST_CKPT.exists() else LATEST_CKPT


# ----------------------------
# CHECKPOINT HELPERS
# ----------------------------

def save_training_checkpoint(model, optimizer, epoch, map_name, best_iou, path):
    inner = model.module if hasattr(model, "module") else model
    ckpt = {
        "model_state_dict": inner.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "map_name": map_name,
        "best_iou": best_iou,
    }
    torch.save(ckpt, path.with_suffix(".train.pt"))


def save_inference_checkpoint(model, path):
    inner = model.module if hasattr(model, "module") else model
    torch.save(
        {k: v.cpu() for k, v in inner.state_dict().items()},
        path
    )


print("Training engine ready ✓")

Training engine ready ✓


---
## Cell 3 — Execute Training
Groups MAPC sub-maps (MAP1.1, MAP1.2, …) by parent (MAP1, MAP2, …) and trains each parent map sequentially.

In [4]:
# Discover sub-map folders and group by parent map
sub_folders = sorted([
    d.name for d in DATA_DIR.iterdir()
    if d.is_dir() and d.name.startswith("MAP") and "." in d.name
])

parent_maps = []
seen = set()
for name in sub_folders:
    parent = name.split(".")[0]     # "MAP1.123" → "MAP1"
    if parent not in seen:
        seen.add(parent)
        parent_maps.append(parent)
parent_maps.sort()

sub_counts = {p: sum(1 for n in sub_folders if n.startswith(p + ".")) for p in parent_maps}
print(f"Found {len(parent_maps)} parent maps: {[f'{p} ({sub_counts[p]} sub-maps)' for p in parent_maps]}")

prev_ckpt = None
for p_name in parent_maps:
    print(f"\n⏳ Training {p_name} ({sub_counts[p_name]} sub-maps)...")
    ckpt = train_parent_map(p_name, resume_from=prev_ckpt)
    if ckpt and ckpt.exists():
        prev_ckpt = ckpt
        print(f"✅ {p_name} done — checkpoint: {ckpt}")
    else:
        print(f"❌ {p_name} failed")

print("\n*** LOCAL TRAINING (MAPC) COMPLETE ***")

Found 2 parent maps: ['MAP1 (3180 sub-maps)', 'MAP2 (4369 sub-maps)']

⏳ Training MAP1 (3180 sub-maps)...

  Region     : MAP1  (3180 sub-maps)
  Sub-maps   : MAP1.1 … MAP1.999
  Checkpoint : SCRATCH


INFO:data.preprocessing:OrthophotoPreprocessor: target CRS = EPSG:32643
INFO:data.preprocessing:ShapefileAnnotationParser initialised
INFO:data.dataset:[MAP1.1] Orthophoto.tif | tasks: building, road, road_centerline
INFO:data.dataset:  → 1 tiles from MAP1.1 (512×512px)
INFO:data.dataset:[MAP1.1000] Orthophoto.tif | tasks: building
INFO:data.dataset:  → 1 tiles from MAP1.1000 (512×512px)
INFO:data.dataset:[MAP1.1002] Orthophoto.tif | tasks: building
INFO:data.dataset:  → 1 tiles from MAP1.1002 (512×512px)
INFO:data.dataset:[MAP1.1003] Orthophoto.tif | tasks: building, road, road_centerline
INFO:data.dataset:  → 1 tiles from MAP1.1003 (512×512px)
INFO:data.dataset:[MAP1.1004] Orthophoto.tif | tasks: building, road, road_centerline
INFO:data.dataset:  → 1 tiles from MAP1.1004 (512×512px)
INFO:data.dataset:[MAP1.1005] Orthophoto.tif | tasks: building, road, road_centerline
INFO:data.dataset:  → 1 tiles from MAP1.1005 (512×512px)
INFO:data.dataset:[MAP1.1006] Orthophoto.tif | tasks: buildi

  Tiles: 2434 from 3180 sub-maps


INFO:data.preprocessing:  [building] mask positive pixels: 243684
INFO:data.preprocessing:  [building] mask positive pixels: 238663
INFO:data.preprocessing:  [building] mask positive pixels: 211734
INFO:data.preprocessing:  [building] mask positive pixels: 158054
INFO:data.preprocessing:  [building] mask positive pixels: 12093
INFO:data.preprocessing:  [building] mask positive pixels: 200621
INFO:data.preprocessing:  [road] mask positive pixels: 156
INFO:data.preprocessing:  [road] mask positive pixels: 51891
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 67622
INFO:data.preprocessing:  [building] mask positive pixels: 875
INFO:data.preprocessing:  [road] mask positive pixels: 65608
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 117785
INFO:data.preprocessing:  [road] mask positive pixels: 119378
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 112126
INFO:data.preprocessing:  [building] mask positive pixels: 29977
INFO:data.prepro

Epoch  1/5 | loss: 3.8041 | iou: 0.1073 | 225s


INFO:data.preprocessing:  [building] mask positive pixels: 262144
INFO:data.preprocessing:  [building] mask positive pixels: 147708
INFO:data.preprocessing:  [building] mask positive pixels: 27170
INFO:data.preprocessing:  [road] mask positive pixels: 68499
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 79026
INFO:data.preprocessing:  [utility_poly] mask positive pixels: 5800
INFO:data.preprocessing:  [utility_line] mask positive pixels: 10723
INFO:data.preprocessing:  [road] mask positive pixels: 39183
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 27679


→ New best! IoU = 0.1073


INFO:data.preprocessing:  [building] mask positive pixels: 144982
INFO:data.preprocessing:  [road] mask positive pixels: 52029
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 85266
INFO:data.preprocessing:  [building] mask positive pixels: 35524
INFO:data.preprocessing:  [road] mask positive pixels: 69077
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 80252
INFO:data.preprocessing:  [waterbody_line] mask positive pixels: 82725
INFO:data.preprocessing:  [building] mask positive pixels: 184375
INFO:data.preprocessing:  [road] mask positive pixels: 3739
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 19381
INFO:data.preprocessing:  [building] mask positive pixels: 182044
INFO:data.preprocessing:  [road] mask positive pixels: 26432
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 69133
INFO:data.preprocessing:  [building] mask positive pixels: 159587
INFO:data.preprocessing:  [road] mask positive pixels: 11442
INFO:dat


[INTERRUPT] Saving emergency checkpoint...


KeyboardInterrupt: 