In [None]:
import os, re, random
from dataclasses import dataclass
from typing import Dict, Optional, List, Tuple

import numpy as np
import pandas as pd

import cv2
import tifffile as tiff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm
import wandb

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
import wandb, time
wandb.login()


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: [wandb.login()] Using explicit session credentials for https://api.wandb.ai.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnguyenthu-m462004[0m ([33mnguyenthu-m462004-fpt-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
!unzip "/content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared.zip" -d /content/data

Archive:  /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared.zip
   creating: /content/data/Kaggle_Prepared/
   creating: /content/data/Kaggle_Prepared/train/
   creating: /content/data/Kaggle_Prepared/train/HS/
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_1.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_10.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_100.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_101.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_102.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_103.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_104.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_105.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_106.tif  
  inflating: /content/data/Kaggle_Prepared/train/HS/Health_hyper_107.tif  
  inflating: /content/data/Kaggle_P

In [None]:
@dataclass
class CFG:
    ROOT: str = "/content/data/Kaggle_Prepared"
    TRAIN_DIR: str = "train"
    VAL_DIR: str = "val"

    USE_RGB: bool = True
    USE_MS: bool  = True
    USE_HS: bool  = True

    IMG_SIZE: int = 224

    BATCH_SIZE: int = 32
    EPOCHS: int = 50
    LR: float = 3e-4
    WD: float = 1e-4

    NUM_WORKERS: int = 1
    SEED: int = 3557

    RGB_BACKBONE: str = "resnet_b18"
    AMP: bool = True
    SEEDS = [0, 42, 123, 2026, 999]

    HS_BANDS = 125
    HS_DROP_FIRST = 10
    HS_DROP_LAST  = 14
    HS_MEAN: tuple = None
    HS_STD:  tuple = None

    MS_MEAN: tuple = (0.0, 0.0, 0.0, 0.0, 0.0)
    MS_STD:  tuple = (1.0, 1.0, 1.0, 1.0, 1.0)

    OUT_DIR: str = "/content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output"
    BEST_CKPT: str = "/content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/best.pt"


LABELS = ["Health", "Rust", "Other"]
LBL2ID = {k: i for i, k in enumerate(LABELS)}
ID2LBL = {i: k for k, i in LBL2ID.items()}

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def list_files(folder: str, exts: Tuple[str, ...]) -> List[str]:
    if not os.path.isdir(folder):
        return []
    out = []
    for fn in os.listdir(folder):
        if fn.lower().endswith(exts):
            out.append(os.path.join(folder, fn))
    return sorted(out)

def base_id(path: str) -> str:
    return os.path.splitext(os.path.basename(path))[0]

def parse_label_from_train_name(bid: str) -> Optional[str]:
    m = re.match(r"^(Health|Rust|Other)_", bid)
    return m.group(1) if m else None

def build_index(root: str, split: str) -> Dict[str, Dict[str, str]]:
    split_dir = os.path.join(root, split)
    rgb_dir = os.path.join(split_dir, "RGB")
    ms_dir  = os.path.join(split_dir, "MS")
    hs_dir  = os.path.join(split_dir, "HS")

    rgb_files = list_files(rgb_dir, (".png", ".jpg", ".jpeg"))
    ms_files  = list_files(ms_dir, (".tif", ".tiff"))
    hs_files  = list_files(hs_dir, (".tif", ".tiff"))

    idx: Dict[str, Dict[str, str]] = {}
    for p in rgb_files:
        idx.setdefault(base_id(p), {})["rgb"] = p
    for p in ms_files:
        idx.setdefault(base_id(p), {})["ms"] = p
    for p in hs_files:
        idx.setdefault(base_id(p), {})["hs"] = p
    return idx

def make_train_df(train_idx: Dict[str, Dict[str, str]]) -> pd.DataFrame:
    rows = []
    for bid, paths in train_idx.items():
        lab = parse_label_from_train_name(bid)
        if lab is None:
            continue
        rows.append({
            "base_id": bid,
            "label": lab,
            "rgb": paths.get("rgb"),
            "ms":  paths.get("ms"),
            "hs":  paths.get("hs"),
        })
    return pd.DataFrame(rows)

def make_val_df(val_idx: Dict[str, Dict[str, str]]) -> pd.DataFrame:
    rows = []
    for bid, paths in val_idx.items():
        rows.append({
            "base_id": bid,
            "rgb": paths.get("rgb"),
            "ms":  paths.get("ms"),
            "hs":  paths.get("hs"),
        })
    return pd.DataFrame(rows)

def stratified_holdout(df: pd.DataFrame, frac: float = 0.15, seed: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame]:
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    parts = []
    for lab, g in df.groupby("label"):
        n = max(1, int(len(g) * frac))
        parts.append(g.iloc[:n])
    df_va = pd.concat(parts).drop_duplicates("base_id")
    df_tr = df[~df["base_id"].isin(df_va["base_id"])].reset_index(drop=True)
    df_va = df_va.reset_index(drop=True)
    return df_tr, df_va

In [None]:
def compute_band_stats(df, modality, read_fn):
    sums, sqs, cnt = None, None, 0

    for p in df[modality]:
        x = read_fn(p)          # (C,H,W), float32
        x = x.view(x.shape[0], -1)

        if sums is None:
            sums = x.sum(dim=1)
            sqs  = (x ** 2).sum(dim=1)
        else:
            sums += x.sum(dim=1)
            sqs  += (x ** 2).sum(dim=1)

        cnt += x.shape[1]

    mean = sums / cnt
    std  = torch.sqrt(sqs / cnt - mean ** 2)
    std[std < 1e-6] = 1.0
    return mean, std


In [None]:
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

# def fix_channels(x: torch.Tensor, target_c: int) -> torch.Tensor:
#     # x: (C,H,W) -> (target_c,H,W) by crop or zero-pad
#     c, h, w = x.shape
#     if c == target_c:
#         return x
#     if c > target_c:
#         return x[:target_c]
#     pad = torch.zeros((target_c - c, h, w), dtype=x.dtype)
#     return torch.cat([x, pad], dim=0)

def read_rgb(path):
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    return torch.from_numpy(img).permute(2, 0, 1)

def read_tiff_multiband(path: str) -> np.ndarray:
    arr = tiff.imread(path)  # (H,W,C) or (C,H,W)
    if arr.ndim != 3:
        raise ValueError(f"Expected 3D TIFF, got {arr.shape} for {path}")
    if arr.shape[0] < arr.shape[1] and arr.shape[0] < arr.shape[2]:
        arr = np.transpose(arr, (1, 2, 0))  # -> (H,W,C)
    return arr

def normalize_per_band_minmax(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    H, W, C = x.shape
    flat = x.reshape(-1, C)
    mn = flat.min(axis=0)
    mx = flat.max(axis=0)
    denom = (mx - mn)
    denom[denom < eps] = 1.0
    x = (x - mn.reshape(1, 1, C)) / denom.reshape(1, 1, C)
    return np.clip(x, 0.0, 1.0)

def read_ms(path, mean, std):
    arr = read_tiff_multiband(path).astype(np.float32)
    x = torch.from_numpy(arr).permute(2, 0, 1)
    x = (x - mean[:, None, None]) / std[:, None, None]
    return x

def read_hs(path, drop_first, drop_last, mean, std, target_ch):
    arr = read_tiff_multiband(path).astype(np.float32)  # (H,W,B)
    B = arr.shape[2]
    if drop_first + drop_last < B:
        arr = arr[:, :, drop_first:B - drop_last]

    x = torch.from_numpy(arr).permute(2, 0, 1)  # (C,H,W)
    C = x.shape[0]
    if C > target_ch: # solve probs about mismatch chanel
        x = x[:target_ch]
    elif C < target_ch:
        pad = torch.zeros(target_ch - C, *x.shape[1:], dtype=x.dtype)
        x = torch.cat([x, pad], dim=0)
    x = (x - mean[:, None, None]) / std[:, None, None]
    return x

def resize_tensor(x: torch.Tensor, size: int) -> torch.Tensor:
    # x: (C,H,W) -> (C,size,size)
    return F.interpolate(x.unsqueeze(0), size=(size, size), mode="bilinear", align_corners=False).squeeze(0)

def apply_joint_aug(x_rgb, x_ms, x_hs):
    k = random.randint(0, 3)
    do_h = random.random() < 0.5
    do_v = random.random() < 0.5

    def _tf(x):
        if x is None:
            return None
        if k:
            x = torch.rot90(x, k, dims=(1, 2))
        if do_h:
            x = torch.flip(x, dims=(2,))
        if do_v:
            x = torch.flip(x, dims=(1,))
        return x

    return _tf(x_rgb), _tf(x_ms), _tf(x_hs)

# def infer_hs_in_ch(df_train: pd.DataFrame, df_val: pd.DataFrame, cfg: CFG) -> int:
#     for df in (df_train, df_val):
#         if "hs" in df.columns:
#             for p in df["hs"].dropna().tolist():
#                 if p and os.path.exists(p):
#                     x = read_hs(p, cfg.HS_DROP_FIRST, cfg.HS_DROP_LAST)
#                     return int(x.shape[0])
#     return 101

In [None]:
class WheatMultiModalDataset(Dataset):
    def __init__(self, df: pd.DataFrame, cfg: CFG, train: bool):
        self.df = df.reset_index(drop=True)
        self.cfg = cfg
        self.train = train
        if cfg.USE_MS:
                    self.ms_mean = torch.tensor(cfg.MS_MEAN, dtype=torch.float32)
                    self.ms_std  = torch.tensor(cfg.MS_STD,  dtype=torch.float32)
        if cfg.USE_HS:
            hs_ch = cfg.HS_BANDS - cfg.HS_DROP_FIRST - cfg.HS_DROP_LAST

            if cfg.HS_MEAN is None:
                self.hs_mean = torch.zeros(hs_ch, dtype=torch.float32)
                self.hs_std  = torch.ones(hs_ch, dtype=torch.float32)
                self.hs_target_ch = hs_ch
            else:
                self.hs_mean = torch.tensor(cfg.HS_MEAN, dtype=torch.float32)
                self.hs_std  = torch.tensor(cfg.HS_STD,  dtype=torch.float32)
                self.hs_target_ch = hs_ch

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i: int):
        row = self.df.iloc[i]
        bid = row["base_id"]

        x_rgb = x_ms = x_hs = None
        m_rgb = m_ms = m_hs = 0.0

        if self.cfg.USE_RGB and row.get("rgb") is not None:
            x_rgb = read_rgb(row["rgb"])
            x_rgb = resize_tensor(x_rgb, self.cfg.IMG_SIZE)
            m_rgb = 1.0

        if self.cfg.USE_MS and row.get("ms") is not None:
            x_ms = read_ms(
                row["ms"],
                mean=self.ms_mean,
                std=self.ms_std
            )
            x_ms = resize_tensor(x_ms, self.cfg.IMG_SIZE)
            m_ms = 1.0


        if self.cfg.USE_HS and isinstance(row.get("hs"), str) and row["hs"]:
          x_hs = read_hs(
              row["hs"],
              self.cfg.HS_DROP_FIRST,
              self.cfg.HS_DROP_LAST,
              mean=self.hs_mean,
              std=self.hs_std,
              target_ch=self.hs_target_ch
          )

          x_hs = resize_tensor(x_hs, self.cfg.IMG_SIZE)
          m_hs = 1.0

        if self.train:
            x_rgb, x_ms, x_hs = apply_joint_aug(x_rgb, x_ms, x_hs)

        if self.cfg.USE_RGB and x_rgb is None:
            x_rgb = torch.zeros(3, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE, dtype=torch.float32)
        if self.cfg.USE_MS and x_ms is None:
            x_ms = torch.zeros(5, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE, dtype=torch.float32)
        if self.cfg.USE_HS and x_hs is None:
            hs_ch = self.cfg.HS_BANDS - self.cfg.HS_DROP_FIRST - self.cfg.HS_DROP_LAST
            x_hs = torch.zeros(hs_ch, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE, dtype=torch.float32)

        mask = torch.tensor([m_rgb, m_ms, m_hs], dtype=torch.float32)

        if "label" in row:
            y = LBL2ID[row["label"]]
            return {"id": bid, "rgb": x_rgb, "ms": x_ms, "hs": x_hs, "mask": mask, "y": torch.tensor(y, dtype=torch.long)}
        else:
            return {"id": bid, "rgb": x_rgb, "ms": x_ms, "hs": x_hs, "mask": mask}

In [None]:
# class SmallSpectralEncoder(nn.Module):
#     def __init__(self, in_ch: int, emb_dim: int = 256):
#         super().__init__()
#         self.stem = nn.Sequential(
#             nn.Conv2d(in_ch, 32, kernel_size=1, bias=False),
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),
#         )
#         self.block = nn.Sequential(
#             nn.Conv2d(32, 64, 3, padding=1, bias=False),
#             nn.BatchNorm2d(64),
#             nn.ReLU(inplace=True),

#             nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),

#             nn.Conv2d(128, 128, 3, padding=1, bias=False),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),
#         )
#         self.head = nn.Sequential(
#             nn.AdaptiveAvgPool2d(1),
#             nn.Flatten(),
#             nn.Linear(128, emb_dim),
#             nn.ReLU(inplace=True),
#         )

#     def forward(self, x):
#         x = self.stem(x)
#         x = self.block(x)
#         return self.head(x)

# class MultiModalNet(nn.Module):
#     def __init__(self, cfg: CFG, hs_in_ch: int, n_classes: int = 3):
#         super().__init__()

#         self.use_rgb = cfg.USE_RGB
#         self.use_ms  = cfg.USE_MS
#         self.use_hs  = cfg.USE_HS

#         feat_dims = []

#         if self.use_rgb:
#             self.rgb_enc = timm.create_model(
#                 "vit_base_patch16_224",
#                 pretrained=True,
#                 num_classes=0
#             )
#             for p in self.rgb_enc.parameters(): #freeze
#                 p.requires_grad = False

#             feat_dims.append(self.rgb_enc.num_features)


#         if self.use_ms:
#             self.ms_enc = SmallSpectralEncoder(5, 256)
#             feat_dims.append(256)

#         if self.use_hs:
#             self.hs_enc = SmallSpectralEncoder(hs_in_ch, 256)
#             feat_dims.append(256)

#         self.classifier = nn.Sequential(
#             nn.Linear(sum(feat_dims), 512),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(512, n_classes),
#         )

#     def forward(self, rgb, ms, hs, mask):
#         feats = []
#         if self.use_rgb:
#             feats.append(self.rgb_enc(rgb) * mask[:, 0:1])
#         if self.use_ms:
#             feats.append(self.ms_enc(ms) * mask[:, 1:2])
#         if self.use_hs:
#             feats.append(self.hs_enc(hs) * mask[:, 2:3])

#         f = torch.cat(feats, dim=1)
#         return self.classifier(f)

In [None]:
class SmallSpectralEncoder(nn.Module):
    def __init__(self, in_ch: int, emb_dim: int = 256):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, 32, kernel_size=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.block = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, emb_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.block(x)
        return self.head(x)

class MultiModalNet(nn.Module):
    def __init__(self, cfg: CFG, hs_in_ch: int, n_classes: int = 3):
        super().__init__()

        self.use_rgb = cfg.USE_RGB
        self.use_ms  = cfg.USE_MS
        self.use_hs  = cfg.USE_HS

        feat_dims = []

        # ===== RGB ResNet =====
        if self.use_rgb:
            self.rgb_enc = timm.create_model(
                "resnet18",
                pretrained=True,
                num_classes=0,
                global_pool="avg"
            )

            for p in self.rgb_enc.parameters():
                p.requires_grad = False
            for p in self.rgb_enc.layer4.parameters():
                p.requires_grad = True
            rgb_dim = self.rgb_enc.num_features



            self.rgb_norm = nn.LayerNorm(rgb_dim)
            feat_dims.append(rgb_dim)

        # ===== MS =====
        if self.use_ms:
            self.ms_enc = SmallSpectralEncoder(5, 256)
            self.ms_norm = nn.LayerNorm(256)
            feat_dims.append(256)

        if self.use_hs:
            self.hs_enc = SmallSpectralEncoder(hs_in_ch, 256)
            self.hs_norm = nn.LayerNorm(256)
            feat_dims.append(256)

        fusion_dim = sum(feat_dims)

        self.gate = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )

        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, n_classes),
        )

    def forward(self, rgb, ms, hs, mask):
        feats = []

        if self.use_rgb:
            f_rgb = self.rgb_norm(self.rgb_enc(rgb))
            feats.append(f_rgb * mask[:, 0:1])

        if self.use_ms:
            f_ms = self.ms_norm(self.ms_enc(ms))
            feats.append(f_ms * mask[:, 1:2])

        if self.use_hs:
            f_hs = self.hs_norm(self.hs_enc(hs))
            feats.append(f_hs * mask[:, 2:3])

        f = torch.cat(feats, dim=1)
        f = f * self.gate(f)

        return self.classifier(f)


In [None]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct = 0, 0
    conf = np.zeros((3, 3), dtype=np.int64)

    for batch in loader:
        rgb  = batch["rgb"].to(device)
        ms   = batch["ms"].to(device)
        hs   = batch["hs"].to(device)
        mask = batch["mask"].to(device)
        y    = batch["y"].to(device)

        logits = model(rgb, ms, hs, mask)
        pred = logits.argmax(dim=1)

        total += y.size(0)
        correct += (pred == y).sum().item()

        yt = y.cpu().numpy()
        yp = pred.cpu().numpy()
        for t, p in zip(yt, yp):
            conf[t, p] += 1

    acc = correct / max(1, total)

    f1s = []
    for c in range(3):
        tp = conf[c, c]
        fp = conf[:, c].sum() - tp
        fn = conf[c, :].sum() - tp
        prec = tp / max(1, (tp + fp))
        rec  = tp / max(1, (tp + fn))
        f1 = 0.0 if (prec + rec) == 0 else (2 * prec * rec / (prec + rec))
        f1s.append(f1)

    return {"acc": float(acc), "macro_f1": float(np.mean(f1s))}

def train_one_epoch(model, loader, optimizer, scaler, device):
    model.train()
    total_loss, n = 0.0, 0

    for batch in loader:
        rgb  = batch["rgb"].to(device)
        ms   = batch["ms"].to(device)
        hs   = batch["hs"].to(device)
        mask = batch["mask"].to(device)
        y    = batch["y"].to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type=device.type, enabled=(scaler is not None)):
            logits = model(rgb, ms, hs, mask)
            loss = F.cross_entropy(logits, y)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        bs = y.size(0)
        total_loss += loss.item() * bs
        n += bs

    return total_loss / max(1, n)

@torch.no_grad()
def predict(model, loader, device):
    model.eval()
    preds = []
    ids = []
    for batch in loader:
        rgb  = batch["rgb"].to(device)
        ms   = batch["ms"].to(device)
        hs   = batch["hs"].to(device)
        mask = batch["mask"].to(device)

        logits = model(rgb, ms, hs, mask)
        p = logits.argmax(dim=1).cpu().numpy().tolist()
        preds.extend([ID2LBL[x] for x in p])
        ids.extend(batch["id"])
    return ids, preds

In [None]:
# cfg = CFG()
# wandb.init(
#     project="wheat-multimodal-vit",
#     name=f"ViT-RGB_CNN-MS-HS_bs{cfg.BATCH_SIZE}_lr{cfg.LR}",
#     config=cfg.__dict__,
#     tags=["ViT", "Multimodal", "RGB-MS-HS", "CNN-vs-ViT"]
# )

# seed_everything(cfg.SEED)
# os.makedirs(cfg.OUT_DIR, exist_ok=True)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print("Device:", device)
# print("Backbone:", cfg.RGB_BACKBONE)

# train_idx = build_index(cfg.ROOT, cfg.TRAIN_DIR)
# val_idx   = build_index(cfg.ROOT, cfg.VAL_DIR)

# train_df = make_train_df(train_idx)
# val_df   = make_val_df(val_idx)

# print(f"Indexed train IDs: {len(train_idx)} | usable labeled train rows: {len(train_df)}")
# print(f"Indexed val IDs:   {len(val_idx)}   | val rows: {len(val_df)}")

# if len(train_df) == 0:
#     raise RuntimeError("No training samples found. Check ROOT/train and filename label pattern (Health_/Rust_/Other_).")

# hs_in_ch = cfg.HS_BANDS - cfg.HS_DROP_FIRST - cfg.HS_DROP_LAST
# print("HS channels after trimming:", hs_in_ch)

# df_tr, df_va = stratified_holdout(train_df, frac=0.1, seed=cfg.SEED)
# print(f"Train split: {len(df_tr)} | Holdout split: {len(df_va)}")

# ds_tr = WheatMultiModalDataset(df_tr, cfg, train=True)
# ds_va = WheatMultiModalDataset(df_va, cfg, train=False)
# ds_te = WheatMultiModalDataset(val_df, cfg, train=False)

# dl_tr = DataLoader(ds_tr, batch_size=cfg.BATCH_SIZE, shuffle=True,
#                    num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True)
# dl_va = DataLoader(ds_va, batch_size=cfg.BATCH_SIZE, shuffle=False,
#                    num_workers=cfg.NUM_WORKERS, pin_memory=True)
# dl_te = DataLoader(ds_te, batch_size=cfg.BATCH_SIZE, shuffle=False,
#                    num_workers=cfg.NUM_WORKERS, pin_memory=True)
# sample = ds_tr[0]
# print(sample["hs"].shape)

# model = MultiModalNet(cfg, hs_in_ch=hs_in_ch, n_classes=3).to(device)
# wandb.watch(
#     model,
#     log="gradients",
#     log_freq=100
# )
# optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),lr=cfg.LR,weight_decay=cfg.WD)
# scaler = torch.cuda.amp.GradScaler(enabled=(cfg.AMP and device.type == "cuda"))

# best_f1 = -1.0
# best_path = os.path.join(cfg.OUT_DIR, cfg.BEST_CKPT)

# for ep in range(1, cfg.EPOCHS + 1):
#     tr_loss = train_one_epoch(model, dl_tr, optimizer,scaler if scaler.is_enabled() else None,device)
#     metrics = evaluate(model, dl_va, device)
#     wandb.log({"epoch": ep, "train/loss": tr_loss,"val/acc": metrics["acc"],"val/macro_f1": metrics["macro_f1"],"lr": optimizer.param_groups[0]["lr"]})

#     print(f"Epoch {ep:02d}, loss={tr_loss:.4f}, val_acc={metrics['acc']:.4f}, val_macro_F1={metrics['macro_f1']:.4f}")

#     if metrics["macro_f1"] > best_f1:
#         best_f1 = metrics["macro_f1"]
#         torch.save({"model": model.state_dict(), "hs_in_ch": hs_in_ch, "cfg": cfg.__dict__},best_path)
#         print(f"  -> saved best to {best_path}")

# wandb.run.summary["best_macro_f1"] = best_f1
# wandb.run.summary["backbone"] = cfg.RGB_BACKBONE
# wandb.run.summary["modalities"] = "RGB + MS + HS"

# ckpt = torch.load(best_path, map_location=device)
# model.load_state_dict(ckpt["model"], strict=True)

# pred_ids, pred_labels = predict(model, dl_te, device)


# sub_ids = []
# for _, r in val_df.iterrows():
#     if isinstance(r.get("hs"), str) and r.get("hs"):
#         sub_ids.append(os.path.basename(r["hs"]))
#     elif isinstance(r.get("ms"), str) and r.get("ms"):
#         sub_ids.append(os.path.basename(r["ms"]))
#     else:
#         sub_ids.append(os.path.basename(r["rgb"]))
# wandb.finish()

In [None]:
!pip -q install timm wandb scikit-learn

In [None]:
import os
import torch
import numpy as np
import wandb, time
from torch.utils.data import DataLoader

# ======================
# seeds to run
# ======================
SEEDS = [0, 42, 123, 2024, 999]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

all_best_f1 = []

# ======================
# loop over seeds
# ======================
for seed in SEEDS:
    print(f"\n========== RUN WITH SEED {seed} ==========")

    cfg = CFG()
    cfg.SEED = seed

    seed_everything(cfg.SEED)
    os.makedirs(cfg.OUT_DIR, exist_ok=True)

    wandb.init(
        project="wheat-multimodal-resnet",
        name=f"ResNet-RGB_CNN-MS-HS_seed{seed}_bs{cfg.BATCH_SIZE}_lr{cfg.LR}-{time.strftime('%Y%m%d-%H%M%S')}",
        config=cfg.__dict__,
        tags=["ResNet", "Multimodal", "RGB-MS-HS", "CNN-vs-ResNet"]
    )

    print("Backbone:", cfg.RGB_BACKBONE)

    # ======================
    # build index & dataframe
    # ======================
    train_idx = build_index(cfg.ROOT, cfg.TRAIN_DIR)
    val_idx   = build_index(cfg.ROOT, cfg.VAL_DIR)

    train_df = make_train_df(train_idx)
    val_df   = make_val_df(val_idx)

    print(f"Indexed train IDs: {len(train_idx)} | usable labeled train rows: {len(train_df)}")
    print(f"Indexed val IDs:   {len(val_idx)}   | val rows: {len(val_df)}")

    if len(train_df) == 0:
        raise RuntimeError("No training samples found.")

    # ======================
    # HS channels
    # ======================
    hs_in_ch = cfg.HS_BANDS - cfg.HS_DROP_FIRST - cfg.HS_DROP_LAST
    print("HS channels after trimming:", hs_in_ch)

    # ======================
    # stratified split (THIS is where seed matters)
    # ======================
    df_tr, df_va = stratified_holdout(train_df, frac=0.15, seed=cfg.SEED)
    print(f"Train split: {len(df_tr)} | Holdout split: {len(df_va)}")

    # ======================
    # datasets & loaders
    # ======================
    ds_tr = WheatMultiModalDataset(df_tr, cfg, train=True)
    ds_va = WheatMultiModalDataset(df_va, cfg, train=False)
    ds_te = WheatMultiModalDataset(val_df, cfg, train=False)

    dl_tr = DataLoader(
        ds_tr,
        batch_size=cfg.BATCH_SIZE,
        shuffle=True,
        num_workers=cfg.NUM_WORKERS,
        pin_memory=True,
        drop_last=True
    )

    dl_va = DataLoader(
        ds_va,
        batch_size=cfg.BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.NUM_WORKERS,
        pin_memory=True
    )

    dl_te = DataLoader(
        ds_te,
        batch_size=cfg.BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.NUM_WORKERS,
        pin_memory=True
    )

    # ======================
    # model
    # ======================
    model = MultiModalNet(
        cfg,
        hs_in_ch=hs_in_ch,
        n_classes=3
    ).to(device)

    wandb.watch(model, log="gradients", log_freq=100)

    resnet_params = []
    other_params = []

    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue

        if "rgb_enc" in name:
            resnet_params.append(p)
        else:
            other_params.append(p)

    optimizer = torch.optim.AdamW(
        [
            {"params": resnet_params, "lr": cfg.LR * 0.1},
            {"params": other_params, "lr": cfg.LR},
        ],
        weight_decay=cfg.WD
    )


    scaler = torch.amp.GradScaler(
        enabled=(cfg.AMP and device.type == "cuda")
    )

    # ======================
    # training loop
    # ======================
    best_f1 = -1.0
    best_path = os.path.join(cfg.OUT_DIR, f"best_seed{seed}.pt")
    bad_epochs = 0
    patience = 7

    for ep in range(1, cfg.EPOCHS + 1):
        tr_loss = train_one_epoch(
            model,
            dl_tr,
            optimizer,
            scaler if scaler.is_enabled() else None,
            device
        )

        metrics = evaluate(model, dl_va, device)

        wandb.log({
            "epoch": ep,
            "train/loss": tr_loss,
            "val/acc": metrics["acc"],
            "val/macro_f1": metrics["macro_f1"],
            "lr": optimizer.param_groups[0]["lr"]
        })

        print(
            f"Epoch {ep:02d} | loss={tr_loss:.4f} | val_F1={metrics['macro_f1']:.4f} | "
            f"Best F1: {best_f1:.4f} | Bad Epochs: {bad_epochs}/{patience}"
        )

        if metrics["macro_f1"] > best_f1:
            best_f1 = metrics["macro_f1"]
            bad_epochs = 0 # Reset bad_epochs if performance improves
            torch.save(
                {   'epoch': ep,
                    "model": model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    "hs_in_ch": hs_in_ch,
                    "cfg": cfg.__dict__,
                    "seed": seed
                },
                best_path
            )
            print(f"  -> saved best to {best_path}")
        else:
            bad_epochs += 1
        if bad_epochs >= patience:
            print("Early stopping")
            break

    # ======================
    # log summary
    # ======================
    wandb.run.summary["best_macro_f1"] = best_f1
    wandb.run.summary["seed"] = seed
    wandb.run.summary["backbone"] = cfg.RGB_BACKBONE
    wandb.run.summary["modalities"] = "RGB + MS + HS"

    all_best_f1.append(best_f1)

    wandb.finish()

# ======================
# final report
# ======================
mean_f1 = np.mean(all_best_f1)
std_f1  = np.std(all_best_f1)

print("\n========== FINAL RESULT ==========")
print(f"Macro-F1: {mean_f1:.4f} \u00b1 {std_f1:.4f}")

Device: cuda



Backbone: resnet_b18
Indexed train IDs: 600 | usable labeled train rows: 600
Indexed val IDs:   300   | val rows: 300
HS channels after trimming: 101
Train split: 510 | Holdout split: 90


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

Epoch 01 | loss=0.8925 | val_F1=0.5864 | Best F1: -1.0000 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed0.pt
Epoch 02 | loss=0.7876 | val_F1=0.6941 | Best F1: 0.5864 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed0.pt
Epoch 03 | loss=0.7400 | val_F1=0.7077 | Best F1: 0.6941 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed0.pt
Epoch 04 | loss=0.7225 | val_F1=0.6335 | Best F1: 0.7077 | Bad Epochs: 0/7
Epoch 05 | loss=0.6759 | val_F1=0.6611 | Best F1: 0.7077 | Bad Epochs: 1/7
Epoch 06 | loss=0.7094 | val_F1=0.6528 | Best F1: 0.7077 | Bad Epochs: 2/7
Epoch 07 | loss=0.6877 | val_F1=0.6239 | Best F1: 0.7077 | Bad Epochs: 3/7
Epoch 08 | loss=0.6197 | val_F1=0.6699 | Best F1: 0.7077 | Bad Epochs: 4/7
Epoch 09 | loss=0.6412 | val_F1=0.7069 | Best F1: 0.7077 | Bad Epochs: 5/7
Epoch 10 | l

0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train/loss,█▅▄▄▂▃▃▁▂▂
val/acc,▁▇█▅▆▆▂▅█▁
val/macro_f1,▂▇█▅▆▅▄▆█▁

0,1
backbone,resnet_b18
best_macro_f1,0.70769
epoch,10
lr,3e-05
modalities,RGB + MS + HS
seed,0
train/loss,0.65592
val/acc,0.61111
val/macro_f1,0.55476





Backbone: resnet_b18
Indexed train IDs: 600 | usable labeled train rows: 600
Indexed val IDs:   300   | val rows: 300
HS channels after trimming: 101
Train split: 510 | Holdout split: 90
Epoch 01 | loss=0.8735 | val_F1=0.5251 | Best F1: -1.0000 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed42.pt
Epoch 02 | loss=0.7723 | val_F1=0.6180 | Best F1: 0.5251 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed42.pt
Epoch 03 | loss=0.7364 | val_F1=0.6219 | Best F1: 0.6180 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed42.pt
Epoch 04 | loss=0.7035 | val_F1=0.5626 | Best F1: 0.6219 | Bad Epochs: 0/7
Epoch 05 | loss=0.7313 | val_F1=0.6398 | Best F1: 0.6219 | Bad Epochs: 1/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed42.pt
Epoch 06 | loss=0

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▆▅▄▅▃▃▂▂▂▁▂▁▁▁
val/acc,▄▃▅▁▇▂▅█▄▃▅▂▂▄▇
val/macro_f1,▁▅▆▃▆▃▇█▅▅▆▃▄▆█

0,1
backbone,resnet_b18
best_macro_f1,0.67478
epoch,15
lr,3e-05
modalities,RGB + MS + HS
seed,42
train/loss,0.5764
val/acc,0.7
val/macro_f1,0.66709





Backbone: resnet_b18
Indexed train IDs: 600 | usable labeled train rows: 600
Indexed val IDs:   300   | val rows: 300
HS channels after trimming: 101
Train split: 510 | Holdout split: 90
Epoch 01 | loss=0.8809 | val_F1=0.4935 | Best F1: -1.0000 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed123.pt
Epoch 02 | loss=0.7171 | val_F1=0.5226 | Best F1: 0.4935 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed123.pt
Epoch 03 | loss=0.7231 | val_F1=0.6632 | Best F1: 0.5226 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed123.pt
Epoch 04 | loss=0.7397 | val_F1=0.6534 | Best F1: 0.6632 | Bad Epochs: 0/7
Epoch 05 | loss=0.6615 | val_F1=0.6572 | Best F1: 0.6632 | Bad Epochs: 1/7
Epoch 06 | loss=0.6740 | val_F1=0.5717 | Best F1: 0.6632 | Bad Epochs: 2/7
Epoch 07 | loss=0.6306 | val_F1=0.6312 | Bes

0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train/loss,█▄▄▅▃▃▂▁▃▂
val/acc,▁▃███▄█▄▇▇
val/macro_f1,▁▂███▄▇▅▆▅

0,1
backbone,resnet_b18
best_macro_f1,0.66317
epoch,10
lr,3e-05
modalities,RGB + MS + HS
seed,123
train/loss,0.62301
val/acc,0.64444
val/macro_f1,0.5989





Backbone: resnet_b18
Indexed train IDs: 600 | usable labeled train rows: 600
Indexed val IDs:   300   | val rows: 300
HS channels after trimming: 101
Train split: 510 | Holdout split: 90
Epoch 01 | loss=0.8951 | val_F1=0.4969 | Best F1: -1.0000 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed2024.pt
Epoch 02 | loss=0.7602 | val_F1=0.6188 | Best F1: 0.4969 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed2024.pt
Epoch 03 | loss=0.7269 | val_F1=0.6681 | Best F1: 0.6188 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed2024.pt
Epoch 04 | loss=0.6974 | val_F1=0.6746 | Best F1: 0.6681 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed2024.pt
Epoch 05 | loss=0.6791 | val_F1=0.6623 | Best F1: 0.6746 | Bad Epochs: 0/7
Epoch 06 

0,1
epoch,▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▆▅▅▄▃▄▃▃▃▃▄▂▂▁▂▁▁
val/acc,▁▆▇▆▆▇▅▅▆▆█▆▆▆▇▆▆▅
val/macro_f1,▁▅▆▆▆▇▆▅▆▆█▆▅▆▇▆▆▅

0,1
backbone,resnet_b18
best_macro_f1,0.73507
epoch,18
lr,3e-05
modalities,RGB + MS + HS
seed,2024
train/loss,0.49077
val/acc,0.64444
val/macro_f1,0.63175





Backbone: resnet_b18
Indexed train IDs: 600 | usable labeled train rows: 600
Indexed val IDs:   300   | val rows: 300
HS channels after trimming: 101
Train split: 510 | Holdout split: 90
Epoch 01 | loss=0.8660 | val_F1=0.5080 | Best F1: -1.0000 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed999.pt
Epoch 02 | loss=0.7390 | val_F1=0.6312 | Best F1: 0.5080 | Bad Epochs: 0/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed999.pt
Epoch 03 | loss=0.6934 | val_F1=0.5779 | Best F1: 0.6312 | Bad Epochs: 0/7
Epoch 04 | loss=0.6782 | val_F1=0.5666 | Best F1: 0.6312 | Bad Epochs: 1/7
Epoch 05 | loss=0.6569 | val_F1=0.6586 | Best F1: 0.6312 | Bad Epochs: 2/7
  -> saved best to /content/drive/MyDrive/SP26/DAT301m/PBL1/Kaggle_Prepared_resnet/output/best_seed999.pt
Epoch 06 | loss=0.6499 | val_F1=0.5836 | Best F1: 0.6586 | Bad Epochs: 0/7
Epoch 07 | loss=0.6252 | val_F1=0.6792 | Bes

0,1
epoch,▁▂▂▃▃▄▄▅▅▆▆▇▇█
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▆▅▄▄▄▃▃▃▂▂▃▁▁
val/acc,▁▆▆▅▇▄█▆▄▃█▆▆▄
val/macro_f1,▁▆▄▃▇▄█▆▃▄█▅▆▂

0,1
backbone,resnet_b18
best_macro_f1,0.6792
epoch,14
lr,3e-05
modalities,RGB + MS + HS
seed,999
train/loss,0.51654
val/acc,0.58889
val/macro_f1,0.52112



Macro-F1: 0.6920 ± 0.0260
