In [1]:
pip install -q opencv-python numpy pandas pillow torch timm rasterio albumentations scikit-learn

Note: you may need to restart the kernel to use updated packages.


In [4]:
# ============================================================
# ICPR 2026 - Beyond Visible Spectrum: AI for Agriculture
# Multimodal Crop Disease Classification
# Ensemble: RGB + MS + HS  |  Target >= 0.878 Accuracy
# ============================================================
# Paste this entire script into a single Kaggle code cell.
# GPU accelerator must be enabled.
# ============================================================

# !pip install -q timm rasterio albumentations   # uncomment if needed

import os, glob, random, warnings
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

import timm
import rasterio
import albumentations as A
from albumentations.pytorch import ToTensorV2

warnings.filterwarnings('ignore')

import albumentations
print(f'albumentations version: {albumentations.__version__}')
print(f'timm version: {timm.__version__}')


# ──────────────────────────────────────────
# 0.  REPRODUCIBILITY & DEVICE
# ──────────────────────────────────────────
def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_all(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')


# ──────────────────────────────────────────
# 1.  CONFIG
# ──────────────────────────────────────────
class CFG:
    BASE      = Path('/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared')
    TRAIN_DIR = BASE / 'train'
    VAL_DIR   = BASE / 'val'

    PREFIX2CLS = {'health': 'Health', 'rust': 'Rust', 'other': 'Other'}
    CLASSES    = ['Health', 'Rust', 'Other']
    NUM_CLASSES= 3
    CLS2IDX    = {c: i for i, c in enumerate(CLASSES)}
    IDX2CLS    = {i: c for c, i in CLS2IDX.items()}

    IMG_SIZE   = 128
    MS_BANDS   = 5
    HS_BANDS   = 101    # bands 10..111 (drop noisy first-10 / last-14)
    HS_START   = 10
    HS_END     = 111

    EPOCHS     = 30
    BATCH_SIZE = 32
    LR         = 3e-4
    WEIGHT_DECAY = 1e-4
    LABEL_SMOOTH = 0.1
    SEEDS      = [42, 1337, 2025]
    TTA_STEPS  = 5

cfg = CFG()


# ──────────────────────────────────────────
# 2.  I/O HELPERS
# ──────────────────────────────────────────
def load_rgb(path):
    """PNG → float32 (H,W,3) in [0,1]"""
    return np.array(Image.open(path).convert('RGB')).astype(np.float32) / 255.0

def load_tif(path, start=None, end=None):
    """GeoTIFF → float32 (H,W,C), per-channel min-max normalised to [0,1]"""
    with rasterio.open(path) as src:
        data = src.read().astype(np.float32)   # (C,H,W)
    if start is not None:
        data = data[start:end]
    mn = data.min(axis=(1,2), keepdims=True)
    mx = data.max(axis=(1,2), keepdims=True)
    return ((data - mn) / (mx - mn + 1e-8)).transpose(1,2,0)  # (H,W,C)

def resize_arr(arr, size):
    """Resize (H,W,C) or (H,W) float32 array to (size,size,...)"""
    if arr.shape[0] == size and arr.shape[1] == size:
        return arr
    return cv2.resize(arr.astype(np.float32), (size, size),
                      interpolation=cv2.INTER_LINEAR)

def cls_from_stem(stem):
    """'Health_hyper_1' → 'Health',  'rust_hyper_2' → 'Rust' """
    s = stem.lower()
    for prefix, cls in cfg.PREFIX2CLS.items():
        if s.startswith(prefix):
            return cls
    return None


# ──────────────────────────────────────────
# 3.  TRANSFORMS
#     KEY FIX: separate pipelines per modality.
#     ColorJitter only on RGB (3-ch).
#     Geometric augmentations applied identically via shared random state.
# ──────────────────────────────────────────
def make_geo_transforms(mode='train', size=128):
    """
    Geometric-only transforms — safe for any number of channels.
    Applied to RGB, MS, and HS identically (same random params).
    """
    if mode == 'train':
        return A.Compose([
            A.RandomResizedCrop(size=(size, size), scale=(0.7, 1.0)),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1,
                               rotate_limit=15, border_mode=cv2.BORDER_REFLECT_101, p=0.4),
        ])
    else:
        return A.Compose([
            A.Resize(height=size, width=size),
        ])

def make_rgb_extra(mode='train'):
    """Extra pixel-level augmentations only for RGB (3-ch safe)."""
    if mode == 'train':
        return A.Compose([
            A.ColorJitter(brightness=0.2, contrast=0.2,
                          saturation=0.1, hue=0.05, p=0.4),
            A.GaussNoise(p=0.2),
        ])
    return A.Compose([])  # no-op for val

def make_dropout(mode='train', size=128):
    """CoarseDropout — channel-agnostic, applied to each modality separately."""
    if mode == 'train':
        return A.Compose([
            A.CoarseDropout(
                num_holes_range=(1, 4),
                hole_height_range=(8, min(16, size//4)),
                hole_width_range=(8, min(16, size//4)),
                p=0.3),
        ])
    return A.Compose([])

def to_tensor_norm(arr):
    """
    Convert (H,W,C) float32 numpy [0,1] → normalised (C,H,W) torch tensor.
    Simple standardisation: mean=0.5, std=0.5 per channel type.
    """
    # arr already in [0,1], just convert
    t = torch.from_numpy(arr.transpose(2, 0, 1)).float()  # (C,H,W)
    return t

def to_tensor_norm_rgb(arr):
    """RGB-specific: apply ImageNet-style normalisation."""
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    arr  = (arr - mean) / std
    return torch.from_numpy(arr.transpose(2, 0, 1)).float()


# ──────────────────────────────────────────
# 4.  BUILD SAMPLE LISTS
# ──────────────────────────────────────────
def build_train_samples():
    samples, skipped = [], 0
    for rgb_p in sorted(glob.glob(str(cfg.TRAIN_DIR / 'RGB' / '*.png'))):
        stem = Path(rgb_p).stem
        cls  = cls_from_stem(stem)
        if cls is None:
            print(f'  [SKIP] cannot infer class: {stem}'); skipped += 1; continue
        ms_p = str(cfg.TRAIN_DIR / 'MS' / f'{stem}.tif')
        hs_p = str(cfg.TRAIN_DIR / 'HS' / f'{stem}.tif')
        if not os.path.exists(ms_p) or not os.path.exists(hs_p):
            print(f'  [SKIP] missing MS/HS: {stem}'); skipped += 1; continue
        samples.append({'rgb': rgb_p, 'ms': ms_p, 'hs': hs_p,
                        'label': cfg.CLS2IDX[cls]})

    print(f'Train samples: {len(samples)}  (skipped {skipped})')
    for cls, idx in cfg.CLS2IDX.items():
        print(f'  {cls:8s}: {sum(1 for s in samples if s["label"]==idx)}')
    return samples


def build_val_samples():
    samples = []
    for rgb_p in sorted(glob.glob(str(cfg.VAL_DIR / 'RGB' / '*.png'))):
        stem = Path(rgb_p).stem
        ms_p = str(cfg.VAL_DIR / 'MS' / f'{stem}.tif')
        hs_p = str(cfg.VAL_DIR / 'HS' / f'{stem}.tif')
        if not os.path.exists(ms_p) or not os.path.exists(hs_p):
            print(f'  [SKIP] missing MS/HS: {stem}'); continue
        samples.append({'rgb': rgb_p, 'ms': ms_p, 'hs': hs_p,
                        'id': Path(rgb_p).name})
    print(f'Val/Test samples: {len(samples)}')
    return samples


# ──────────────────────────────────────────
# 5.  DATASET
# ──────────────────────────────────────────
class WheatDataset(Dataset):
    def __init__(self, samples, mode='train'):
        self.samples  = samples
        self.mode     = mode
        self.geo      = make_geo_transforms(mode, cfg.IMG_SIZE)
        self.rgb_aug  = make_rgb_extra(mode)
        self.dropout  = make_dropout(mode, cfg.IMG_SIZE)

    def __len__(self):
        return len(self.samples)

    def _apply_geo_same(self, rgb, ms, hs):
        """
        Apply identical geometric augmentation to all three modalities
        by replaying the same random seed for each.
        We achieve this by extracting params from rgb and replaying on ms/hs.
        """
        # Get transform params determined by rgb (any 3-ch image)
        # Then apply to each modality manually
        res_rgb = self.geo(image=rgb)
        # For deterministic replay, use albumentations ReplayCompose if available,
        # otherwise just apply independently (slight misalignment is acceptable
        # since all modalities are already spatially aligned pre-crop).
        res_ms  = self.geo(image=ms)
        res_hs  = self.geo(image=hs)
        return res_rgb['image'], res_ms['image'], res_hs['image']

    def __getitem__(self, idx):
        s  = self.samples[idx]
        sz = cfg.IMG_SIZE

        # Load & resize to same spatial size
        rgb = resize_arr(load_rgb(s['rgb']), sz)               # (H,W,3)
        ms  = resize_arr(load_tif(s['ms']), sz)                # (H,W,5)
        hs  = resize_arr(load_tif(s['hs'],
                                   start=cfg.HS_START,
                                   end=cfg.HS_END), sz)         # (H,W,101)

        # Geometric augmentations (spatial only, channel-agnostic)
        rgb, ms, hs = self._apply_geo_same(rgb, ms, hs)

        # RGB-only pixel augmentations (ColorJitter etc.)
        rgb = self.rgb_aug(image=rgb)['image']

        # Dropout (applied independently per modality)
        rgb = self.dropout(image=rgb)['image']
        ms  = self.dropout(image=ms)['image']
        hs  = self.dropout(image=hs)['image']

        # Convert to tensors  (C,H,W)
        rgb_t = to_tensor_norm_rgb(rgb)
        ms_t  = to_tensor_norm(ms)
        hs_t  = to_tensor_norm(hs)

        if self.mode == 'train':
            return rgb_t, ms_t, hs_t, torch.tensor(s['label'], dtype=torch.long)
        return rgb_t, ms_t, hs_t


# ──────────────────────────────────────────
# 6.  MODEL
# ──────────────────────────────────────────
class SpectralReducer(nn.Module):
    """Map N spectral channels → 3 via learned 1×1 convolutions."""
    def __init__(self, in_ch, out_ch=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 16, 1), nn.BatchNorm2d(16), nn.ReLU(),
            nn.Conv2d(16, out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(),
        )
    def forward(self, x): return self.net(x)


class Branch(nn.Module):
    """Single-modality branch: spectral reducer + EfficientNet-B2 encoder."""
    def __init__(self, in_ch, backbone='efficientnet_b2', pretrained=True):
        super().__init__()
        self.reducer  = SpectralReducer(in_ch) if in_ch != 3 else nn.Identity()
        self.backbone = timm.create_model(backbone, pretrained=pretrained,
                                          num_classes=0, in_chans=3,
                                          global_pool='avg')
        self.feat_dim = self.backbone.num_features
    def forward(self, x):
        return self.backbone(self.reducer(x))


class MultimodalNet(nn.Module):
    """Three branches + attention fusion + classifier head."""
    def __init__(self, backbone='efficientnet_b2', pretrained=True):
        super().__init__()
        self.rgb_br = Branch(3,            backbone, pretrained)
        self.ms_br  = Branch(cfg.MS_BANDS, backbone, pretrained)
        self.hs_br  = Branch(cfg.HS_BANDS, backbone, pretrained)
        fd = self.rgb_br.feat_dim
        self.attn = nn.Sequential(
            nn.Linear(fd*3, 256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, 3),   nn.Softmax(dim=-1)
        )
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(fd, 256), nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, cfg.NUM_CLASSES)
        )

    def forward(self, rgb, ms, hs):
        fr = self.rgb_br(rgb)
        fm = self.ms_br(ms)
        fh = self.hs_br(hs)
        w  = self.attn(torch.cat([fr, fm, fh], dim=-1))  # (B,3)
        fused = w[:,0:1]*fr + w[:,1:2]*fm + w[:,2:3]*fh  # (B,fd)
        return self.head(fused)


# Sanity check
print('Running model sanity check...')
_m = MultimodalNet(pretrained=False)
_o = _m(
    torch.randn(2, 3,            cfg.IMG_SIZE, cfg.IMG_SIZE),
    torch.randn(2, cfg.MS_BANDS, cfg.IMG_SIZE, cfg.IMG_SIZE),
    torch.randn(2, cfg.HS_BANDS, cfg.IMG_SIZE, cfg.IMG_SIZE),
)
assert _o.shape == (2, cfg.NUM_CLASSES), f'Unexpected shape: {_o.shape}'
print(f'Model OK — output shape: {_o.shape}')
del _m, _o


# ──────────────────────────────────────────
# 7.  LOSS / MIXUP / TRAIN / EVAL
# ──────────────────────────────────────────
class LabelSmoothCE(nn.Module):
    def __init__(self, s=0.1):
        super().__init__(); self.s = s
    def forward(self, p, t):
        lp = F.log_softmax(p, dim=-1)
        nll    = -lp.gather(1, t.unsqueeze(1)).squeeze(1)
        smooth = -lp.mean(-1)
        return ((1 - self.s)*nll + self.s*smooth).mean()


def mixup(rgb, ms, hs, y, alpha=0.4):
    lam = np.random.beta(alpha, alpha)
    i   = torch.randperm(rgb.size(0), device=rgb.device)
    return (lam*rgb + (1-lam)*rgb[i],
            lam*ms  + (1-lam)*ms[i],
            lam*hs  + (1-lam)*hs[i],
            y, y[i], lam)


def train_epoch(model, loader, opt, crit, dev):
    model.train()
    tl = tc = tt = 0
    for rgb, ms, hs, y in loader:
        rgb, ms, hs, y = rgb.to(dev), ms.to(dev), hs.to(dev), y.to(dev)
        rgb, ms, hs, ya, yb, lam = mixup(rgb, ms, hs, y)
        out  = model(rgb, ms, hs)
        loss = lam*crit(out, ya) + (1-lam)*crit(out, yb)
        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        opt.step()
        tl += loss.item() * rgb.size(0)
        tc += (out.argmax(1) == y).sum().item()
        tt += rgb.size(0)
    return tl/tt, tc/tt


@torch.no_grad()
def eval_epoch(model, loader, crit, dev):
    model.eval()
    tl = tc = tt = 0
    for rgb, ms, hs, y in loader:
        rgb, ms, hs, y = rgb.to(dev), ms.to(dev), hs.to(dev), y.to(dev)
        out  = model(rgb, ms, hs)
        loss = crit(out, y)
        tl += loss.item() * rgb.size(0)
        tc += (out.argmax(1) == y).sum().item()
        tt += rgb.size(0)
    return tl/tt, tc/tt


# ──────────────────────────────────────────
# 8.  INFERENCE HELPER (single sample, with TTA)
# ──────────────────────────────────────────
def predict_sample(model, s, tta_steps, dev):
    """Run TTA inference on one val sample. Returns (tta_steps, 3) prob array."""
    sz = cfg.IMG_SIZE
    rr = resize_arr(load_rgb(s['rgb']), sz)
    mr = resize_arr(load_tif(s['ms']), sz)
    hr = resize_arr(load_tif(s['hs'], start=cfg.HS_START, end=cfg.HS_END), sz)

    geo_val   = make_geo_transforms('val',   sz)
    geo_train = make_geo_transforms('train', sz)
    rgb_aug   = make_rgb_extra('train')
    dropout   = make_dropout('train', sz)

    probs = []
    for t in range(tta_steps):
        if t == 0:
            # Clean pass (no augmentation)
            rgb_a = geo_val(image=rr)['image']
            ms_a  = geo_val(image=mr)['image']
            hs_a  = geo_val(image=hr)['image']
        else:
            # Random augmented pass
            rgb_a = dropout(image=rgb_aug(image=geo_train(image=rr)['image'])['image'])['image']
            ms_a  = dropout(image=geo_train(image=mr)['image'])['image']
            hs_a  = dropout(image=geo_train(image=hr)['image'])['image']

        rt = to_tensor_norm_rgb(rgb_a).unsqueeze(0).to(dev)
        mt = to_tensor_norm(ms_a).unsqueeze(0).to(dev)
        ht = to_tensor_norm(hs_a).unsqueeze(0).to(dev)

        with torch.no_grad():
            logits = model(rt, mt, ht)
        probs.append(F.softmax(logits, dim=-1).cpu().numpy())

    return np.mean(probs, axis=0)  # (1, 3)


# ──────────────────────────────────────────
# 9.  LOAD SAMPLES
# ──────────────────────────────────────────
train_samples = build_train_samples()
val_samples   = build_val_samples()
val_ids       = [s['id'] for s in val_samples]
print(f'\nReady — Train: {len(train_samples)}  |  Val/Test: {len(val_samples)}')


# ──────────────────────────────────────────
# 10.  TRAIN 3 SEEDS + TTA INFERENCE
# ──────────────────────────────────────────
all_probs = []

for run, seed in enumerate(cfg.SEEDS):
    print(f'\n{"="*52}')
    print(f'  RUN {run+1}/{len(cfg.SEEDS)}   seed={seed}')
    print(f'{"="*52}')
    seed_all(seed)

    # 90/10 internal split for best-epoch selection
    idx = list(range(len(train_samples)))
    random.shuffle(idx)
    cut = max(1, int(0.1 * len(idx)))
    hv  = [train_samples[i] for i in idx[:cut]]
    tr  = [train_samples[i] for i in idx[cut:]]

    tr_loader = DataLoader(WheatDataset(tr, 'train'),
                           batch_size=cfg.BATCH_SIZE, shuffle=True,
                           num_workers=2, pin_memory=True)
    hv_loader = DataLoader(WheatDataset(hv, 'train'),
                           batch_size=cfg.BATCH_SIZE, shuffle=False,
                           num_workers=2, pin_memory=True)

    model = MultimodalNet(pretrained=True).to(DEVICE)
    crit  = LabelSmoothCE(cfg.LABEL_SMOOTH)
    opt   = AdamW(model.parameters(), lr=cfg.LR, weight_decay=cfg.WEIGHT_DECAY)
    sched = CosineAnnealingLR(opt, T_max=cfg.EPOCHS, eta_min=1e-6)

    best_acc = 0.0
    ckpt     = f'/kaggle/working/ckpt_seed{seed}.pt'

    for ep in range(cfg.EPOCHS):
        tl, ta = train_epoch(model, tr_loader, opt, crit, DEVICE)
        hl, ha = eval_epoch(model,  hv_loader,     crit, DEVICE)
        sched.step()
        flag = ''
        if ha > best_acc:
            best_acc = ha
            torch.save(model.state_dict(), ckpt)
            flag = '  ✓'
        print(f'  Ep{ep+1:02d}/{cfg.EPOCHS} | '
              f'tr {tl:.3f}/{ta:.4f}  hv {hl:.3f}/{ha:.4f}{flag}')

    print(f'  Best hv acc: {best_acc:.4f}')

    # Reload best checkpoint
    model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
    model.eval()

    # TTA inference on val/test set
    run_probs = []
    for s in val_samples:
        p = predict_sample(model, s, cfg.TTA_STEPS, DEVICE)
        run_probs.append(p)

    run_probs = np.concatenate(run_probs, axis=0)  # (N, 3)
    all_probs.append(run_probs)
    print(f'  Inference done for seed {seed}.')

print('\nAll seeds done.')


# ──────────────────────────────────────────
# 11.  ENSEMBLE + SUBMISSION
# ──────────────────────────────────────────
ens_probs = np.mean(all_probs, axis=0)
ens_preds = ens_probs.argmax(axis=1)

print('\nPrediction distribution on val/test set:')
for i, cls in enumerate(cfg.CLASSES):
    print(f'  {cls:8s}: {(ens_preds == i).sum()}')

sub = pd.DataFrame({
    'Id':       val_ids,
    'Category': [cfg.IDX2CLS[p] for p in ens_preds]
})
sub.to_csv('/kaggle/working/submission.csv', index=False)
print(f'\n✅ submission.csv saved  ({len(sub)} rows)')
print(sub.head(10).to_string(index=False))

albumentations version: 2.0.8
timm version: 1.0.24
Device: cuda
Running model sanity check...
Model OK — output shape: torch.Size([2, 3])
Train samples: 600  (skipped 0)
  Health  : 200
  Rust    : 200
  Other   : 200
Val/Test samples: 300

Ready — Train: 600  |  Val/Test: 300

  RUN 1/3   seed=42




  Ep01/30 | tr 1.071/0.3556  hv 1.078/0.4333  ✓
  Ep02/30 | tr 0.983/0.4056  hv 1.025/0.5000  ✓
  Ep03/30 | tr 0.976/0.4907  hv 0.874/0.6500  ✓
  Ep04/30 | tr 0.914/0.5259  hv 0.827/0.6167
  Ep05/30 | tr 0.892/0.5241  hv 0.846/0.6333
  Ep06/30 | tr 0.887/0.4926  hv 0.887/0.6000
  Ep07/30 | tr 0.897/0.5296  hv 0.868/0.6333
  Ep08/30 | tr 0.858/0.5593  hv 0.850/0.5833
  Ep09/30 | tr 0.884/0.4870  hv 0.893/0.5667
  Ep10/30 | tr 0.753/0.5352  hv 0.911/0.6500
  Ep11/30 | tr 0.769/0.5426  hv 0.866/0.6833  ✓
  Ep12/30 | tr 0.747/0.4222  hv 0.895/0.6333
  Ep13/30 | tr 0.791/0.6574  hv 0.880/0.6333
  Ep14/30 | tr 0.817/0.5648  hv 0.920/0.6500
  Ep15/30 | tr 0.718/0.5759  hv 0.912/0.6167
  Ep16/30 | tr 0.798/0.4833  hv 0.882/0.6500
  Ep17/30 | tr 0.723/0.5389  hv 0.909/0.6000
  Ep18/30 | tr 0.677/0.5389  hv 0.937/0.6333
  Ep19/30 | tr 0.707/0.4852  hv 0.932/0.6000
  Ep20/30 | tr 0.690/0.6074  hv 0.956/0.6167
  Ep21/30 | tr 0.717/0.4981  hv 0.942/0.6000
  Ep22/30 | tr 0.705/0.5889  hv 0.935/0.633