# SVAMITVA — Local Mac Training Pipeline

**Target:** Apple Silicon MPS or CPU.  
**DATA path:** `/Users/aaronr/Desktop/DATA`

Trains on MAP1 only for rapid local prototyping with KMeans Tile Filtering enabled.

---
## Cell 1 — Setup

In [1]:
import os
import sys
import time
from pathlib import Path
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

NOTEBOOK_DIR = Path.cwd()
if str(NOTEBOOK_DIR) not in sys.path:
    sys.path.insert(0, str(NOTEBOOK_DIR))

DATA_DIR = Path("/Users/aaronr/Desktop/DATA")
CKPT_DIR = NOTEBOOK_DIR / "checkpoints"
CKPT_DIR.mkdir(parents=True, exist_ok=True)
(NOTEBOOK_DIR / "logs").mkdir(exist_ok=True)

# Device
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print("Apple Silicon GPU (MPS) \u2705")
else:
    device = torch.device("cpu")
    print("Running on CPU \u26a0\ufe0f")

print(f"DATA: {DATA_DIR}  (exists={DATA_DIR.exists()})")

CONFIG = dict(
    backbone              = "resnet34",  # Lighter for local test
    pretrained            = True,
    image_size            = 256,         # Smaller for local test
    batch_size            = 4,
    epochs_per_map        = 5,           # Less epochs for local test
    learning_rate         = 2e-4,
    weight_decay          = 1e-4,
    num_workers           = 0,
    mixed_precision       = False,       # MPS AMP unstable
    gradient_clip         = 1.0,
    building_weight       = 1.0,
    roof_weight           = 0.5,
    road_weight           = 0.8,
    waterbody_weight      = 0.8,
    road_centerline_weight= 0.7,
    waterbody_line_weight = 0.7,
    waterbody_point_weight= 0.9,
    utility_line_weight   = 0.7,
    utility_poly_weight   = 0.8,
    bridge_weight         = 1.0,
    railway_weight        = 0.9,
)

TARGET_KEYS = [
    "building_mask", "road_mask", "road_centerline_mask",
    "waterbody_mask", "waterbody_line_mask", "waterbody_point_mask",
    "utility_line_mask", "utility_poly_mask",
    "bridge_mask", "railway_mask", "roof_type_mask",
]

print("\nSetup complete \u2713")

Apple Silicon GPU (MPS) ✅
DATA: /Users/aaronr/Desktop/DATA  (exists=True)

Setup complete ✓


---
## Cell 2 — Training Engine

In [5]:
from pathlib import Path
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from models.feature_extractor import FeatureExtractor
from models.losses import MultiTaskLoss
from training.metrics import MetricTracker
from data.dataset import SvamitvaDataset
from data.augmentation import get_train_transforms


def move_targets(batch):
    return {k: batch[k].to(device) for k in TARGET_KEYS if k in batch}


def build_model(load_from: Path = None):
    m = FeatureExtractor(
        backbone=CONFIG["backbone"],
        pretrained=CONFIG["pretrained"],
        num_roof_classes=5,
    )

    if load_from and load_from.exists():
        state = torch.load(load_from, map_location="cpu")
        weights = state.get("model") or state.get("model_state_dict") or state
        m.load_state_dict(weights, strict=False)
        print(f"Loaded: {load_from.name}")

    return m.to(device)


def train_map(map_name: str, resume_from: Path = None):
    map_dir = DATA_DIR / map_name
    best_out = CKPT_DIR / f"{map_name}_best.pt"
    last_out = CKPT_DIR / f"{map_name}_latest.pt"

    if not map_dir.exists():
        print(f"[SKIP] {map_dir} not found")
        return best_out if best_out.exists() else None

    print(f"\n{'='*70}")
    print(f"  Region     : {map_name}")
    print(f"  Checkpoint : {resume_from.name if resume_from and resume_from.exists() else 'SCRATCH'}")
    print(f"{'='*70}")

    model_w = build_model(load_from=resume_from)

    ds = SvamitvaDataset(
        root_dir=DATA_DIR,
        image_size=CONFIG["image_size"],
        transform=get_train_transforms(CONFIG["image_size"]),
        mode="train",
    )
    ds.samples = [s for s in ds.samples if s["map_name"] == map_name]

    print(f"  Tiles: {len(ds)}")

    loader = DataLoader(
        ds,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )

    loss_fn = MultiTaskLoss(
        **{k: v for k, v in CONFIG.items() if k.endswith("_weight")}
    ).to(device)

    optimizer = torch.optim.AdamW(
        model_w.parameters(),
        lr=CONFIG["learning_rate"],
        weight_decay=CONFIG["weight_decay"]
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=CONFIG["epochs_per_map"],
        eta_min=1e-6
    )

    best_iou = 0.0

    for epoch in range(1, CONFIG["epochs_per_map"] + 1):

        model_w.train()
        tracker = MetricTracker()
        run_loss = 0.0
        n_steps = 0
        t0 = time.time()

        try:
            for batch in loader:

                images = batch["image"].to(device)
                targets = move_targets(batch)

                optimizer.zero_grad(set_to_none=True)

                if not torch.isfinite(images).all():
                    print("  [NaN SKIP] NaN in images")
                    continue

                preds = model_w(images)
                total_loss, loss_d = loss_fn(preds, targets)

                if not torch.isfinite(total_loss):
                    print(f"  [NaN SKIP] epoch {epoch} loss NaN")
                    continue

                total_loss.backward()

                if CONFIG["gradient_clip"] > 0:
                    nn.utils.clip_grad_norm_(
                        model_w.parameters(),
                        CONFIG["gradient_clip"]
                    )

                optimizer.step()

                run_loss += total_loss.item()
                tracker.update(preds, targets)
                n_steps += 1

        except KeyboardInterrupt:
            print("\n[INTERRUPT] Saving emergency checkpoint...")
            save_training_checkpoint(model_w, optimizer, epoch, map_name, best_iou, last_out)
            raise

        except Exception as e:
            print(f"\n[ERROR] {e}. Saving emergency checkpoint...")
            save_training_checkpoint(model_w, optimizer, epoch, map_name, best_iou, last_out)
            raise

        scheduler.step()

        metrics = tracker.compute()
        avg_loss = run_loss / max(n_steps, 1)
        avg_iou = metrics.get("avg_iou", 0.0)

        print(
            f"Epoch {epoch:2d}/{CONFIG['epochs_per_map']} | "
            f"loss: {avg_loss:.4f} | "
            f"iou: {avg_iou:.4f} | "
            f"{time.time()-t0:.0f}s"
        )

        # Save training checkpoint (large file)
        save_training_checkpoint(
            model_w, optimizer, epoch, map_name, best_iou, last_out
        )

        # Save inference-only weights (small file)
        save_inference_checkpoint(model_w, last_out)

        if avg_iou > best_iou:
            best_iou = avg_iou
            save_training_checkpoint(
                model_w, optimizer, epoch, map_name, best_iou, best_out
            )
            save_inference_checkpoint(model_w, best_out)
            print(f"→ New best! IoU = {best_iou:.4f}")

    print(f"\n[DONE] {map_name}  Best IoU={best_iou:.4f}")
    return best_out


# ----------------------------
# CHECKPOINT HELPERS
# ----------------------------

def save_training_checkpoint(model, optimizer, epoch, map_name, best_iou, path):
    inner = model.module if hasattr(model, "module") else model

    ckpt = {
        "model_state_dict": inner.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "map_name": map_name,
        "best_iou": best_iou,
    }

    torch.save(ckpt, path.with_suffix(".train.pt"))


def save_inference_checkpoint(model, path):
    inner = model.module if hasattr(model, "module") else model
    torch.save(
        {k: v.cpu() for k, v in inner.state_dict().items()},
        path
    )


print("Training engine ready ✓")

Training engine ready ✓


---
## Cell 3 — Execute Training
Training exactly ONE MAP as requested by user locally.

In [6]:
cpt1 = train_map("MAP1", resume_from=None)
print("\n*** LOCAL TRAINING COMPLETE ***")


  Region     : MAP1
  Checkpoint : SCRATCH


INFO:data.preprocessing:OrthophotoPreprocessor: target CRS = EPSG:32643
INFO:data.preprocessing:ShapefileAnnotationParser initialised
INFO:data.dataset:[MAP1] MAP1.tif | tasks: building, road, road_centerline, waterbody, waterbody_line, waterbody_point, utility_poly, utility_line
INFO:data.dataset:  → 2786 tiles from MAP1 (27390×26259px)
INFO:data.dataset:[MAP2] MAP2.tif | tasks: building, road, road_centerline, waterbody, waterbody_line, waterbody_point, utility_poly, utility_line
INFO:data.dataset:  → 3933 tiles from MAP2 (54365×76309px)
INFO:data.dataset:[train] dataset ready: 6719 tile samples


  Tiles: 2786


INFO:data.preprocessing:  [building] mask positive pixels: 47972
INFO:data.preprocessing:  [building] mask positive pixels: 115877
INFO:data.preprocessing:  [building] mask positive pixels: 90380
INFO:data.preprocessing:  [building] mask positive pixels: 231769
INFO:data.preprocessing:  [road] mask positive pixels: 20197
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 50788
INFO:data.preprocessing:  [building] mask positive pixels: 227320
INFO:data.preprocessing:  [building] mask positive pixels: 164159
INFO:data.preprocessing:  [road] mask positive pixels: 23368
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 27414
INFO:data.preprocessing:  [building] mask positive pixels: 178477
INFO:data.preprocessing:  [building] mask positive pixels: 123617
INFO:data.preprocessing:  [road] mask positive pixels: 60438
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 53186
INFO:data.preprocessing:  [building] mask positive pixels: 239528
INFO:data

Epoch  1/5 | loss: 3.6878 | iou: 0.0978 | 825s


INFO:data.preprocessing:  [road] mask positive pixels: 43211
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 50682


→ New best! IoU = 0.0978


INFO:data.preprocessing:  [building] mask positive pixels: 45888
INFO:data.preprocessing:  [road] mask positive pixels: 57504
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 50948
INFO:data.preprocessing:  [utility_poly] mask positive pixels: 228
INFO:data.preprocessing:  [utility_line] mask positive pixels: 7645
INFO:data.preprocessing:  [building] mask positive pixels: 212008
INFO:data.preprocessing:  [road] mask positive pixels: 1220
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 694
INFO:data.preprocessing:  [building] mask positive pixels: 203105
INFO:data.preprocessing:  [building] mask positive pixels: 167909
INFO:data.preprocessing:  [road] mask positive pixels: 47778
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 55737
INFO:data.preprocessing:  [building] mask positive pixels: 90465
INFO:data.preprocessing:  [building] mask positive pixels: 11284
INFO:data.preprocessing:  [road] mask positive pixels: 57414
INFO:data.prepr

Epoch  2/5 | loss: 2.9933 | iou: 0.1435 | 819s
→ New best! IoU = 0.1435


INFO:data.preprocessing:  [building] mask positive pixels: 250786
INFO:data.preprocessing:  [building] mask positive pixels: 88738
INFO:data.preprocessing:  [road] mask positive pixels: 13109
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 25626
INFO:data.preprocessing:  [building] mask positive pixels: 179212
INFO:data.preprocessing:  [building] mask positive pixels: 8262
INFO:data.preprocessing:  [building] mask positive pixels: 94963
INFO:data.preprocessing:  [building] mask positive pixels: 67621
INFO:data.preprocessing:  [road] mask positive pixels: 22396
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 15458
INFO:data.preprocessing:  [building] mask positive pixels: 132421
INFO:data.preprocessing:  [road] mask positive pixels: 40068
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 38625
INFO:data.preprocessing:  [building] mask positive pixels: 81978
INFO:data.preprocessing:  [road] mask positive pixels: 57852
INFO:data.preproce

Epoch  3/5 | loss: 2.4803 | iou: 0.1620 | 816s


INFO:data.preprocessing:  [road] mask positive pixels: 4836
INFO:data.preprocessing:  [waterbody_line] mask positive pixels: 18540


→ New best! IoU = 0.1620


INFO:data.preprocessing:  [building] mask positive pixels: 162948
INFO:data.preprocessing:  [building] mask positive pixels: 27699
INFO:data.preprocessing:  [building] mask positive pixels: 99557
INFO:data.preprocessing:  [building] mask positive pixels: 116186
INFO:data.preprocessing:  [road] mask positive pixels: 43160
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 50863
INFO:data.preprocessing:  [building] mask positive pixels: 77485
INFO:data.preprocessing:  [waterbody] mask positive pixels: 125843
INFO:data.preprocessing:  [building] mask positive pixels: 155533
INFO:data.preprocessing:  [building] mask positive pixels: 842
INFO:data.preprocessing:  [road] mask positive pixels: 22948
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 11980
INFO:data.preprocessing:  [building] mask positive pixels: 212182
INFO:data.preprocessing:  [road_centerline] mask positive pixels: 69
INFO:data.preprocessing:  [building] mask positive pixels: 219722
INFO:data.


[INTERRUPT] Saving emergency checkpoint...


KeyboardInterrupt: 