In [1]:
pip install numpy pandas matplotlib torch torchvision pillow rasterio scikit-learn

Note: you may need to restart the kernel to use updated packages.


In [2]:
"""
=============================================================================
ICPR 2026 — Multimodal Wheat Disease Classification  ─  v6 FINAL
Target: > 0.857 on competition validation set
=============================================================================

WHY v5 GOT 0.53:
─────────────────
  Band means collapse spatial patches to single numbers.
  Health vs Rust have nearly IDENTICAL mean spectra.
  SVM/RF on means = predicting nearly random for that pair.

WHAT ACTUALLY DISCRIMINATES Health vs Rust:
────────────────────────────────────────────
  RUST creates:
    • High LOCAL VARIANCE in red-edge bands (patchy lesions)
    • High ENTROPY in NIR texture (non-uniform canopy)
    • SKEWED pixel distribution (some pixels very bright = lesion spots)
    • Low chlorophyll in RED band but HETEROGENEOUS (not uniform)

  HEALTHY creates:
    • Low variance (uniform green canopy)
    • Symmetric pixel distributions
    • Smooth spatial gradients

SOLUTION — Three-part approach:
─────────────────────────────────
  PART A: Rich spatial texture features (the key insight)
    - Per-band: mean, std, skew, kurtosis, p10,p25,p50,p75,p90
    - GLCM-like: local variance map stats, co-occurrence energy
    - Within-patch NDVI map: std, skew, fraction of stressed pixels
    - Red-edge gradient magnitude statistics
    - HS: above for ALL 101 bands → captures spectral heterogeneity

  PART B: CNN on actual image patches (spatial learning)
    - EfficientNet-B0 fine-tuned on RGB (224×224) — full fine-tuning
    - Custom CNN on HS patch (101 bands, spatial) — learns texture
    - Cross-validated, trained with MixUp + CutMix + strong augmentation

  PART C: Stacking ensemble
    - Train SVM/XGBoost/RF on PART A features
    - Combine with PART B CNN softmax outputs via logistic regression meta-learner
    - Final prediction = weighted average

This approach is grounded in what remote sensing literature actually uses
for disease detection with small datasets.
=============================================================================
"""

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import time, warnings, random
from pathlib import Path
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy.stats import skew, kurtosis

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from PIL import Image
import rasterio

from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.pipeline import Pipeline
from sklearn.base import clone
import xgboost as xgb

warnings.filterwarnings("ignore")

# =============================================================================
# 0.  CONFIG
# =============================================================================

COMP_ROOT  = "/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026"
DATA_ROOT  = os.path.join(COMP_ROOT, "Kaggle_Prepared")
TRAIN_DIR  = os.path.join(DATA_ROOT, "train")
VAL_DIR    = os.path.join(DATA_ROOT, "val")
OUTPUT_DIR = "/kaggle/working"
SUBMIT_CSV = os.path.join(OUTPUT_DIR, "submission.csv")

CLASSES    = ["Health", "Rust", "Other"]
N_CLS      = 3
PREFIX_MAP = {"health": 0, "rust": 1, "other": 2}

SEED       = 42
N_FOLDS    = 5
HS_START   = 10
HS_END     = 111          # 101 valid HS bands

# CNN settings
IMG_SIZE   = 224
BATCH_SIZE = 16
N_EPOCHS   = 80           # with early stop, usually ~40-50 effective epochs
LR_HEAD    = 1e-3         # new layers
LR_BACK    = 5e-5         # pretrained backbone
PATIENCE   = 15
WARMUP     = 3

os.makedirs(OUTPUT_DIR, exist_ok=True)

# =============================================================================
# 1.  SEED + DEVICE
# =============================================================================

def set_seed(s):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True

set_seed(SEED)

def get_device():
    if not torch.cuda.is_available():
        print("[DEVICE] CPU only"); return torch.device("cpu")
    try:
        _ = torch.zeros(1).cuda() + 1
        torch.cuda.empty_cache()
        print(f"[DEVICE] {torch.cuda.get_device_name(0)}")
        return torch.device("cuda")
    except Exception as e:
        print(f"[DEVICE] CUDA failed: {e} → CPU")
        return torch.device("cpu")

DEVICE = get_device()

# =============================================================================
# 2.  FILE HELPERS
# =============================================================================

def label_from_fn(fn):
    stem = Path(fn).stem.lower()
    for pfx, idx in PREFIX_MAP.items():
        if stem.startswith(pfx): return idx
    for pfx, idx in PREFIX_MAP.items():
        if pfx in stem: return idx
    return -1

def find_file(d, stem, exts=(".tif",".tiff",".png",".jpg")):
    if not d or not os.path.isdir(d): return None
    for ext in exts:
        p = os.path.join(d, stem + ext)
        if os.path.isfile(p): return p
    return None

def load_tif(path):
    """GeoTIFF → float32 (C,H,W) per-band min-max normalised [0,1]."""
    with rasterio.open(path) as src:
        data = src.read().astype(np.float32)
    for b in range(data.shape[0]):
        lo, hi = data[b].min(), data[b].max()
        data[b] = (data[b] - lo) / (hi - lo + 1e-8)
    return data

# =============================================================================
# 3.  FEATURE EXTRACTION  ← the key fix vs v5
#     We extract SPATIAL TEXTURE features, not just means
# =============================================================================

def band_stats(band_2d: np.ndarray) -> np.ndarray:
    """
    Extract 12 statistics from a single 2D band.
    These capture DISTRIBUTION SHAPE, not just central tendency.
    Rust lesions create high variance + positive skew in affected bands.
    """
    flat = band_2d.flatten().astype(np.float64)
    if flat.std() < 1e-8:   # constant band — return zeros
        return np.zeros(12, dtype=np.float32)
    return np.array([
        flat.mean(),
        flat.std(),
        float(skew(flat)),          # asymmetry: rust → positive skew
        float(kurtosis(flat)),      # peakedness: healthy → low kurtosis
        np.percentile(flat, 5),
        np.percentile(flat, 10),
        np.percentile(flat, 25),
        np.percentile(flat, 50),
        np.percentile(flat, 75),
        np.percentile(flat, 90),
        np.percentile(flat, 95),
        # Local variance (texture) — variance of a 3×3 local mean subtracted
        float(ndimage.uniform_filter(band_2d, size=3).flatten().std()),
    ], dtype=np.float32)


def extract_ms_features(ms: np.ndarray) -> np.ndarray:
    """
    5-band MS patch (5,H,W) → rich feature vector.

    Structure:
      5 bands × 12 stats           = 60
      Per-pixel index maps (stats) = 60
      Total                        = 120
    """
    feats = []
    b, g, r, re, nir = [ms[i] for i in range(5)]
    e = 1e-8

    # Per-band texture stats
    for band in [b, g, r, re, nir]:
        feats.append(band_stats(band))  # 12 each → 60

    # Per-pixel vegetation index MAPS → stats on the distribution
    # (captures within-patch heterogeneity)
    def idx_stats(map_2d):
        return band_stats(map_2d)  # same 12 stats

    ndvi_map   = (nir - r)  / (nir + r  + e)
    rendvi_map = (nir - re) / (nir + re + e)
    ndre_map   = (re  - r)  / (re  + r  + e)
    gndvi_map  = (nir - g)  / (nir + g  + e)
    evi_map    = 2.5*(nir-r)/ (nir + 6*r - 7.5*b + 1 + e)

    feats.append(idx_stats(ndvi_map))    # 12
    feats.append(idx_stats(rendvi_map))  # 12
    feats.append(idx_stats(ndre_map))    # 12
    feats.append(idx_stats(gndvi_map))   # 12
    feats.append(idx_stats(evi_map))     # 12

    # Total: 5×12 + 5×12 = 120
    return np.concatenate(feats).astype(np.float32)


def extract_hs_features(hs: np.ndarray) -> np.ndarray:
    """
    101-band HS patch (101,H,W) → feature vector.

    Strategy: per-band stats for ALL 101 bands (captures full spectral curve
    shape AND its spatial variability), plus key region summaries.

    Per-band: mean, std, skew, p25, p50, p75  → 6 stats × 101 = 606
    Key regions: 6 summary features
    Total: 612
    """
    feats = []

    for b in range(hs.shape[0]):
        flat = hs[b].flatten().astype(np.float64)
        if flat.std() < 1e-8:
            feats.append(np.zeros(6, dtype=np.float32))
        else:
            feats.append(np.array([
                flat.mean(),
                flat.std(),
                float(skew(flat)),
                np.percentile(flat, 25),
                np.percentile(flat, 50),
                np.percentile(flat, 75),
            ], dtype=np.float32))

    # Key spectral region summaries (from EDA)
    # Green peak ~band 15-20
    gp_std     = hs[13:20].std()
    # Red-edge slope bands ~53-65 (700-750 nm)
    re_mean    = hs[53:65].mean()
    re_std     = hs[53:65].std()     # HIGH for rust (patchy re-absorption)
    re_slope   = hs[60:65].mean() - hs[53:58].mean()
    # NIR plateau ~73-84 (781-821 nm)
    nir_mean   = hs[73:84].mean()
    nir_std    = hs[73:84].std()

    feats.append(np.array([gp_std, re_mean, re_std, re_slope,
                            nir_mean, nir_std], dtype=np.float32))

    return np.concatenate(feats).astype(np.float32)   # 612


def build_all_features(split="train", verbose=True):
    """
    Load all samples for a split, extract features.
    Returns X_ms (N,120), X_hs (N,612), labels (N,), stems list, paths dict.
    """
    rgb_dir = os.path.join(TRAIN_DIR if split=="train" else VAL_DIR, "RGB")
    ms_dir  = os.path.join(TRAIN_DIR if split=="train" else VAL_DIR, "MS")
    hs_dir  = os.path.join(TRAIN_DIR if split=="train" else VAL_DIR, "HS")

    files = sorted(f for f in os.listdir(rgb_dir)
                   if f.lower().endswith((".png",".jpg",".jpeg")))
    N = len(files)
    if verbose: print(f"  [{split}] {N} files")

    X_ms   = np.zeros((N, 120), dtype=np.float32)
    X_hs   = np.zeros((N, 612), dtype=np.float32)
    labels = np.full(N, -1, dtype=np.int32)
    stems  = []
    paths  = {"rgb": [], "ms": [], "hs": []}

    for i, fn in enumerate(files):
        stem = Path(fn).stem; stems.append(stem)
        if split == "train":
            labels[i] = label_from_fn(fn)

        rgb_p = os.path.join(rgb_dir, fn)
        ms_p  = find_file(ms_dir, stem)
        hs_p  = find_file(hs_dir, stem)
        paths["rgb"].append(rgb_p)
        paths["ms"] .append(ms_p)
        paths["hs"] .append(hs_p)

        if ms_p and os.path.isfile(ms_p):
            try:    X_ms[i] = extract_ms_features(load_tif(ms_p))
            except Exception as ex:
                if verbose: print(f"  [WARN MS] {stem}: {ex}")

        if hs_p and os.path.isfile(hs_p):
            try:    X_hs[i] = extract_hs_features(load_tif(hs_p)[HS_START:HS_END])
            except Exception as ex:
                if verbose: print(f"  [WARN HS] {stem}: {ex}")

        if verbose and (i+1) % 200 == 0:
            print(f"    {i+1}/{N}")

    return X_ms, X_hs, labels, stems, paths

# =============================================================================
# 4.  CLASSICAL ML  (on texture features — much stronger than band means)
# =============================================================================

def build_classifiers():
    """Returns dict of sklearn pipelines."""
    return {
        "svm_rbf": Pipeline([
            ("sc",  StandardScaler()),
            ("clf", SVC(C=10,  gamma="scale", kernel="rbf",
                        probability=True, random_state=SEED,
                        class_weight="balanced")),
        ]),
        "svm_c100": Pipeline([
            ("sc",  StandardScaler()),
            ("clf", SVC(C=100, gamma="scale", kernel="rbf",
                        probability=True, random_state=SEED,
                        class_weight="balanced")),
        ]),
        "rf": Pipeline([
            ("sc",  StandardScaler()),
            ("clf", RandomForestClassifier(
                n_estimators=600, max_depth=None, min_samples_leaf=1,
                max_features="sqrt", class_weight="balanced",
                random_state=SEED, n_jobs=-1)),
        ]),
        "et": Pipeline([
            ("sc",  StandardScaler()),
            ("clf", ExtraTreesClassifier(
                n_estimators=600, max_depth=None, min_samples_leaf=1,
                max_features="sqrt", class_weight="balanced",
                random_state=SEED, n_jobs=-1)),
        ]),
        "xgb": Pipeline([
            ("sc",  StandardScaler()),
            ("clf", xgb.XGBClassifier(
                n_estimators=500, max_depth=5, learning_rate=0.03,
                subsample=0.8, colsample_bytree=0.8, min_child_weight=3,
                eval_metric="mlogloss", random_state=SEED,
                n_jobs=-1, verbosity=0)),
        ]),
        "gb": Pipeline([
            ("sc",  StandardScaler()),
            ("clf", GradientBoostingClassifier(
                n_estimators=300, max_depth=4, learning_rate=0.05,
                subsample=0.8, random_state=SEED)),
        ]),
    }

# =============================================================================
# 5.  CNN  — Fine-tuned EfficientNet-B3 on RGB
#     B3 is stronger than B0, still fits in GPU memory
# =============================================================================

class RGBDataset(Dataset):
    def __init__(self, rgb_paths, labels=None, augment=False):
        self.paths   = rgb_paths
        self.labels  = labels
        self.augment = augment
        self.tf = self._build_tf(augment)

    @staticmethod
    def _build_tf(aug):
        if aug:
            return transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.RandomCrop(IMG_SIZE),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(45),
                transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                       saturation=0.3, hue=0.1),
                transforms.RandomGrayscale(p=0.05),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
                transforms.RandomErasing(p=0.25, scale=(0.02,0.2)),
            ])
        return transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
        ])

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        try:
            img = Image.open(p).convert("RGB") if p and os.path.isfile(p) \
                  else Image.new("RGB", (IMG_SIZE, IMG_SIZE))
        except:
            img = Image.new("RGB", (IMG_SIZE, IMG_SIZE))
        x = self.tf(img)
        if self.labels is not None:
            lbl = int(self.labels[idx])
            assert 0 <= lbl < N_CLS, f"Bad label {lbl}"
            return x, lbl
        return x


def build_cnn():
    """EfficientNet-B3 with custom classification head."""
    model = models.efficientnet_b3(
        weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
    in_feat = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(0.4),
        nn.Linear(in_feat, 512),
        nn.BatchNorm1d(512),
        nn.GELU(),
        nn.Dropout(0.3),
        nn.Linear(512, N_CLS),
    )
    return model


def mixup_data(x, y, alpha=0.3):
    """MixUp augmentation — interpolates pairs of samples."""
    if alpha <= 0: return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    return (lam*x + (1-lam)*x[idx], y, y[idx], lam)


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1-lam) * criterion(pred, y_b)


def get_warmup_cosine_scheduler(optimizer, warmup_ep, total_ep, min_lr=1e-6):
    def lr_lambda(ep):
        if ep < warmup_ep:
            return float(ep+1) / float(warmup_ep)
        prog = float(ep - warmup_ep) / float(max(1, total_ep - warmup_ep))
        cosine = 0.5 * (1.0 + np.cos(np.pi * prog))
        base_lr = optimizer.param_groups[0]["initial_lr"]
        return max(min_lr / base_lr, cosine)
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def train_cnn_fold(tr_paths, tr_labels, vl_paths, vl_labels, device):
    """Train one fold of the CNN. Returns best model + val probs."""
    model = build_cnn().to(device)

    # Separate LR: backbone vs head
    backbone_params = list(model.features.parameters())
    head_params     = list(model.classifier.parameters())
    optimizer = torch.optim.AdamW([
        {"params": backbone_params, "lr": LR_BACK,
         "initial_lr": LR_BACK},
        {"params": head_params,     "lr": LR_HEAD,
         "initial_lr": LR_HEAD},
    ], weight_decay=1e-4)

    scheduler = get_warmup_cosine_scheduler(optimizer, WARMUP, N_EPOCHS)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    tr_ds = RGBDataset(tr_paths, tr_labels, augment=True)
    vl_ds = RGBDataset(vl_paths, vl_labels, augment=False)

    # Weighted sampler for balanced batches
    cls_counts = np.bincount(tr_labels, minlength=N_CLS)
    weights    = 1.0 / cls_counts[tr_labels]
    sampler    = WeightedRandomSampler(weights, len(weights), replacement=True)

    tr_ld = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=sampler,
                       num_workers=0, drop_last=True)
    vl_ld = DataLoader(vl_ds, batch_size=BATCH_SIZE, shuffle=False,
                       num_workers=0)

    best_acc   = 0.0
    best_state = None
    no_improve = 0

    for ep in range(1, N_EPOCHS + 1):
        # ── Train ────────────────────────────────────────────────────────────
        model.train()
        for x, y in tr_ld:
            x, y = x.to(device), y.to(device)
            x, ya, yb, lam = mixup_data(x, y, alpha=0.4)
            optimizer.zero_grad()
            loss = mixup_criterion(criterion, model(x), ya, yb, lam)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        scheduler.step()

        # ── Validate ─────────────────────────────────────────────────────────
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in vl_ld:
                x, y   = x.to(device), y.to(device)
                preds  = model(x).argmax(1)
                correct += (preds == y).sum().item()
                total   += y.size(0)
        acc = correct / total

        if acc > best_acc:
            best_acc   = acc
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1

        if ep % 10 == 0:
            print(f"      ep {ep:03d} val_acc={acc:.4f}  best={best_acc:.4f}")

        if no_improve >= PATIENCE:
            print(f"      Early stop at ep {ep}. Best val acc={best_acc:.4f}")
            break

    model.load_state_dict(best_state)
    return model, best_acc


@torch.no_grad()
def get_probs(model, paths, device, augment=False, n_tta=1):
    """
    Get softmax probabilities.
    If n_tta > 1, run n_tta augmented forward passes and average.
    """
    model.eval()
    ds = RGBDataset(paths, labels=None, augment=augment)
    ld = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    all_probs = []
    for x in ld:
        logits = model(x.to(device))
        all_probs.append(torch.softmax(logits, 1).cpu().numpy())
    return np.vstack(all_probs)


@torch.no_grad()
def get_probs_tta(model, paths, device, n_tta=5):
    """Average probabilities over n_tta augmented views."""
    all_runs = []
    # 1 clean pass
    all_runs.append(get_probs(model, paths, device, augment=False))
    # n_tta-1 augmented passes
    for _ in range(n_tta - 1):
        all_runs.append(get_probs(model, paths, device, augment=True))
    return np.mean(all_runs, axis=0)

# =============================================================================
# 6.  STACKING META-LEARNER
#     Takes OOF probs from classical + CNN, learns optimal combination
# =============================================================================

def build_meta_learner():
    return LogisticRegression(
        C=1.0, max_iter=1000, random_state=SEED,
        multi_class="multinomial", solver="lbfgs")

# =============================================================================
# 7.  FULL CV PIPELINE
# =============================================================================

def run_cv(X_feat, y, rgb_paths_all, comp_rgb_paths, device):
    """
    5-Fold Stratified CV:
      Each fold trains:
        - Classical models on texture features (X_feat)
        - CNN on RGB patches

      Collects:
        - OOF probs for all models
        - Test probs for final ensemble

    Returns:
      oof_classical  : (N, 3) averaged classical OOF probs
      oof_cnn        : (N, 3) CNN OOF probs
      test_classical : (N_test, 3) classical test probs (fold-averaged)
      test_cnn       : (N_test, 3) CNN test probs (fold-averaged, with TTA)
    """
    skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
    N_test = len(comp_rgb_paths)

    oof_classical  = np.zeros((len(X_feat), 3), dtype=np.float32)
    oof_cnn        = np.zeros((len(X_feat), 3), dtype=np.float32)
    test_classical_folds = []
    test_cnn_folds       = []

    for fold, (tr_idx, vl_idx) in enumerate(skf.split(X_feat, y)):
        print(f"\n{'─'*55}")
        print(f"  FOLD {fold+1}/{N_FOLDS}")
        print(f"{'─'*55}")

        X_tr, X_vl = X_feat[tr_idx], X_feat[vl_idx]
        y_tr, y_vl = y[tr_idx],      y[vl_idx]
        rgb_tr = [rgb_paths_all[i] for i in tr_idx]
        rgb_vl = [rgb_paths_all[i] for i in vl_idx]

        # ── Classical models ──────────────────────────────────────────────────
        print(f"\n  [Classical ML]")
        classifiers = build_classifiers()
        fold_cl_oof = np.zeros((len(vl_idx), 3), dtype=np.float32)

        for name, pipe in classifiers.items():
            p = clone(pipe)
            p.fit(X_tr, y_tr)
            vl_prob = p.predict_proba(X_vl)
            vl_acc  = accuracy_score(y_vl, vl_prob.argmax(1))
            print(f"    {name:<12} val_acc={vl_acc:.4f}")
            fold_cl_oof += vl_prob

        oof_classical[vl_idx] = fold_cl_oof / len(classifiers)
        cl_oof_acc = accuracy_score(y_vl, oof_classical[vl_idx].argmax(1))
        print(f"    {'ENSEMBLE':<12} val_acc={cl_oof_acc:.4f}")

        # ── CNN ───────────────────────────────────────────────────────────────
        print(f"\n  [CNN EfficientNet-B3]")
        cnn_model, cnn_acc = train_cnn_fold(
            rgb_tr, y_tr, rgb_vl, y_vl, device)
        print(f"    CNN best val_acc={cnn_acc:.4f}")

        oof_cnn[vl_idx] = get_probs(cnn_model, rgb_vl, device)
        cnn_test_probs  = get_probs_tta(cnn_model, comp_rgb_paths, device,
                                        n_tta=5)
        test_cnn_folds.append(cnn_test_probs)

        # Cleanup GPU memory after each fold
        del cnn_model
        if DEVICE.type == "cuda": torch.cuda.empty_cache()

    # Average CNN test probs across folds
    test_cnn = np.mean(test_cnn_folds, axis=0)

    return oof_classical, oof_cnn, test_cnn


def cv_classical_on_test(X_feat_train, y_train, X_feat_test):
    """
    Retrain classical models on all train data, predict on test.
    Separate function to ensure correct test features are used.
    """
    print("\n  [Classical — final retrain on all 600 samples]")
    classifiers = build_classifiers()
    all_probs   = []
    for name, pipe in classifiers.items():
        p = clone(pipe)
        p.fit(X_feat_train, y_train)
        prob = p.predict_proba(X_feat_test)
        acc  = accuracy_score(y_train, p.predict(X_feat_train))
        print(f"    {name:<12} train_acc={acc:.4f}")
        all_probs.append(prob)
    return np.mean(all_probs, axis=0)

# =============================================================================
# 8.  MAIN
# =============================================================================

def main():
    t0 = time.time()
    print("\n" + "="*60)
    print("  ICPR 2026 Wheat Disease  —  v6 SPATIAL TEXTURE + CNN")
    print("="*60)

    # ── Step 1: Extract features ──────────────────────────────────────────────
    print("\n[STEP 1] Extracting spatial texture features from TRAIN …")
    X_ms_tr, X_hs_tr, y_tr, stems_tr, paths_tr = build_all_features("train")
    print(f"  MS features: {X_ms_tr.shape}  HS features: {X_hs_tr.shape}")
    assert (y_tr >= 0).all() and (y_tr < N_CLS).all(), \
        f"Bad train labels: {np.unique(y_tr)}"
    counts = dict(zip(*np.unique(y_tr, return_counts=True)))
    print(f"  Label distribution: {counts}")

    print("\n[STEP 1b] Extracting features from competition VAL …")
    X_ms_vl, X_hs_vl, _, stems_vl, paths_vl = build_all_features("val")

    # ── Step 2: Concatenate MS + HS features ─────────────────────────────────
    print("\n[STEP 2] Concatenating MS+HS texture features …")
    X_train = np.concatenate([X_ms_tr, X_hs_tr], axis=1)   # (600, 732)
    X_test  = np.concatenate([X_ms_vl, X_hs_vl], axis=1)   # (300, 732)
    print(f"  Train feature matrix: {X_train.shape}")
    print(f"  Test  feature matrix: {X_test.shape}")

    # ── Step 3: CV training ───────────────────────────────────────────────────
    print(f"\n[STEP 3] {N_FOLDS}-Fold Cross-Validation …")
    oof_classical, oof_cnn, test_cnn = run_cv(
        X_train, y_tr,
        rgb_paths_all  = paths_tr["rgb"],
        comp_rgb_paths = paths_vl["rgb"],
        device         = DEVICE,
    )

    # Classical test predictions from full retrain
    test_classical = cv_classical_on_test(X_train, y_tr, X_test)

    # ── Step 4: Evaluate OOF ─────────────────────────────────────────────────
    print("\n[STEP 4] OOF Evaluation …")

    oof_cl_acc  = accuracy_score(y_tr, oof_classical.argmax(1))
    oof_cnn_acc = accuracy_score(y_tr, oof_cnn.argmax(1))

    # Try various ensemble weights and pick best
    best_ens_acc = 0.0
    best_w       = (0.5, 0.5)
    for w_cl in np.arange(0.1, 1.0, 0.05):
        w_cn  = 1.0 - w_cl
        combo = w_cl * oof_classical + w_cn * oof_cnn
        acc   = accuracy_score(y_tr, combo.argmax(1))
        if acc > best_ens_acc:
            best_ens_acc = acc
            best_w       = (w_cl, w_cn)

    print(f"\n  OOF Classical accuracy : {oof_cl_acc:.4f}")
    print(f"  OOF CNN accuracy       : {oof_cnn_acc:.4f}")
    print(f"  OOF Ensemble accuracy  : {best_ens_acc:.4f}  "
          f"(w_classical={best_w[0]:.2f}, w_cnn={best_w[1]:.2f})")

    oof_final = best_w[0] * oof_classical + best_w[1] * oof_cnn
    print("\n  Classification Report (OOF):")
    print(classification_report(y_tr, oof_final.argmax(1),
                                target_names=CLASSES, zero_division=0))
    cm = confusion_matrix(y_tr, oof_final.argmax(1))
    print("  Confusion Matrix:")
    print("         " + "  ".join(f"{c:>8}" for c in CLASSES))
    for i, row in enumerate(cm):
        print(f"  {CLASSES[i]:>6}  " + "  ".join(f"{v:>8}" for v in row))

    # ── Step 5: Build submission ──────────────────────────────────────────────
    print("\n[STEP 5] Building submission …")
    test_probs = best_w[0] * test_classical + best_w[1] * test_cnn
    idx2cls    = {i: c for i, c in enumerate(CLASSES)}

    rows = []
    for stem, prob in zip(stems_vl, test_probs):
        rows.append({"Id": stem + ".png",
                     "Category": idx2cls[int(prob.argmax())]})
    sub = pd.DataFrame(rows)
    sub.to_csv(SUBMIT_CSV, index=False)

    print(f"\n  {len(sub)} predictions → {SUBMIT_CSV}")
    print("  Distribution:")
    print(sub["Category"].value_counts().to_string())

    # ── Plot confusion matrix ─────────────────────────────────────────────────
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    for ax, mat, title in zip(
        axes,
        [confusion_matrix(y_tr, oof_classical.argmax(1)),
         confusion_matrix(y_tr, oof_cnn.argmax(1)),
         cm],
        ["Classical OOF", "CNN OOF", "Ensemble OOF"]
    ):
        im = ax.imshow(mat, cmap="Blues")
        ax.set_xticks(range(N_CLS)); ax.set_xticklabels(CLASSES, rotation=30)
        ax.set_yticks(range(N_CLS)); ax.set_yticklabels(CLASSES)
        ax.set_title(f"{title}\nacc={accuracy_score(y_tr, (oof_classical if 'Classical' in title else oof_cnn if 'CNN' in title else oof_final).argmax(1)):.4f}")
        for r in range(N_CLS):
            for c_ in range(N_CLS):
                ax.text(c_, r, str(mat[r,c_]), ha="center", va="center",
                        color="white" if mat[r,c_] > mat.max()/2 else "black")
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "confusion_matrices.png"), dpi=100)
    plt.close()

    elapsed = (time.time() - t0) / 60
    print(f"\n{'='*60}")
    print(f"  OOF Accuracy  : {best_ens_acc:.4f}")
    print(f"  Runtime       : {elapsed:.1f} min")
    print(f"  Submission    : {SUBMIT_CSV}")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()

[DEVICE] Tesla T4

  ICPR 2026 Wheat Disease  —  v6 SPATIAL TEXTURE + CNN

[STEP 1] Extracting spatial texture features from TRAIN …
  [train] 600 files
    200/600
    400/600
    600/600
  MS features: (600, 120)  HS features: (600, 612)
  Label distribution: {np.int32(0): np.int64(200), np.int32(1): np.int64(200), np.int32(2): np.int64(200)}

[STEP 1b] Extracting features from competition VAL …
  [val] 300 files
    200/300

[STEP 2] Concatenating MS+HS texture features …
  Train feature matrix: (600, 732)
  Test  feature matrix: (300, 732)

[STEP 3] 5-Fold Cross-Validation …

───────────────────────────────────────────────────────
  FOLD 1/5
───────────────────────────────────────────────────────

  [Classical ML]
    svm_rbf      val_acc=0.5833
    svm_c100     val_acc=0.5167
    rf           val_acc=0.5667
    et           val_acc=0.5583
    xgb          val_acc=0.5750
    gb           val_acc=0.5750
    ENSEMBLE     val_acc=0.5833

  [CNN EfficientNet-B3]
Downloading: "https://d

100%|██████████| 47.2M/47.2M [00:00<00:00, 169MB/s] 


      ep 010 val_acc=0.4667  best=0.5333
      ep 020 val_acc=0.4833  best=0.5833
      Early stop at ep 26. Best val acc=0.5833
    CNN best val_acc=0.5833

───────────────────────────────────────────────────────
  FOLD 2/5
───────────────────────────────────────────────────────

  [Classical ML]
    svm_rbf      val_acc=0.4917
    svm_c100     val_acc=0.5250
    rf           val_acc=0.5167
    et           val_acc=0.5167
    xgb          val_acc=0.5667
    gb           val_acc=0.5333
    ENSEMBLE     val_acc=0.5583

  [CNN EfficientNet-B3]
      ep 010 val_acc=0.4333  best=0.5833
      ep 020 val_acc=0.5000  best=0.6083
      ep 030 val_acc=0.5583  best=0.6083
      Early stop at ep 33. Best val acc=0.6083
    CNN best val_acc=0.6083

───────────────────────────────────────────────────────
  FOLD 3/5
───────────────────────────────────────────────────────

  [Classical ML]
    svm_rbf      val_acc=0.6500
    svm_c100     val_acc=0.5333
    rf           val_acc=0.6000
    et          