In [2]:
!pip install -q timm rasterio opencv-python-headless

In [1]:
# =============================================================================
# ICPR 2026 – Beyond Visible Spectrum: AI for Agriculture
# OPTIMIZED TRAINING + INFERENCE PIPELINE
#
# SUBMISSION FORMAT:
#   Id          → filename WITH .png extension  (e.g. val_02afcb0e.png)
#   Category    → Health | Rust | Other
#
# DATASET LAYOUT (confirmed):
#   train/RGB/  Health_hyper_1.png | rust_hyper_2.png | other_hyper_12.png
#   train/MS/   same stems, .tif
#   train/HS/   same stems, .tif  (125 bands, 450–950 nm)
#   val/RGB/    val_02afcb0e.png  (no class in filename)
#   val/MS/     val_02afcb0e.tif
#   val/HS/     val_02afcb0e.tif
#
# PASTE ENTIRE FILE INTO ONE SINGLE KAGGLE CELL AND RUN.
# =============================================================================

# ── Install timm if not present ──────────────────────────────────────────────
import subprocess, sys
try:
    import timm
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "timm"])
    import timm

import os, random, warnings
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import cv2
import rasterio
from rasterio.errors import NotGeoreferencedWarning

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.amp import GradScaler, autocast

warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# ── Reproducibility ──────────────────────────────────────────────────────────
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = False

device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_GPUS  = torch.cuda.device_count()
USE_AMP = torch.cuda.is_available()
print(f"[HW] device={device}  GPUs={N_GPUS}  AMP={USE_AMP}")
if torch.cuda.is_available():
    print(f"[HW] GPU={torch.cuda.get_device_name(0)}  "
          f"VRAM={torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")


# =============================================================================
# SECTION 1 – CONFIGURATION
# =============================================================================

class CFG:
    CLASSES     = ["Health", "Rust", "Other"]
    NUM_CLASSES = 3
    CLASS2IDX   = {"Health": 0, "Rust": 1, "Other": 2}
    IDX2CLASS   = {0: "Health", 1: "Rust", 2: "Other"}

    MS_BANDS  = 5
    MS_IN_CH  = 7       # 5 bands + NDVI + RENDVI

    # 40 HS bands across 4 wavelength clusters (700–860 nm)
    # Cluster A 700–740nm idx 62–71 | B 740–780nm idx 72–81
    # Cluster C 780–820nm idx 82–91 | D 820–860nm idx 92–101
    HS_KEY_BANDS = (list(range(62, 72)) + list(range(72, 82)) +
                    list(range(82, 92)) + list(range(92, 102)))
    HS_IN_CH     = 40

    IMG_SIZE     = 64

    BATCH_SIZE   = 24
    GRAD_ACCUM   = 2
    NUM_WORKERS  = 4
    LR           = 2e-4
    LR_MIN       = 1e-6
    WEIGHT_DECAY = 1e-2
    LABEL_SMOOTH = 0.10
    FOCAL_GAMMA  = 2.0
    MIXUP_ALPHA  = 0.2
    BAND_DROP_P  = 0.15
    VAL_FRAC     = 0.15
    WARMUP_STEPS = 100

    EPOCHS_S1_MS = 15
    EPOCHS_S1_HS = 15
    EPOCHS_S2    = 20
    EPOCHS_S3    = 30

    SUBMISSION   = "/kaggle/working/submission.csv"


# =============================================================================
# SECTION 2 – PATH DETECTION
# =============================================================================

def _has_split(p):
    return (p / "train").is_dir() and (p / "val").is_dir()

def find_root():
    hardcoded = [
        "/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared",
        "/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026",
        "/kaggle/input/Kaggle_Prepared",
        "/kaggle/working/Kaggle_Prepared",
        "/kaggle/working",
    ]
    for raw in hardcoded:
        p = Path(raw)
        if p.is_dir() and _has_split(p):
            print(f"[PATH] Root: {p}")
            return p
    base = Path("/kaggle/input")
    if base.is_dir():
        def _walk(d, depth):
            if depth == 0: return None
            try:
                for c in sorted(d.iterdir()):
                    if c.is_dir():
                        if _has_split(c): return c
                        r = _walk(c, depth - 1)
                        if r: return r
            except PermissionError: pass
        found = _walk(base, 6)
        if found:
            print(f"[PATH] Root (recursive): {found}")
            return found
    raise FileNotFoundError("Cannot find dataset root with train/ and val/")

ROOT      = find_root()
TRAIN_DIR = ROOT / "train"
VAL_DIR   = ROOT / "val"
print(f"[PATH] TRAIN={TRAIN_DIR}")
print(f"[PATH] VAL  ={VAL_DIR}")


# =============================================================================
# SECTION 3 – SAMPLE COLLECTION
# =============================================================================

_CLS_MAP = {c.lower(): c for c in CFG.CLASSES}

def class_from_stem(stem):
    """Extract class from filename prefix: 'Health_hyper_1' → ('Health', 0)"""
    prefix = stem.split("_")[0].lower()
    canon  = _CLS_MAP.get(prefix)
    if canon is None:
        return None, -1
    return canon, CFG.CLASS2IDX[canon]


def collect_train(split_dir):
    rgb_dir = split_dir / "RGB"
    ms_dir  = split_dir / "MS"
    hs_dir  = split_dir / "HS"
    pngs    = sorted(rgb_dir.glob("*.png")) + sorted(rgb_dir.glob("*.PNG"))
    print(f"[COLLECT] Train: {len(pngs)} PNG files")
    records, per_cls = [], defaultdict(int)
    for png in pngs:
        stem = png.stem
        canon, label = class_from_stem(stem)
        if canon is None:
            continue
        per_cls[canon] += 1
        records.append({
            "stem":     stem,
            "filename": png.name,          # stem + .png extension
            "ms_path":  ms_dir / (stem + ".tif"),
            "hs_path":  hs_dir / (stem + ".tif"),
            "label":    label,
            "cls_name": canon,
        })
    print(f"[COLLECT] Train records={len(records)}  dist={dict(per_cls)}")
    return records


def collect_val(split_dir):
    rgb_dir = split_dir / "RGB"
    ms_dir  = split_dir / "MS"
    hs_dir  = split_dir / "HS"
    pngs    = sorted(rgb_dir.glob("*.png")) + sorted(rgb_dir.glob("*.PNG"))
    print(f"[COLLECT] Val: {len(pngs)} PNG files")
    records = []
    for png in pngs:
        stem = png.stem
        records.append({
            "stem":     stem,
            "filename": png.name,          # e.g. val_02afcb0e.png
            "ms_path":  ms_dir / (stem + ".tif"),
            "hs_path":  hs_dir / (stem + ".tif"),
            "label":    -1,
            "cls_name": "__val__",
        })
    print(f"[COLLECT] Val records={len(records)}")
    return records


train_records = collect_train(TRAIN_DIR)
val_records   = collect_val(VAL_DIR)

# Label integrity check
bad = [r for r in train_records if r["label"] not in {0, 1, 2}]
if bad:
    raise ValueError(f"{len(bad)} train records with invalid label: {bad[0]}")
if len(train_records) == 0:
    raise RuntimeError("Zero training records. Check filenames like 'Health_hyper_1.png'")
if len(val_records) == 0:
    raise RuntimeError("Zero val records found.")
print(f"[OK] train={len(train_records)}  val={len(val_records)}")

# Class weights for focal loss
cls_counts  = defaultdict(int)
for r in train_records: cls_counts[r["label"]] += 1
cls_weights = torch.tensor(
    [len(train_records) / (CFG.NUM_CLASSES * cls_counts.get(i, 1))
     for i in range(CFG.NUM_CLASSES)], dtype=torch.float32)
print(f"[INFO] Class weights: {cls_weights.numpy().round(3)}")


# =============================================================================
# SECTION 4 – DATA LOADING
# =============================================================================

def load_ms(path):
    try:
        with rasterio.open(str(path)) as src:
            data = src.read().astype(np.float32)
    except Exception:
        return torch.zeros(CFG.MS_BANDS, CFG.IMG_SIZE, CFG.IMG_SIZE)
    for b in range(data.shape[0]):
        lo, hi = np.percentile(data[b], 1), np.percentile(data[b], 99)
        data[b] = np.clip((data[b] - lo) / (hi - lo + 1e-8), 0., 1.)
    out = np.stack([cv2.resize(data[b], (CFG.IMG_SIZE, CFG.IMG_SIZE),
                               interpolation=cv2.INTER_LINEAR)
                    for b in range(data.shape[0])])
    return torch.from_numpy(out)


def load_hs(path, band_dropout=False):
    bands_1idx = [b + 1 for b in CFG.HS_KEY_BANDS]
    try:
        with rasterio.open(str(path)) as src:
            data = src.read(bands_1idx).astype(np.float32)
    except Exception:
        return torch.zeros(CFG.HS_IN_CH, CFG.IMG_SIZE, CFG.IMG_SIZE)
    for b in range(data.shape[0]):
        lo, hi = np.percentile(data[b], 1), np.percentile(data[b], 99)
        data[b] = np.clip((data[b] - lo) / (hi - lo + 1e-8), 0., 1.)
    if band_dropout and random.random() < CFG.BAND_DROP_P:
        data[random.randint(0, data.shape[0] - 1)] = 0.0
    out = np.stack([cv2.resize(data[b], (CFG.IMG_SIZE, CFG.IMG_SIZE),
                               interpolation=cv2.INTER_LINEAR)
                    for b in range(data.shape[0])])
    return torch.from_numpy(out)


def inject_vi(ms):
    """(B,5,H,W) → (B,7,H,W) by appending NDVI and RENDVI channels."""
    RED    = ms[:, 2:3]
    RE     = ms[:, 3:4]
    NIR    = ms[:, 4:5]
    NDVI   = (NIR - RED) / (NIR + RED + 1e-8)
    RENDVI = (NIR - RE)  / (NIR + RE  + 1e-8)
    return torch.cat([ms, (NDVI+1)/2, (RENDVI+1)/2], dim=1)


# =============================================================================
# SECTION 5 – DATASET
# =============================================================================

class CropDS(Dataset):
    def __init__(self, records, augment=False):
        self.records = [r for r in records
                        if r["label"] == -1 or r["label"] in {0, 1, 2}]
        self.augment = augment

    def __len__(self): return len(self.records)

    def _aug(self, ms, hs):
        if random.random() > 0.5: ms, hs = torch.flip(ms,[-1]), torch.flip(hs,[-1])
        if random.random() > 0.5: ms, hs = torch.flip(ms,[-2]), torch.flip(hs,[-2])
        k = random.randint(0, 3)
        if k:
            ms = torch.rot90(ms, k, [-2,-1])
            hs = torch.rot90(hs, k, [-2,-1])
        return ms, hs

    def __getitem__(self, idx):
        r  = self.records[idx]
        ms = load_ms(r["ms_path"])
        hs = load_hs(r["hs_path"], band_dropout=self.augment)
        if self.augment:
            ms, hs = self._aug(ms, hs)
        return {
            "ms":       ms,
            "hs":       hs,
            "label":    torch.tensor(r["label"], dtype=torch.long),
            "stem":     r["stem"],
            "filename": r["filename"],     # includes .png extension
        }


def mixup_batch(ms, hs, labels, alpha=CFG.MIXUP_ALPHA):
    lam  = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
    idx  = torch.randperm(ms.size(0), device=ms.device)
    return lam*ms + (1-lam)*ms[idx], lam*hs + (1-lam)*hs[idx], labels, labels[idx], lam


# =============================================================================
# SECTION 6 – MODELS
# =============================================================================

class SESpectralAttention(nn.Module):
    """Squeeze-and-Excitation block for spectral channel reweighting."""
    def __init__(self, n_ch, reduction=4):
        super().__init__()
        hidden = max(n_ch // reduction, 4)
        self.net = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(),
            nn.Linear(n_ch, hidden), nn.ReLU(inplace=True),
            nn.Linear(hidden, n_ch), nn.Sigmoid(),
        )
    def forward(self, x):
        return x * self.net(x).unsqueeze(-1).unsqueeze(-1)


class GatedFusion(nn.Module):
    """Learns per-sample dynamic importance of MS vs HS features."""
    def __init__(self, feat_dim=256):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(feat_dim*2, feat_dim), nn.ReLU(inplace=True),
            nn.Linear(feat_dim, 2), nn.Softmax(dim=-1),
        )
        self.proj = nn.Linear(feat_dim, feat_dim)
    def forward(self, ms_f, hs_f):
        w = self.gate(torch.cat([ms_f, hs_f], dim=1))
        return self.proj(w[:,0:1]*ms_f + w[:,1:2]*hs_f)


def _make_efficientnet(in_ch):
    model   = timm.create_model("efficientnet_b0", pretrained=True,
                                 num_classes=0, global_pool="avg")
    orig_w  = model.conv_stem.weight.data          # (32, 3, 3, 3)
    avg_w   = orig_w.mean(dim=1, keepdim=True)     # (32, 1, 3, 3)
    new_w   = avg_w.repeat(1, in_ch, 1, 1) * (3.0 / in_ch)
    model.conv_stem = nn.Conv2d(in_ch, 32, 3, stride=2, padding=1, bias=False)
    model.conv_stem.weight = nn.Parameter(new_w)
    return model, model.num_features               # 1280 for B0


class MSEncoder(nn.Module):
    def __init__(self, feat_dim=256):
        super().__init__()
        self.backbone, out_dim = _make_efficientnet(CFG.MS_IN_CH)
        self.proj = nn.Sequential(nn.Linear(out_dim, feat_dim),
                                  nn.LayerNorm(feat_dim), nn.GELU(), nn.Dropout(0.3))
    def forward(self, x): return self.proj(self.backbone(x))
    def freeze(self):
        for p in self.backbone.parameters(): p.requires_grad = False
    def unfreeze_top(self, n=4):
        for blk in list(self.backbone.blocks)[-n:]:
            for p in blk.parameters(): p.requires_grad = True
        for p in self.backbone.conv_head.parameters(): p.requires_grad = True
        for p in self.backbone.bn2.parameters():       p.requires_grad = True
    def unfreeze_all(self):
        for p in self.backbone.parameters(): p.requires_grad = True


class HSEncoder(nn.Module):
    def __init__(self, feat_dim=256):
        super().__init__()
        self.se_attn  = SESpectralAttention(CFG.HS_IN_CH, reduction=4)
        self.backbone, out_dim = _make_efficientnet(CFG.HS_IN_CH)
        self.proj = nn.Sequential(nn.Linear(out_dim, feat_dim),
                                  nn.LayerNorm(feat_dim), nn.GELU(), nn.Dropout(0.3))
    def forward(self, x): return self.proj(self.backbone(self.se_attn(x)))
    def freeze(self):
        for p in self.backbone.parameters(): p.requires_grad = False
    def unfreeze_top(self, n=4):
        for blk in list(self.backbone.blocks)[-n:]:
            for p in blk.parameters(): p.requires_grad = True
        for p in self.backbone.conv_head.parameters(): p.requires_grad = True
        for p in self.backbone.bn2.parameters():       p.requires_grad = True
    def unfreeze_all(self):
        for p in self.backbone.parameters(): p.requires_grad = True


class MSModel(nn.Module):
    def __init__(self, feat_dim=256):
        super().__init__()
        self.encoder = MSEncoder(feat_dim)
        self.clf      = nn.Linear(feat_dim, CFG.NUM_CLASSES)
    def forward(self, ms, hs=None): return self.clf(self.encoder(ms))


class HSModel(nn.Module):
    def __init__(self, feat_dim=256):
        super().__init__()
        self.encoder = HSEncoder(feat_dim)
        self.clf      = nn.Linear(feat_dim, CFG.NUM_CLASSES)
    def forward(self, ms=None, hs=None): return self.clf(self.encoder(hs))


class FusionModel(nn.Module):
    def __init__(self, feat_dim=256, ms_w=None, hs_w=None):
        super().__init__()
        self.ms_enc = MSEncoder(feat_dim)
        self.hs_enc = HSEncoder(feat_dim)
        if ms_w: self.ms_enc.load_state_dict(ms_w, strict=False)
        if hs_w: self.hs_enc.load_state_dict(hs_w, strict=False)
        self.fusion = GatedFusion(feat_dim)
        self.clf    = nn.Sequential(
            nn.Linear(feat_dim, 128), nn.GELU(), nn.Dropout(0.4),
            nn.Linear(128, CFG.NUM_CLASSES),
        )
    def forward(self, ms, hs):
        return self.clf(self.fusion(self.ms_enc(ms), self.hs_enc(hs)))


# =============================================================================
# SECTION 7 – LOSS
# =============================================================================

class FocalLoss(nn.Module):
    def __init__(self, gamma=CFG.FOCAL_GAMMA, weights=None, smooth=CFG.LABEL_SMOOTH):
        super().__init__()
        self.gamma  = gamma
        self.smooth = smooth
        self.register_buffer("weights", weights)

    def forward(self, logits, targets):
        targets = targets.clamp(0, CFG.NUM_CLASSES - 1)
        log_p   = F.log_softmax(logits, dim=-1)
        p       = log_p.exp()
        with torch.no_grad():
            t = torch.full_like(log_p, self.smooth / (CFG.NUM_CLASSES - 1))
            t.scatter_(1, targets.long().unsqueeze(1), 1.0 - self.smooth)
        p_true  = (p * t).sum(dim=-1, keepdim=True)
        focal_w = (1.0 - p_true).pow(self.gamma)
        loss    = -(t * log_p).sum(dim=-1) * focal_w.squeeze(-1)
        if self.weights is not None:
            loss = loss * self.weights[targets.long()]
        return loss.mean()


def mixup_loss(criterion, logits, la, lb, lam):
    return lam * criterion(logits, la) + (1 - lam) * criterion(logits, lb)


# =============================================================================
# SECTION 8 – TRAINING LOOP
# =============================================================================

def warmup_lr(optimizer, step, warmup_steps, base_lr):
    if step < warmup_steps:
        lr = base_lr * step / max(warmup_steps, 1)
        for pg in optimizer.param_groups: pg["lr"] = lr


def _fwd(model, ms, hs, mode):
    ms_vi = inject_vi(ms)
    if   mode == "ms":     return model(ms=ms_vi)
    elif mode == "hs":     return model(hs=hs)
    else:                  return model(ms=ms_vi, hs=hs)


def _filter(ms, hs, label):
    mask = (label >= 0) & (label < CFG.NUM_CLASSES)
    return ms[mask], hs[mask], label[mask]


def train_epoch(model, loader, optimizer, criterion, scaler, mode,
                use_mixup, grad_accum, g_step, warmup_steps, base_lr):
    model.train()
    loss_sum = correct = total = 0
    optimizer.zero_grad(set_to_none=True)

    for i, batch in enumerate(loader):
        ms    = batch["ms"].to(device, non_blocking=True)
        hs    = batch["hs"].to(device, non_blocking=True)
        label = batch["label"].to(device, non_blocking=True)
        ms, hs, label = _filter(ms, hs, label)
        if label.numel() == 0: continue

        warmup_lr(optimizer, g_step, warmup_steps, base_lr)
        g_step += 1

        with autocast('cuda', enabled=USE_AMP):
            if use_mixup:
                ms_m, hs_m, la, lb, lam = mixup_batch(ms, hs, label)
                out  = _fwd(model, ms_m, hs_m, mode)
                loss = mixup_loss(criterion, out, la, lb, lam) / grad_accum
            else:
                out  = _fwd(model, ms, hs, mode)
                loss = criterion(out, label) / grad_accum

        scaler.scale(loss).backward()

        if (i + 1) % grad_accum == 0 or (i + 1) == len(loader):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer); scaler.update()
            optimizer.zero_grad(set_to_none=True)

        loss_sum += loss.item() * grad_accum * label.size(0)
        with torch.no_grad():
            correct += (out.argmax(1) == label).sum().item()
        total += label.size(0)

    return loss_sum / max(total,1), correct / max(total,1), g_step


@torch.no_grad()
def eval_epoch(model, loader, criterion, mode):
    model.eval()
    loss_sum = correct = total = 0
    per_c = defaultdict(int); per_t = defaultdict(int)
    for batch in loader:
        ms    = batch["ms"].to(device, non_blocking=True)
        hs    = batch["hs"].to(device, non_blocking=True)
        label = batch["label"].to(device, non_blocking=True)
        ms, hs, label = _filter(ms, hs, label)
        if label.numel() == 0: continue
        with autocast('cuda', enabled=USE_AMP):
            out  = _fwd(model, ms, hs, mode)
            loss = criterion(out, label)
        loss_sum += loss.item() * label.size(0)
        preds     = out.argmax(1)
        correct  += (preds == label).sum().item()
        total    += label.size(0)
        for lbl in range(CFG.NUM_CLASSES):
            m = label == lbl
            per_c[lbl] += (preds[m] == lbl).sum().item()
            per_t[lbl] += m.sum().item()
    cls_str = " | ".join(f"{CFG.IDX2CLASS[k]}={per_c[k]/max(per_t[k],1):.3f}"
                         for k in sorted(CFG.IDX2CLASS))
    return loss_sum / max(total,1), correct / max(total,1), cls_str


def run_stage(model, tr_ld, va_ld, epochs, lr, mode, ckpt,
              use_mixup=False, warmup_steps=CFG.WARMUP_STEPS):
    cw        = cls_weights.to(device)
    criterion = FocalLoss(gamma=CFG.FOCAL_GAMMA, weights=cw,
                          smooth=CFG.LABEL_SMOOTH).to(device)
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                      lr=lr, weight_decay=CFG.WEIGHT_DECAY)
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=max(1, epochs // 2),
                                            eta_min=CFG.LR_MIN)
    scaler    = GradScaler('cuda', enabled=USE_AMP)
    best_acc  = 0.0; best_wts = None; g_step = 0

    print(f"\n{'─'*65}")
    print(f"  mode={mode.upper()}  epochs={epochs}  lr={lr:.0e}  mixup={use_mixup}")
    print(f"{'─'*65}")

    for ep in range(1, epochs + 1):
        tr_l, tr_a, g_step = train_epoch(model, tr_ld, optimizer, criterion,
                                          scaler, mode, use_mixup,
                                          CFG.GRAD_ACCUM, g_step,
                                          warmup_steps, lr)
        va_l, va_a, cls_str = eval_epoch(model, va_ld, criterion, mode)
        if g_step >= warmup_steps: scheduler.step(ep)

        print(f"  ep {ep:3d}/{epochs}  tr {tr_l:.4f}/{tr_a:.3f}  "
              f"va {va_l:.4f}/{va_a:.3f}  lr {optimizer.param_groups[0]['lr']:.2e}")
        print(f"    {cls_str}")

        if va_a > best_acc:
            best_acc = va_a
            best_wts = {k: v.clone() for k, v in model.state_dict().items()}
            torch.save(best_wts, ckpt)
            print(f"    ✓ checkpoint saved  val_acc={va_a:.4f}")

    if best_wts: model.load_state_dict(best_wts)
    print(f"  Stage best val_acc: {best_acc:.4f}")
    return model, best_acc


# =============================================================================
# SECTION 9 – DATALOADER SETUP
# =============================================================================

random.shuffle(train_records)
buckets = defaultdict(list)
for r in train_records: buckets[r["label"]].append(r)

int_train, int_val = [], []
for lbl, recs in buckets.items():
    n_val = max(1, int(len(recs) * CFG.VAL_FRAC))
    int_val.extend(recs[:n_val]); int_train.extend(recs[n_val:])
random.shuffle(int_train); random.shuffle(int_val)
print(f"\n[DS] int_train={len(int_train)}  int_val={len(int_val)}  test={len(val_records)}")

train_ds = CropDS(int_train,   augment=True)
val_ds   = CropDS(int_val,     augment=False)
test_ds  = CropDS(val_records, augment=False)

kw = dict(num_workers=CFG.NUM_WORKERS, pin_memory=True, persistent_workers=True)
train_ld = DataLoader(train_ds, CFG.BATCH_SIZE, shuffle=True,  drop_last=True,  **kw)
val_ld   = DataLoader(val_ds,   CFG.BATCH_SIZE, shuffle=False, drop_last=False, **kw)
test_ld  = DataLoader(test_ds,  CFG.BATCH_SIZE, shuffle=False, drop_last=False, **kw)
print(f"[DS] train={len(train_ld)} val={len(val_ld)} test={len(test_ld)} batches")


# =============================================================================
# SECTION 10 – STAGE 1a: MS BASELINE
# =============================================================================

print("\n" + "="*65)
print("  STAGE 1a – MS Baseline (EfficientNet-B0, 7ch, backbone frozen)")
print("="*65)

model_ms = MSModel(feat_dim=256).to(device)
model_ms.encoder.freeze()
if N_GPUS > 1: model_ms = nn.DataParallel(model_ms)

model_ms, acc_s1a = run_stage(model_ms, train_ld, val_ld,
                               CFG.EPOCHS_S1_MS, CFG.LR * 3, "ms",
                               "/kaggle/working/ckpt_ms_frozen.pth")

print("\n  [1a] Unfreezing top-4 backbone blocks…")
(model_ms.module if isinstance(model_ms, nn.DataParallel) else model_ms).encoder.unfreeze_top(4)
model_ms, acc_s1a = run_stage(model_ms, train_ld, val_ld,
                               10, CFG.LR / 3, "ms",
                               "/kaggle/working/ckpt_ms.pth",
                               warmup_steps=50)
print(f"[S1a DONE] val_acc={acc_s1a:.4f}")


# =============================================================================
# SECTION 11 – STAGE 1b: HS CORE MODEL
# =============================================================================

print("\n" + "="*65)
print("  STAGE 1b – HS Core (EfficientNet-B0, 40ch, SE-attention, frozen)")
print("="*65)

model_hs = HSModel(feat_dim=256).to(device)
model_hs.encoder.freeze()
if N_GPUS > 1: model_hs = nn.DataParallel(model_hs)

model_hs, acc_s1b = run_stage(model_hs, train_ld, val_ld,
                               CFG.EPOCHS_S1_HS, CFG.LR * 3, "hs",
                               "/kaggle/working/ckpt_hs_frozen.pth")

print("\n  [1b] Unfreezing top-4 backbone blocks…")
(model_hs.module if isinstance(model_hs, nn.DataParallel) else model_hs).encoder.unfreeze_top(4)
model_hs, acc_s1b = run_stage(model_hs, train_ld, val_ld,
                               10, CFG.LR / 3, "hs",
                               "/kaggle/working/ckpt_hs.pth",
                               warmup_steps=50)
print(f"[S1b DONE] val_acc={acc_s1b:.4f}")


# =============================================================================
# SECTION 12 – STAGE 2: GATED FUSION (partial unfreeze)
# =============================================================================

print("\n" + "="*65)
print("  STAGE 2 – Gated Fusion (top-4 blocks unfrozen)")
print("="*65)

def _enc_w(m):
    core = m.module if isinstance(m, nn.DataParallel) else m
    return core.encoder.state_dict() if hasattr(core, "encoder") else None

model_f = FusionModel(feat_dim=256,
                      ms_w=_enc_w(model_ms),
                      hs_w=_enc_w(model_hs)).to(device)

# Freeze full backbones, keep SE-attn + proj + top-4 blocks trainable
model_f.ms_enc.freeze(); model_f.hs_enc.freeze()
model_f.ms_enc.unfreeze_top(4); model_f.hs_enc.unfreeze_top(4)
for p in model_f.ms_enc.proj.parameters():      p.requires_grad = True
for p in model_f.hs_enc.proj.parameters():      p.requires_grad = True
for p in model_f.hs_enc.se_attn.parameters():   p.requires_grad = True

if N_GPUS > 1: model_f = nn.DataParallel(model_f)

model_f, acc_s2 = run_stage(model_f, train_ld, val_ld,
                              CFG.EPOCHS_S2, CFG.LR, "fusion",
                              "/kaggle/working/ckpt_fusion_s2.pth")
print(f"[S2 DONE] val_acc={acc_s2:.4f}")


# =============================================================================
# SECTION 13 – STAGE 3: FULL FINE-TUNE + MIXUP
# =============================================================================

print("\n" + "="*65)
print("  STAGE 3 – Full Fine-tune (all layers, Mixup ON)")
print("="*65)

core_f = model_f.module if isinstance(model_f, nn.DataParallel) else model_f
core_f.ms_enc.unfreeze_all(); core_f.hs_enc.unfreeze_all()

model_f, acc_s3 = run_stage(model_f, train_ld, val_ld,
                              CFG.EPOCHS_S3, CFG.LR / 5, "fusion",
                              "/kaggle/working/ckpt_fusion_final.pth",
                              use_mixup=True)
print(f"[S3 DONE] val_acc={acc_s3:.4f}")


# =============================================================================
# SECTION 14 – ABLATION SUMMARY
# =============================================================================

print("\n" + "="*65)
print("  ABLATION SUMMARY")
print("="*65)
print(f"  S1a  MS  baseline                : {acc_s1a:.4f}")
print(f"  S1b  HS  core (SE-attn, 40 bands): {acc_s1b:.4f}")
print(f"  S2   Fusion (partial unfreeze)   : {acc_s2:.4f}")
print(f"  S3   Fusion (full + Mixup)       : {acc_s3:.4f}")
print(f"  Previous baseline                : ~0.537")
print("="*65)


# =============================================================================
# SECTION 15 – INFERENCE & SUBMISSION
# =============================================================================
#
# SUBMISSION FORMAT (exactly as required):
#   Id        → filename WITH .png extension  e.g. val_02afcb0e.png
#   Category  → Health | Rust | Other
#
# TTA: 4 orientations averaged (original, H-flip, V-flip, 90° rotation)
# =============================================================================

@torch.no_grad()
def run_inference(model, loader, mode):
    model.eval()
    rows = []
    for batch in loader:
        ms        = batch["ms"].to(device, non_blocking=True)
        hs        = batch["hs"].to(device, non_blocking=True)
        filenames = batch["filename"]   # list of "val_xxxxx.png" strings

        with autocast('cuda', enabled=USE_AMP):
            # 4-fold TTA: original + H-flip + V-flip + 90° rotation
            p0 = F.softmax(_fwd(model, ms,                          hs,                          mode), -1)
            p1 = F.softmax(_fwd(model, torch.flip(ms, [-1]),        torch.flip(hs, [-1]),        mode), -1)
            p2 = F.softmax(_fwd(model, torch.flip(ms, [-2]),        torch.flip(hs, [-2]),        mode), -1)
            p3 = F.softmax(_fwd(model, torch.rot90(ms,1,[-2,-1]),   torch.rot90(hs,1,[-2,-1]),   mode), -1)
            probs = (p0 + p1 + p2 + p3) / 4.0

        preds = probs.argmax(1).cpu().numpy()
        for fname, pred in zip(filenames, preds):
            rows.append({
                "Id":       fname,                        # e.g. val_02afcb0e.png
                "Category": CFG.IDX2CLASS[int(pred)],    # Health | Rust | Other
            })
    return rows


print("\n[INFERENCE] Running 4-fold TTA on val set…")
rows = run_inference(model_f, test_ld, "fusion")

# Build and validate submission dataframe
sub = pd.DataFrame(rows, columns=["Id", "Category"])
print(f"[INFERENCE] Total rows    : {len(sub)}")
print(f"[INFERENCE] Sample IDs    : {sub['Id'].head(3).tolist()}")
print(f"[INFERENCE] Class dist    :\n{sub['Category'].value_counts()}")

# Safety: ensure all category values are valid
invalid = set(sub["Category"].unique()) - set(CFG.CLASSES)
if invalid:
    raise ValueError(f"Invalid class names found in submission: {invalid}")

# Save to Kaggle working directory
sub.to_csv(CFG.SUBMISSION, index=False)

# Verify the file was actually written
saved_size = Path(CFG.SUBMISSION).stat().st_size
print(f"\n[SAVED] {CFG.SUBMISSION}")
print(f"[SAVED] File size: {saved_size} bytes")
print(f"\n{sub.head(10).to_string(index=False)}")

print("\n" + "="*65)
print("  PIPELINE COMPLETE")
print(f"  MS  val_acc  : {acc_s1a:.4f}")
print(f"  HS  val_acc  : {acc_s1b:.4f}")
print(f"  S2  val_acc  : {acc_s2:.4f}")
print(f"  S3  val_acc  : {acc_s3:.4f}")
print(f"  Submission   : {CFG.SUBMISSION}")
print(f"  Format       : Id (filename.png) | Category")
print("="*65)

[HW] device=cuda  GPUs=2  AMP=True
[HW] GPU=Tesla T4  VRAM=15.6GB
[PATH] Root: /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared
[PATH] TRAIN=/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared/train
[PATH] VAL  =/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared/val
[COLLECT] Train: 600 PNG files
[COLLECT] Train records=600  dist={'Health': 200, 'Other': 200, 'Rust': 200}
[COLLECT] Val: 300 PNG files
[COLLECT] Val records=300
[OK] train=600  val=300
[INFO] Class weights: [1. 1. 1.]

[DS] int_train=510  int_val=90  test=300
[DS] train=21 val=4 test=13 batches

  STAGE 1a – MS Baseline (EfficientNet-B0, 7ch, backbone frozen)


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]


─────────────────────────────────────────────────────────────────
  mode=MS  epochs=15  lr=6e-04  mixup=False
─────────────────────────────────────────────────────────────────
  ep   1/15  tr 0.5669/0.306  va 0.5372/0.289  lr 1.20e-04
    Health=0.167 | Rust=0.400 | Other=0.300
    ✓ checkpoint saved  val_acc=0.2889
  ep   2/15  tr 0.5033/0.397  va nan/0.367  lr 2.46e-04
    Health=0.400 | Rust=0.133 | Other=0.567
    ✓ checkpoint saved  val_acc=0.3667
  ep   3/15  tr 0.4419/0.498  va 0.5085/0.456  lr 3.72e-04
    Health=0.467 | Rust=0.133 | Other=0.767
    ✓ checkpoint saved  val_acc=0.4556
  ep   4/15  tr 0.4166/0.552  va 0.4422/0.433  lr 4.98e-04
    Health=0.567 | Rust=0.200 | Other=0.533
  ep   5/15  tr 0.4201/0.534  va 0.5026/0.433  lr 1.14e-04
    Health=0.500 | Rust=0.267 | Other=0.533
  ep   6/15  tr 0.4135/0.514  va 0.4789/0.433  lr 3.07e-05
    Health=0.400 | Rust=0.367 | Other=0.533
  ep   7/15  tr 0.4117/0.528  va 0.4875/0.433  lr 6.00e-04
    Health=0.433 | Rust=0.300 | 