# SVAMITVA ‚Äî DGX Training Pipeline (MAPC Sub-Maps)

**Target:** DGX Server ‚Äî single GPU with the most free VRAM.
**DATA path:** `/jupyter/sods.user04/DATA/MAPC` (pre-clipped 512√ó512 sub-maps)
**Checkpoints:** `/jupyter/sods.user04/check/MAP_best.pt` / `MAP_latest.pt`

Trains each sub-map (MAP1.1, MAP1.2, ‚Ä¶) individually in sequence. One global checkpoint. Asks permission before moving to next parent map.

---
## Cell 1 ‚Äî Setup

In [None]:
import os, sys, time, torch
import torch.nn as nn
from pathlib import Path
import numpy as np

# ‚îÄ‚îÄ GPU: pick the one with the most free VRAM ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
def get_best_gpu():
    import subprocess
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
        stdout=subprocess.PIPE, encoding='utf-8'
    )
    free_memories = [int(x) for x in result.stdout.strip().split('\n')]
    best_idx = max(range(len(free_memories)), key=lambda i: free_memories[i])
    free_gb = free_memories[best_idx] / 1024
    print(f"  GPU {best_idx} selected ‚Äî {free_gb:.1f} GB free (max of {len(free_memories)} GPUs)")
    return str(best_idx)

gpu_id = get_best_gpu()
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

device = torch.device("cuda")
torch.backends.cudnn.benchmark = True

# ‚îÄ‚îÄ Project root discovery ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
NOTEBOOK_DIR = Path.cwd()
for parent in [NOTEBOOK_DIR] + list(NOTEBOOK_DIR.parents):
    if (parent / "models").exists() and (parent / "models/__init__.py").exists():
        NOTEBOOK_DIR = parent
        break
if str(NOTEBOOK_DIR) not in sys.path:
    sys.path.insert(0, str(NOTEBOOK_DIR))

# ‚îÄ‚îÄ Directories ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
DATA_DIR = Path("/jupyter/sods.user04/DATA/MAPC")
if not DATA_DIR.exists():
    DATA_DIR = Path("/DATA/MAPC")
CKPT_DIR = Path("/jupyter/sods.user04/check")
CKPT_DIR.mkdir(parents=True, exist_ok=True)
(NOTEBOOK_DIR / "logs").mkdir(exist_ok=True)

# ‚îÄ‚îÄ Config ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
CONFIG = dict(
    backbone              = "resnet50",
    pretrained            = True,
    image_size            = 512,
    batch_size            = 16,          # Adjust per GPU VRAM (16 for 24GB, 32 for 40GB+)
    epochs_per_map        = 15,          # More epochs for better convergence
    learning_rate         = 3e-4,
    weight_decay          = 1e-4,
    num_workers           = 4,
    mixed_precision       = True,
    gradient_clip         = 0.5,
    building_weight       = 1.0,
    roof_weight           = 0.5,
    road_weight           = 1.0,
    waterbody_weight      = 1.2,
    road_centerline_weight= 1.0,
    waterbody_line_weight = 1.2,
    waterbody_point_weight= 1.5,
    utility_line_weight   = 1.2,
    utility_poly_weight   = 1.3,
    bridge_weight         = 1.5,
    railway_weight        = 1.3,
)

TARGET_KEYS = [
    "building_mask", "road_mask", "road_centerline_mask",
    "waterbody_mask", "waterbody_line_mask", "waterbody_point_mask",
    "utility_line_mask", "utility_poly_mask",
    "bridge_mask", "railway_mask", "roof_type_mask",
]

print(f"‚úÖ Setup | GPU: {gpu_id} | DATA: {DATA_DIR} (exists={DATA_DIR.exists()})")

---
## Cell 2 ‚Äî Training Engine

In [None]:
from pathlib import Path
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
try:
    from torch.amp import GradScaler, autocast
except ImportError:
    from torch.cuda.amp import GradScaler, autocast

from models.feature_extractor import FeatureExtractor
from models.losses import MultiTaskLoss
from training.metrics import MetricTracker
from data.dataset import SvamitvaDataset
from data.augmentation import get_train_transforms

# --- Single global checkpoint files ---
BEST_CKPT = CKPT_DIR / "MAP_best.pt"
LATEST_CKPT = CKPT_DIR / "MAP_latest.pt"


def move_targets(batch):
    return {k: batch[k].to(device, non_blocking=True) for k in TARGET_KEYS if k in batch}


def build_model(load_from: Path = None):
    m = FeatureExtractor(
        backbone=CONFIG["backbone"],
        pretrained=CONFIG["pretrained"],
        num_roof_classes=5,
    )

    if load_from and load_from.exists():
        state = torch.load(load_from, map_location="cpu", weights_only=False)
        weights = state.get("model") or state.get("model_state_dict") or state
        m.load_state_dict(weights, strict=False)
        print(f"  Loaded weights from: {load_from.name}")

    return m.to(device)


def train_submap(sub_name: str, model_w, optimizer, scheduler, scaler, best_iou):
    """Train one sub-map (e.g. MAP1.42) with AMP on GPU."""
    torch.cuda.empty_cache()
    sub_dir = DATA_DIR / sub_name
    if not sub_dir.exists():
        print(f"  [SKIP] {sub_dir} not found")
        return model_w, optimizer, scheduler, best_iou, scaler

    ds = SvamitvaDataset(
        root_dir=DATA_DIR,
        image_size=CONFIG["image_size"],
        transform=get_train_transforms(CONFIG["image_size"]),
        mode="train",
    )
    ds.samples = [s for s in ds.samples if s["map_name"] == sub_name]

    if not ds.samples:
        print(f"  [SKIP] {sub_name}: 0 tiles")
        return model_w, optimizer, scheduler, best_iou, scaler

    loader = DataLoader(
        ds, batch_size=CONFIG["batch_size"],
        shuffle=True, num_workers=CONFIG["num_workers"],
        pin_memory=True, drop_last=False,
    )

    loss_fn = MultiTaskLoss(
        **{k: v for k, v in CONFIG.items() if k.endswith("_weight")}
    ).to(device)

    use_amp = CONFIG["mixed_precision"] and device.type == "cuda"

    for epoch in range(1, CONFIG["epochs_per_map"] + 1):
        model_w.train()
        tracker = MetricTracker()
        run_loss, n_steps, t0 = 0.0, 0, time.time()

        try:
            for batch in loader:
                images = batch["image"].to(device, non_blocking=True)
                targets = move_targets(batch)
                optimizer.zero_grad(set_to_none=True)

                if not torch.isfinite(images).all():
                    continue

                if use_amp:
                    with autocast(device_type="cuda", enabled=True):
                        preds = model_w(images)
                        total_loss, _ = loss_fn(preds, targets)
                    if not torch.isfinite(total_loss):
                        continue
                    scaler.scale(total_loss).backward()
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model_w.parameters(), CONFIG["gradient_clip"])
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    preds = model_w(images)
                    total_loss, _ = loss_fn(preds, targets)
                    if not torch.isfinite(total_loss):
                        continue
                    total_loss.backward()
                    if CONFIG["gradient_clip"] > 0:
                        nn.utils.clip_grad_norm_(model_w.parameters(), CONFIG["gradient_clip"])
                    optimizer.step()

                run_loss += total_loss.item()
                tracker.update(preds, targets)
                n_steps += 1

        except KeyboardInterrupt:
            print(f"\n  [INTERRUPT] Saving emergency checkpoint...")
            save_inference_checkpoint(model_w, LATEST_CKPT)
            raise

        scheduler.step()
        metrics = tracker.compute()
        avg_loss = run_loss / max(n_steps, 1)
        avg_iou = metrics.get("avg_iou", 0.0)

        print(
            f"    Epoch {epoch}/{CONFIG['epochs_per_map']} | "
            f"loss: {avg_loss:.4f} | iou: {avg_iou:.4f} | {time.time()-t0:.0f}s"
        )

        # Always save latest
        save_training_checkpoint(model_w, optimizer, epoch, sub_name, best_iou, LATEST_CKPT)
        save_inference_checkpoint(model_w, LATEST_CKPT)

        # Update best if improved
        if avg_iou > best_iou:
            best_iou = avg_iou
            save_training_checkpoint(model_w, optimizer, epoch, sub_name, best_iou, BEST_CKPT)
            save_inference_checkpoint(model_w, BEST_CKPT)
            print(f"    \u2192 New best! IoU = {best_iou:.4f}")

    return model_w, optimizer, scheduler, best_iou, scaler


def train_parent_map(parent_map: str, resume_from: Path = None):
    """Train each sub-map (MAP1.1, MAP1.2, ...) individually in sequence.
    All checkpoints go to MAP_best.pt / MAP_latest.pt."""

    sub_maps = sorted(
        [d.name for d in DATA_DIR.iterdir()
         if d.is_dir() and d.name.startswith(parent_map + ".")],
        key=lambda n: int(n.split(".")[-1])
    )
    if not sub_maps:
        print(f"[SKIP] No sub-maps for {parent_map}")
        return resume_from

    print(f"\n{'='*70}")
    print(f"  Parent map : {parent_map}  ({len(sub_maps)} sub-maps)")
    print(f"  Sub-maps   : {sub_maps[0]} \u2192 {sub_maps[-1]}")
    print(f"  Resume     : {resume_from.name if resume_from and resume_from.exists() else 'SCRATCH'}")
    print(f"  Saving to  : {BEST_CKPT.name} / {LATEST_CKPT.name}")
    print(f"{'='*70}")

    model_w = build_model(load_from=resume_from)

    optimizer = torch.optim.AdamW(
        model_w.parameters(),
        lr=CONFIG["learning_rate"],
        weight_decay=CONFIG["weight_decay"]
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=CONFIG["epochs_per_map"], eta_min=1e-6
    )
    scaler = GradScaler(enabled=CONFIG["mixed_precision"] and device.type == "cuda")

    # Load best_iou from existing checkpoint if resuming
    best_iou = 0.0
    if resume_from and resume_from.exists():
        try:
            train_ckpt = resume_from.parent / (resume_from.stem + ".train.pt")
            if train_ckpt.exists():
                st = torch.load(train_ckpt, map_location="cpu", weights_only=False)
                best_iou = st.get("best_iou", 0.0)
                print(f"  Resuming with best_iou = {best_iou:.4f}")
        except Exception:
            pass

    try:
        for i, sub_name in enumerate(sub_maps, 1):
            print(f"\n  [{i}/{len(sub_maps)}] Training {sub_name} \u2026")
            model_w, optimizer, scheduler, best_iou, scaler = train_submap(
                sub_name, model_w, optimizer, scheduler, scaler, best_iou
            )
    except (Exception, KeyboardInterrupt) as e:
        print(f"\n  \u26a0\ufe0f EMERGENCY SAVE: {e}")
        save_inference_checkpoint(model_w, CKPT_DIR / "MAP_crash_backup.pt")
        raise

    print(f"\n  \u2705 {parent_map} complete \u2014 best IoU: {best_iou:.4f}")
    return BEST_CKPT if BEST_CKPT.exists() else LATEST_CKPT


# ----------------------------
# CHECKPOINT HELPERS
# ----------------------------

def save_training_checkpoint(model, optimizer, epoch, map_name, best_iou, path):
    inner = model.module if hasattr(model, "module") else model
    train_path = path.parent / (path.stem + ".train.pt")
    ckpt = {
        "model_state_dict": inner.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "map_name": map_name,
        "best_iou": best_iou,
    }
    torch.save(ckpt, train_path)


def save_inference_checkpoint(model, path):
    inner = model.module if hasattr(model, "module") else model
    torch.save(
        {k: v.cpu() for k, v in inner.state_dict().items()},
        path
    )


print("Training engine ready \u2713")


---
## Cell 3 ‚Äî Execute Training
Trains each sub-map individually. One global `MAP_best.pt` / `MAP_latest.pt`. Asks permission before next parent map.

In [None]:
# Discover sub-map folders and group by parent map
sub_folders = sorted([
    d.name for d in DATA_DIR.iterdir()
    if d.is_dir() and d.name.startswith("MAP") and "." in d.name
])

parent_maps = []
seen = set()
for name in sub_folders:
    parent = name.split(".")[0]
    if parent not in seen:
        seen.add(parent)
        parent_maps.append(parent)
parent_maps.sort()

sub_counts = {p: sum(1 for n in sub_folders if n.startswith(p + ".")) for p in parent_maps}
print(f"üöÄ Found {len(parent_maps)} parent maps: {[f'{p} ({sub_counts[p]} sub-maps)' for p in parent_maps]}")
print(f"Checkpoints: {BEST_CKPT} / {LATEST_CKPT}\n")

prev_ckpt = BEST_CKPT if BEST_CKPT.exists() else None

for idx, p_name in enumerate(parent_maps):
    # Ask permission before each parent map (except the first)
    if idx > 0:
        answer = input(f"\nüîî Continue to {p_name} ({sub_counts[p_name]} sub-maps)? [yes/no]: ").strip().lower()
        if answer not in ("yes", "y"):
            print(f"‚õî Stopped before {p_name}. Checkpoints saved.")
            break

    print(f"\n‚è≥ Training {p_name} ({sub_counts[p_name]} sub-maps individually)...")
    ckpt = train_parent_map(p_name, resume_from=prev_ckpt)
    if ckpt and Path(ckpt).exists():
        prev_ckpt = ckpt
    else:
        print(f"‚ùå {p_name} failed")

print("\n*** DGX TRAINING COMPLETE ***")
print(f"Best checkpoint : {BEST_CKPT}  (exists={BEST_CKPT.exists()})")
print(f"Latest checkpoint: {LATEST_CKPT}  (exists={LATEST_CKPT.exists()})")