In [1]:
!pip install -q numpy pandas opencv-python-headless rasterio torch torchvision

In [2]:
# =============================================================================
# ICPR 2026 – Beyond Visible Spectrum: AI for Agriculture
# COMPLETE TRAINING + INFERENCE PIPELINE
#
# CONFIRMED DATASET STRUCTURE:
#   /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/
#     Kaggle_Prepared/
#       train/
#         RGB/   Health_hyper_1.png   rust_hyper_2.png   other_hyper_12.png ...
#         MS/    Health_hyper_1.tif   rust_hyper_2.tif   other_hyper_12.tif ...
#         HS/    Health_hyper_1.tif   rust_hyper_2.tif   other_hyper_12.tif ...
#       val/
#         RGB/   val_02afcb0e.png  ...  (no class in filename)
#         MS/    val_02afcb0e.tif  ...
#         HS/    val_02afcb0e.tif  ...
#
# CLASS EXTRACTION: parsed from filename prefix (case-insensitive)
#   "Health_hyper_1.png"  → Health (label=0)
#   "rust_hyper_2.png"    → Rust   (label=1)
#   "other_hyper_12.png"  → Other  (label=2)
#
# PASTE THIS ENTIRE FILE INTO ONE SINGLE KAGGLE CELL AND RUN.
# Do NOT split across cells — old cached cell state causes assertion errors.
# =============================================================================

import os, sys, random, warnings
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import cv2
import rasterio
from rasterio.errors import NotGeoreferencedWarning

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# ── Reproducibility ─────────────────────────────────────────────────────────
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_GPUS = torch.cuda.device_count()
print(f"[HW] device={device}  n_gpus={N_GPUS}")
if torch.cuda.is_available():
    print(f"[HW] GPU={torch.cuda.get_device_name(0)}  "
          f"VRAM={torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")


# =============================================================================
# SECTION 1 – CONFIGURATION
# =============================================================================

class CFG:
    # Exact strings required by Kaggle submission
    CLASSES     = ["Health", "Rust", "Other"]
    NUM_CLASSES = 3
    CLASS2IDX   = {"Health": 0, "Rust": 1, "Other": 2}
    IDX2CLASS   = {0: "Health", 1: "Rust", 2: "Other"}

    # MS: 5 bands [Blue, Green, Red, RedEdge, NIR]
    # +NDVI +RENDVI injected at runtime → 7 input channels to MS encoder
    MS_BANDS  = 5
    MS_IN_CH  = 7

    # HS: 125 total bands (450–950 nm), noisy front-10 & back-14 removed.
    # EDA peak-variance region: 781–821 nm → absolute band indices 82–92 (11 bands)
    HS_KEY_BANDS = list(range(82, 93))   # 0-based absolute indices
    HS_IN_CH     = len(HS_KEY_BANDS)     # = 11

    IMG_SIZE     = 64       # resize target (H, W) for all modalities

    BATCH_SIZE   = 32
    NUM_WORKERS  = 4
    LR           = 3e-4
    WEIGHT_DECAY = 1e-4
    LABEL_SMOOTH = 0.05
    VAL_FRAC     = 0.15     # fraction of train used as internal validation

    EPOCHS_S1    = 25       # Stage 1: MS baseline
    EPOCHS_S2    = 25       # Stage 2: HS core
    EPOCHS_S3    = 35       # Stage 3: fusion fine-tune
    WARMUP_EP    = 3        # linear LR warmup epochs

    SUBMISSION   = "submission.csv"


# =============================================================================
# SECTION 2 – PATH DETECTION
# =============================================================================

def _has_split(p: Path) -> bool:
    return (p / "train").is_dir() and (p / "val").is_dir()

def find_root() -> Path:
    hardcoded = [
        "/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared",
        "/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026",
        "/kaggle/input/Kaggle_Prepared",
        "/kaggle/working/Kaggle_Prepared",
        "/kaggle/working",
    ]
    for raw in hardcoded:
        p = Path(raw)
        if p.is_dir() and _has_split(p):
            print(f"[PATH] Root (hardcoded): {p}")
            return p

    kaggle_in = Path("/kaggle/input")
    if kaggle_in.is_dir():
        def _walk(base, depth):
            if depth == 0: return None
            try:
                for child in sorted(base.iterdir()):
                    if not child.is_dir(): continue
                    if _has_split(child): return child
                    r = _walk(child, depth - 1)
                    if r: return r
            except PermissionError:
                pass
            return None
        found = _walk(kaggle_in, 6)
        if found:
            print(f"[PATH] Root (recursive search): {found}")
            return found

    print("\n[FATAL] Could not locate dataset. /kaggle/input tree:\n")
    if kaggle_in.is_dir():
        for item in sorted(kaggle_in.rglob("*")):
            d = len(item.relative_to(kaggle_in).parts)
            if d <= 6:
                print("  " * d + item.name + ("/" if item.is_dir() else ""))
    raise FileNotFoundError(
        "Set ROOT manually:\n"
        "  ROOT=Path('/your/path'); TRAIN_DIR=ROOT/'train'; VAL_DIR=ROOT/'val'"
    )

ROOT      = find_root()
TRAIN_DIR = ROOT / "train"
VAL_DIR   = ROOT / "val"
print(f"[PATH] TRAIN_DIR = {TRAIN_DIR}")
print(f"[PATH] VAL_DIR   = {VAL_DIR}")


# =============================================================================
# SECTION 3 – FILENAME-BASED CLASS EXTRACTION & SAMPLE COLLECTION
# =============================================================================
#
# YOUR DATASET STRUCTURE (confirmed from screenshot):
#
#   train/RGB/  ← flat folder, NO class subfolders
#               Files named:  Health_hyper_1.png
#                             rust_hyper_2.png
#                             other_hyper_12.png
#
#   train/MS/   ← same stems, .tif extension
#   train/HS/   ← same stems, .tif extension
#
#   val/RGB/    ← flat folder, filenames have NO class prefix
#               Files named:  val_02afcb0e.png   (class unknown until submission)
#
# CLASS EXTRACTION LOGIC:
#   Split filename stem on "_", check if first token matches a known class
#   (case-insensitive). This correctly handles:
#     "Health_hyper_1"  → "health" → Health (label 0)
#     "rust_hyper_2"    → "rust"   → Rust   (label 1)
#     "other_hyper_12"  → "other"  → Other  (label 2)
#     "val_02afcb0e"    → "val"    → None   (label -1, val set only)
# =============================================================================

# Case-insensitive map: "health"/"HEALTH"/"Health" all → "Health"
_CLS_MAP = {c.lower(): c for c in CFG.CLASSES}

def class_from_stem(stem: str):
    """
    Extract class from filename stem using prefix matching.
    Returns (canonical_class_name, label_int) or (None, -1) if not found.

    Examples:
      "Health_hyper_1"  → ("Health", 0)
      "rust_hyper_2"    → ("Rust",   1)
      "other_hyper_12"  → ("Other",  2)
      "val_02afcb0e"    → (None,    -1)
    """
    prefix = stem.split("_")[0].lower()
    canon  = _CLS_MAP.get(prefix)
    if canon is None:
        return None, -1
    return canon, CFG.CLASS2IDX[canon]


def collect_train_samples(split_dir: Path) -> list:
    """
    Collect training samples from flat RGB folder.
    Class is extracted from filename prefix.
    Records with unrecognised prefix are SKIPPED (never label=-1 in train).
    """
    rgb_dir = split_dir / "RGB"
    ms_dir  = split_dir / "MS"
    hs_dir  = split_dir / "HS"

    if not rgb_dir.is_dir():
        raise RuntimeError(f"[FATAL] {rgb_dir} not found")

    all_pngs = sorted(rgb_dir.glob("*.png")) + sorted(rgb_dir.glob("*.PNG"))
    print(f"[COLLECT] Train: {len(all_pngs)} PNG files found in {rgb_dir}")

    records = []
    skipped = []
    per_cls = defaultdict(int)

    for png in all_pngs:
        stem  = png.stem
        canon, label = class_from_stem(stem)

        if canon is None:
            skipped.append(stem)
            continue   # never add with label=-1 to train

        per_cls[canon] += 1
        records.append({
            "stem":     stem,
            "ms_path":  ms_dir / (stem + ".tif"),
            "hs_path":  hs_dir / (stem + ".tif"),
            "label":    label,      # always 0, 1, or 2 here
            "cls_name": canon,
        })

    print(f"[COLLECT] Train records: {len(records)}  skipped: {len(skipped)}")
    print(f"[COLLECT] Class distribution: {dict(per_cls)}")
    if skipped:
        print(f"[COLLECT] Skipped stems (first 5): {skipped[:5]}")

    return records


def collect_val_samples(split_dir: Path) -> list:
    """
    Collect validation samples from flat RGB folder.
    No class labels available → label=-1 (never used in loss function).
    """
    rgb_dir = split_dir / "RGB"
    ms_dir  = split_dir / "MS"
    hs_dir  = split_dir / "HS"

    if not rgb_dir.is_dir():
        raise RuntimeError(f"[FATAL] {rgb_dir} not found")

    all_pngs = sorted(rgb_dir.glob("*.png")) + sorted(rgb_dir.glob("*.PNG"))
    print(f"[COLLECT] Val: {len(all_pngs)} PNG files found in {rgb_dir}")

    records = []
    for png in all_pngs:
        stem = png.stem
        records.append({
            "stem":     stem,
            "ms_path":  ms_dir / (stem + ".tif"),
            "hs_path":  hs_dir / (stem + ".tif"),
            "label":    -1,         # unknown – never passed to loss
            "cls_name": "__val__",
        })

    print(f"[COLLECT] Val records: {len(records)}")
    return records


# Run collection
train_records = collect_train_samples(TRAIN_DIR)
val_records   = collect_val_samples(VAL_DIR)

# Hard stop if empty
if len(train_records) == 0:
    raise RuntimeError(
        "Zero training records collected.\n"
        f"RGB dir: {TRAIN_DIR/'RGB'}\n"
        "Check that files are named like 'Health_hyper_1.png', 'rust_hyper_2.png'"
    )
if len(val_records) == 0:
    raise RuntimeError(f"Zero val records collected from {VAL_DIR/'RGB'}")

# Final label integrity check – catches any edge case
bad = [r for r in train_records if r["label"] not in {0, 1, 2}]
if bad:
    raise ValueError(
        f"{len(bad)} train records with invalid label. First: {bad[0]}"
    )
print(f"\n[OK] train={len(train_records)}  val={len(val_records)}")
print("[OK] All train labels verified in {0,1,2} – no CUDA scatter crash risk")


# =============================================================================
# SECTION 4 – DATA LOADING UTILITIES
# =============================================================================

def load_ms(path: Path) -> torch.Tensor:
    """
    Load 5-band MS .tif → float32 tensor (5, IMG_SIZE, IMG_SIZE) in [0,1].
    Per-band 1–99th percentile clipping for robust normalization.
    """
    with rasterio.open(str(path)) as src:
        data = src.read().astype(np.float32)          # (bands, H, W)
    for b in range(data.shape[0]):
        lo = np.percentile(data[b], 1)
        hi = np.percentile(data[b], 99)
        data[b] = np.clip((data[b] - lo) / (hi - lo + 1e-8), 0.0, 1.0)
    out = np.stack([
        cv2.resize(data[b], (CFG.IMG_SIZE, CFG.IMG_SIZE),
                   interpolation=cv2.INTER_LINEAR)
        for b in range(data.shape[0])
    ])
    return torch.from_numpy(out)


def load_hs(path: Path) -> torch.Tensor:
    """
    Load only the 11 EDA high-variance HS bands (781–821 nm, indices 82–92).
    Memory saving: 11/125 = 91% vs loading full cube.
    rasterio uses 1-based band indexing → add 1 to each index.
    """
    bands_1idx = [b + 1 for b in CFG.HS_KEY_BANDS]
    with rasterio.open(str(path)) as src:
        data = src.read(bands_1idx).astype(np.float32)  # (11, H, W)
    for b in range(data.shape[0]):
        lo = np.percentile(data[b], 1)
        hi = np.percentile(data[b], 99)
        data[b] = np.clip((data[b] - lo) / (hi - lo + 1e-8), 0.0, 1.0)
    out = np.stack([
        cv2.resize(data[b], (CFG.IMG_SIZE, CFG.IMG_SIZE),
                   interpolation=cv2.INTER_LINEAR)
        for b in range(data.shape[0])
    ])
    return torch.from_numpy(out)


def inject_vi(ms: torch.Tensor) -> torch.Tensor:
    """
    Append NDVI and RENDVI as channels 5 & 6.

    Input:  (B, 5, H, W)  [Blue, Green, Red, RedEdge, NIR]
    Output: (B, 7, H, W)

    EDA finding: NDVI is the most discriminative single MS feature.
    Injecting it directly means early conv filters see it without
    having to re-derive it from raw band ratios.
    """
    RED    = ms[:, 2:3]
    RE     = ms[:, 3:4]
    NIR    = ms[:, 4:5]
    NDVI   = (NIR - RED) / (NIR + RED + 1e-8)
    RENDVI = (NIR - RE)  / (NIR + RE  + 1e-8)
    # Shift from [-1,1] → [0,1] for consistent CNN input range
    NDVI   = (NDVI   + 1.0) / 2.0
    RENDVI = (RENDVI + 1.0) / 2.0
    return torch.cat([ms, NDVI, RENDVI], dim=1)


# =============================================================================
# SECTION 5 – DATASET CLASS
# =============================================================================

class CropDS(Dataset):
    """
    Loads MS + HS per sample.
    VI channels (NDVI/RENDVI) are injected in the training loop, not here,
    so augmentation acts on raw spectral values only.
    Augmentation: H/V flip + 90° rotation. NO spectral perturbation.
    """
    def __init__(self, records: list, augment: bool = False):
        # Safety filter: remove any record with label outside {-1, 0, 1, 2}
        self.records = [r for r in records
                        if r["label"] == -1 or r["label"] in CFG.CLASS2IDX.values()]
        self.augment = augment
        dropped = len(records) - len(self.records)
        if dropped:
            print(f"[DS] Dropped {dropped} records with invalid labels")

    def __len__(self): return len(self.records)

    def _aug(self, ms, hs):
        if random.random() > 0.5:
            ms = torch.flip(ms, [-1]); hs = torch.flip(hs, [-1])
        if random.random() > 0.5:
            ms = torch.flip(ms, [-2]); hs = torch.flip(hs, [-2])
        k = random.randint(0, 3)
        if k:
            ms = torch.rot90(ms, k, [-2, -1])
            hs = torch.rot90(hs, k, [-2, -1])
        return ms, hs

    def __getitem__(self, idx):
        r = self.records[idx]
        try:
            ms = load_ms(r["ms_path"])
        except Exception:
            ms = torch.zeros(CFG.MS_BANDS, CFG.IMG_SIZE, CFG.IMG_SIZE)
        try:
            hs = load_hs(r["hs_path"])
        except Exception:
            hs = torch.zeros(CFG.HS_IN_CH, CFG.IMG_SIZE, CFG.IMG_SIZE)
        if self.augment:
            ms, hs = self._aug(ms, hs)
        return {
            "ms":    ms,
            "hs":    hs,
            "label": torch.tensor(r["label"], dtype=torch.long),
            "stem":  r["stem"],
        }


# =============================================================================
# SECTION 6 – MODEL ARCHITECTURES
# =============================================================================

class DSConv(nn.Module):
    """Depthwise-separable conv: ~3× fewer params than standard conv."""
    def __init__(self, ic, oc, stride=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(ic, ic, 3, stride=stride, padding=1,
                      groups=ic, bias=False),          # depthwise
            nn.Conv2d(ic, oc, 1, bias=False),          # pointwise
            nn.BatchNorm2d(oc),
            nn.ReLU6(inplace=True),
        )
    def forward(self, x): return self.net(x)


def _backbone(in_ch: int) -> nn.Sequential:
    """Shared CNN backbone: in_ch → 256-dim flattened feature."""
    return nn.Sequential(
        nn.Conv2d(in_ch, 32, 3, padding=1, bias=False),
        nn.BatchNorm2d(32), nn.ReLU6(inplace=True),
        DSConv(32,  64,  stride=2),   # 64→32
        DSConv(64,  128, stride=2),   # 32→16
        DSConv(128, 128, stride=2),   # 16→8
        DSConv(128, 256, stride=2),   # 8→4
        nn.AdaptiveAvgPool2d(1),      # 4→1 (global average pool)
        nn.Flatten(),                 # 256
    )


class MSEncoder(nn.Module):
    """Encoder for 7-ch MS input → 128-dim feature vector."""
    def __init__(self):
        super().__init__()
        self.bb   = _backbone(CFG.MS_IN_CH)
        self.head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3))
    def forward(self, x): return self.head(self.bb(x))


class HSEncoder(nn.Module):
    """
    Encoder for 11-ch HS key-band input → 128-dim feature vector.
    Band-attention gate (1×1 conv + sigmoid) reweights the 11 bands
    before spatial processing — lets the network learn which bands
    carry the most disease-specific information (EDA: 781–821 nm region).
    """
    def __init__(self):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Conv2d(CFG.HS_IN_CH, CFG.HS_IN_CH, 1, bias=False),
            nn.Sigmoid())
        self.bb   = _backbone(CFG.HS_IN_CH)
        self.head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3))
    def forward(self, x): return self.head(self.bb(x * self.attn(x)))


class MSModel(nn.Module):
    """Stage 1: MS-only classifier."""
    def __init__(self):
        super().__init__()
        self.encoder = MSEncoder()
        self.clf = nn.Linear(128, CFG.NUM_CLASSES)
    def forward(self, ms, hs=None): return self.clf(self.encoder(ms))


class HSModel(nn.Module):
    """Stage 2: HS-only classifier (11 key bands)."""
    def __init__(self):
        super().__init__()
        self.encoder = HSEncoder()
        self.clf = nn.Linear(128, CFG.NUM_CLASSES)
    def forward(self, ms=None, hs=None): return self.clf(self.encoder(hs))


class FusionModel(nn.Module):
    """
    Stage 3: Late-fusion MS + HS.
    Concat(MS-feat[128], HS-feat[128]) → 256 → FC head → 3 classes.

    Why late fusion:
      MS and HS have different spectral characteristics; mixing them before
      feature extraction destroys modality-specific structure.
      Late fusion lets each encoder specialise, then the FC head learns
      complementary combinations (MS=broad NDVI signal, HS=narrow 781–821nm).
    """
    def __init__(self, ms_w=None, hs_w=None):
        super().__init__()
        self.ms_enc = MSEncoder()
        self.hs_enc = HSEncoder()
        if ms_w: self.ms_enc.load_state_dict(ms_w, strict=False)
        if hs_w: self.hs_enc.load_state_dict(hs_w, strict=False)
        self.head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.4),
            nn.Linear(128, 64),  nn.ReLU(inplace=True),
            nn.Linear(64, CFG.NUM_CLASSES),
        )
    def forward(self, ms, hs):
        return self.head(torch.cat([self.ms_enc(ms), self.hs_enc(hs)], dim=1))


# =============================================================================
# SECTION 7 – TRAINING INFRASTRUCTURE
# =============================================================================

class SmoothCE(nn.Module):
    """
    Label-smoothing cross-entropy.
    targets.clamp() is a hard safety net against any label=-1 reaching GPU.
    """
    def __init__(self, nc=CFG.NUM_CLASSES, smooth=CFG.LABEL_SMOOTH):
        super().__init__()
        self.nc = nc; self.smooth = smooth

    def forward(self, logits, targets):
        targets = targets.clamp(0, self.nc - 1)      # safety clamp
        log_p   = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            t = torch.full_like(log_p, self.smooth / (self.nc - 1))
            t.scatter_(1, targets.unsqueeze(1), 1.0 - self.smooth)
        return -(t * log_p).sum(dim=-1).mean()


def _fwd(model, ms, hs, mode):
    """Run forward pass for given mode, injecting VI channels for MS."""
    ms_vi = inject_vi(ms)
    if   mode == "ms":     return model(ms=ms_vi)
    elif mode == "hs":     return model(hs=hs)
    else:                  return model(ms=ms_vi, hs=hs)


def _filter_batch(ms, hs, label):
    """Drop any sample in a batch whose label is outside {0,1,2}."""
    mask  = (label >= 0) & (label < CFG.NUM_CLASSES)
    return ms[mask], hs[mask], label[mask]


def train_epoch(model, loader, optimizer, criterion, mode):
    model.train()
    loss_sum = correct = total = 0
    for batch in loader:
        ms    = batch["ms"].to(device, non_blocking=True)
        hs    = batch["hs"].to(device, non_blocking=True)
        label = batch["label"].to(device, non_blocking=True)
        ms, hs, label = _filter_batch(ms, hs, label)
        if label.numel() == 0: continue

        optimizer.zero_grad(set_to_none=True)
        out  = _fwd(model, ms, hs, mode)
        loss = criterion(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        loss_sum += loss.item() * label.size(0)
        correct  += (out.argmax(1) == label).sum().item()
        total    += label.size(0)
    return loss_sum / max(total, 1), correct / max(total, 1)


@torch.no_grad()
def eval_epoch(model, loader, criterion, mode):
    model.eval()
    loss_sum = correct = total = 0
    for batch in loader:
        ms    = batch["ms"].to(device, non_blocking=True)
        hs    = batch["hs"].to(device, non_blocking=True)
        label = batch["label"].to(device, non_blocking=True)
        ms, hs, label = _filter_batch(ms, hs, label)
        if label.numel() == 0: continue
        out      = _fwd(model, ms, hs, mode)
        loss     = criterion(out, label)
        loss_sum += loss.item() * label.size(0)
        correct  += (out.argmax(1) == label).sum().item()
        total    += label.size(0)
    return loss_sum / max(total, 1), correct / max(total, 1)


def train_model(model, tr_ld, va_ld, epochs, lr, mode, ckpt):
    """
    Training loop:
      - Linear LR warmup (WARMUP_EP epochs) to stabilise BatchNorm early on
      - Cosine annealing after warmup (smooth decay suits spectral models)
      - AdamW optimizer (clean L2 via weight_decay, no grad-norm confusion)
      - Best checkpoint saved by validation accuracy
    """
    criterion = SmoothCE().to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=CFG.WEIGHT_DECAY)
    scheduler = CosineAnnealingLR(optimizer,
                                  T_max=max(1, epochs - CFG.WARMUP_EP),
                                  eta_min=lr * 0.01)
    best_acc = 0.0
    best_wts = None

    print(f"\n{'─'*62}")
    print(f"  mode={mode.upper()}  epochs={epochs}  lr={lr:.0e}")
    print(f"{'─'*62}")

    for ep in range(1, epochs + 1):
        if ep <= CFG.WARMUP_EP:
            for pg in optimizer.param_groups:
                pg["lr"] = lr * ep / CFG.WARMUP_EP

        tr_l, tr_a = train_epoch(model, tr_ld, optimizer, criterion, mode)
        va_l, va_a = eval_epoch (model, va_ld, criterion, mode)
        if ep > CFG.WARMUP_EP: scheduler.step()

        cur_lr = optimizer.param_groups[0]["lr"]
        print(f"  ep {ep:3d}/{epochs}  "
              f"tr {tr_l:.4f}/{tr_a:.3f}  "
              f"va {va_l:.4f}/{va_a:.3f}  "
              f"lr {cur_lr:.2e}")

        if va_a > best_acc:
            best_acc = va_a
            best_wts = {k: v.clone() for k, v in model.state_dict().items()}
            torch.save(best_wts, ckpt)
            print(f"    ✓ checkpoint saved  val_acc={va_a:.4f}")

    if best_wts:
        model.load_state_dict(best_wts)
    print(f"  Best val_acc: {best_acc:.4f}")
    return model, best_acc


# =============================================================================
# SECTION 8 – DATALOADER SETUP
# =============================================================================

# Stratified 85/15 internal train/val split
random.shuffle(train_records)
buckets = defaultdict(list)
for r in train_records:
    buckets[r["label"]].append(r)

int_train, int_val = [], []
for lbl, recs in buckets.items():
    n_val = max(1, int(len(recs) * CFG.VAL_FRAC))
    int_val.extend(recs[:n_val])
    int_train.extend(recs[n_val:])
random.shuffle(int_train); random.shuffle(int_val)

print(f"\n[DS] int_train={len(int_train)}  int_val={len(int_val)}  test={len(val_records)}")
for nm, recs in [("int_train", int_train), ("int_val", int_val)]:
    d = defaultdict(int)
    for r in recs: d[CFG.IDX2CLASS[r["label"]]] += 1
    print(f"[DS] {nm} class dist: {dict(d)}")

# Final label sanity check before any GPU work
for nm, recs in [("int_train", int_train), ("int_val", int_val)]:
    bad = [r for r in recs if r["label"] not in {0, 1, 2}]
    if bad:
        raise ValueError(f"[FATAL] {nm} still has {len(bad)} invalid labels: {bad[0]}")
print("[DS] Label sanity check PASSED – safe to train")

train_ds = CropDS(int_train,   augment=True)
val_ds   = CropDS(int_val,     augment=False)
test_ds  = CropDS(val_records, augment=False)

kw = dict(num_workers=CFG.NUM_WORKERS, pin_memory=True)
train_ld = DataLoader(train_ds, CFG.BATCH_SIZE, shuffle=True,  drop_last=True,  **kw)
val_ld   = DataLoader(val_ds,   CFG.BATCH_SIZE, shuffle=False, drop_last=False, **kw)
test_ld  = DataLoader(test_ds,  CFG.BATCH_SIZE, shuffle=False, drop_last=False, **kw)
print(f"[DS] batches – train={len(train_ld)}  val={len(val_ld)}  test={len(test_ld)}")


# =============================================================================
# SECTION 9 – STAGE 1: MS BASELINE
# =============================================================================

print("\n" + "="*62)
print("  STAGE 1 – Multispectral Baseline (5 bands + NDVI + RENDVI)")
print("="*62)

model_ms = MSModel().to(device)
if N_GPUS > 1:
    model_ms = nn.DataParallel(model_ms)
    print(f"[HW] DataParallel × {N_GPUS} GPUs")

model_ms, acc_s1 = train_model(model_ms, train_ld, val_ld,
                                CFG.EPOCHS_S1, CFG.LR, "ms", "ckpt_ms.pth")
print(f"\n[S1 DONE] best val_acc = {acc_s1:.4f}")


# =============================================================================
# SECTION 10 – STAGE 2: HS CORE MODEL
# =============================================================================

print("\n" + "="*62)
print("  STAGE 2 – Hyperspectral Core (11 bands, 781–821 nm)")
print("="*62)

model_hs = HSModel().to(device)
if N_GPUS > 1:
    model_hs = nn.DataParallel(model_hs)

model_hs, acc_s2 = train_model(model_hs, train_ld, val_ld,
                                CFG.EPOCHS_S2, CFG.LR, "hs", "ckpt_hs.pth")
print(f"\n[S2 DONE] best val_acc = {acc_s2:.4f}")


# =============================================================================
# SECTION 11 – STAGE 3: FUSION
# =============================================================================

print("\n" + "="*62)
print("  STAGE 3 – MS + HS Late-Fusion Model")
print("="*62)

def _enc_w(m):
    """Unwrap DataParallel and return encoder state dict."""
    core = m.module if isinstance(m, nn.DataParallel) else m
    return core.encoder.state_dict() if hasattr(core, "encoder") else None

model_fusion = FusionModel(ms_w=_enc_w(model_ms),
                           hs_w=_enc_w(model_hs)).to(device)
if N_GPUS > 1:
    model_fusion = nn.DataParallel(model_fusion)

# Phase A: freeze encoders → train only fusion head (fast warm-start)
print("\n[Phase A] Fusion head only – encoders frozen (5 epochs)")
fcore = model_fusion.module if isinstance(model_fusion, nn.DataParallel) else model_fusion
for p in list(fcore.ms_enc.parameters()) + list(fcore.hs_enc.parameters()):
    p.requires_grad = False

model_fusion, _ = train_model(model_fusion, train_ld, val_ld,
                               5, CFG.LR * 2, "fusion", "ckpt_fA.pth")

# Phase B: unfreeze everything → full end-to-end fine-tune at lower LR
print("\n[Phase B] End-to-end fine-tune – all layers unfrozen")
for p in list(fcore.ms_enc.parameters()) + list(fcore.hs_enc.parameters()):
    p.requires_grad = True

model_fusion, acc_s3 = train_model(model_fusion, train_ld, val_ld,
                                    CFG.EPOCHS_S3, CFG.LR / 3,
                                    "fusion", "ckpt_fusion.pth")
print(f"\n[S3 DONE] best val_acc = {acc_s3:.4f}")


# =============================================================================
# SECTION 12 – ABLATION SUMMARY
# =============================================================================

print("\n" + "="*62)
print("  ABLATION SUMMARY (internal val)")
print("="*62)
print(f"  Stage 1  MS  baseline  : {acc_s1:.4f}")
print(f"  Stage 2  HS  core      : {acc_s2:.4f}")
print(f"  Stage 3  Fusion        : {acc_s3:.4f}")
print("  Expected: Fusion ≥ HS > MS  (confirmed by EDA)")
print("="*62)


# =============================================================================
# SECTION 13 – INFERENCE & SUBMISSION
# =============================================================================

@torch.no_grad()
def run_inference(model, loader, mode, tta=True):
    """
    Inference with horizontal-flip TTA (spatial only, no spectral jitter).
    Returns list of {"Id": stem, "Category": class_name}.
    """
    model.eval()
    rows = []
    for batch in loader:
        ms    = batch["ms"].to(device, non_blocking=True)
        hs    = batch["hs"].to(device, non_blocking=True)
        stems = batch["stem"]

        probs = F.softmax(_fwd(model, ms, hs, mode), dim=-1)
        if tta:
            ms_f  = torch.flip(ms, [-1])
            hs_f  = torch.flip(hs, [-1])
            probs = (probs + F.softmax(_fwd(model, ms_f, hs_f, mode), dim=-1)) / 2.0

        preds = probs.argmax(dim=1).cpu().numpy()
        for stem, pred in zip(stems, preds):
            rows.append({"Id": stem, "Category": CFG.IDX2CLASS[int(pred)]})
    return rows


print("\n[INFERENCE] Running on val set with fusion model + TTA…")
rows = run_inference(model_fusion, test_ld, "fusion", tta=True)
sub  = pd.DataFrame(rows)

print(f"[INFERENCE] Rows      : {len(sub)}")
print(f"[INFERENCE] Class dist:\n{sub['Category'].value_counts()}")

# Verify no invalid class names leaked into output
invalid = set(sub["Category"].unique()) - set(CFG.CLASSES)
assert not invalid, f"Invalid class names in submission: {invalid}"

sub.to_csv(CFG.SUBMISSION, index=False)
print(f"[INFERENCE] Saved → {CFG.SUBMISSION}")
print(sub.head(8).to_string(index=False))


# =============================================================================
# SECTION 14 – SELF-EVALUATION (if result.csv present in val/)
# =============================================================================

result_csv = VAL_DIR / "result.csv"
if result_csv.exists():
    print("\n" + "="*62)
    print("  SELF-EVALUATION  (result.csv found)")
    print("="*62)
    gt    = pd.read_csv(result_csv)
    id_c  = next((c for c in gt.columns if c.lower() in {"id","filename","stem"}), None)
    cat_c = next((c for c in gt.columns if c.lower() in {"category","label","class"}), None)
    if id_c and cat_c:
        gt  = gt.rename(columns={id_c: "Id", cat_c: "gt"})
        mg  = sub.merge(gt[["Id","gt"]], on="Id")
        acc = (mg["Category"] == mg["gt"]).mean()
        print(f"  Overall accuracy : {acc:.4f}  (n={len(mg)})")
        for cls in CFG.CLASSES:
            m = mg["gt"] == cls
            c = (mg.loc[m, "Category"] == cls).sum()
            print(f"    {cls:8s}: {c}/{m.sum()} = {c/max(m.sum(),1):.4f}")
    else:
        print(f"  Unexpected GT columns: {list(gt.columns)}")
else:
    print("\n[INFO] No result.csv in val/ – Kaggle-style hidden labels.")
    print(f"[INFO] Upload {CFG.SUBMISSION} to the competition leaderboard.")


print("\n" + "="*62)
print("  PIPELINE COMPLETE")
print(f"  MS  baseline  val_acc : {acc_s1:.4f}")
print(f"  HS  core      val_acc : {acc_s2:.4f}")
print(f"  Fusion        val_acc : {acc_s3:.4f}")
print(f"  Output                : {CFG.SUBMISSION}")
print("="*62)

[HW] device=cpu  n_gpus=0
[PATH] Root (hardcoded): /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared
[PATH] TRAIN_DIR = /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared/train
[PATH] VAL_DIR   = /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared/val
[COLLECT] Train: 600 PNG files found in /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared/train/RGB
[COLLECT] Train records: 600  skipped: 0
[COLLECT] Class distribution: {'Health': 200, 'Other': 200, 'Rust': 200}
[COLLECT] Val: 300 PNG files found in /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared/val/RGB
[COLLECT] Val records: 300

[OK] train=600  val=300
[OK] All train labels verified in {0,1,2} – no CUDA scatter crash risk

[DS] int_train=510  int_val=90  test=300
[DS] int_train class dist: {'Rust': 170, 'Other': 170, 'Health': 170}
[DS] int_val class dist: {'Other': 30, 'Health': 30, 'Rust': 30}
[