# Using same base

## Verify data

In [1]:
import os
from collections import Counter

DATA_ROOT = "/Users/apple/Downloads/data"
TRAIN_DIR = os.path.join(DATA_ROOT, "train")
VAL_DIR = os.path.join(DATA_ROOT, "val")

print("data structure :")
for split, path in [("train", TRAIN_DIR), ("val", VAL_DIR)]:
    print(f"\n{split.upper()}:")
    if not os.path.exists(path):
        print(f"{path} NON EXISTENT")
        continue
    
    for modality in ["RGB", "MS", "HS"]:
        mod_path = os.path.join(path, modality)
        if os.path.exists(mod_path):
            count = len([f for f in os.listdir(mod_path) if f.lower().endswith((".png", ".tif", ".tiff"))])
            print(f"  {modality}: {count} files")
        else:
            print(f"  {modality}: MISSING")

# Sample filenames
print("\n sample filenames:")
for split, path in [("train/RGB", os.path.join(TRAIN_DIR, "RGB")), ("val/HS", os.path.join(VAL_DIR, "HS"))]:
    if os.path.exists(path):
        files = [f for f in os.listdir(path)[:5] if f.lower().endswith((".png", ".tif"))]
        print(f"\n{split}: {files}")
    else:
        print(f"\n{split}: PATH MISSING")

# Train label parsing check
train_rgb_path = os.path.join(TRAIN_DIR, "RGB")
if os.path.exists(train_rgb_path):
    train_rgb_files = [f for f in os.listdir(train_rgb_path) if f.lower().endswith(".png")]
    labels = [f.split("_")[0] for f in train_rgb_files if "_" in f]
    label_counts = Counter(labels)
    print(f"\n train labels: {dict(label_counts)} (total files: {len(train_rgb_files)})")
else:
    print("\n !!! train/RGB missing")


data structure :

TRAIN:
  RGB: 600 files
  MS: 600 files
  HS: 600 files

VAL:
  RGB: 300 files
  MS: 300 files
  HS: 300 files

 sample filenames:

train/RGB: ['Rust_hyper_181.png', 'Rust_hyper_195.png', 'Other_hyper_117.png', 'Other_hyper_103.png', 'Rust_hyper_142.png']

val/HS: ['val_f720cb3d.tif', 'val_9eb4f156.tif', 'val_cc2ca03e.tif', 'val_8e9e3f20.tif', 'val_47de91ce.tif']

 train labels: {'Rust': 200, 'Other': 200, 'Health': 200} (total files: 600)


## Install dependencies + adapt baseline CFG

In [2]:
import subprocess
subprocess.run(["uv", "add", "timm", "tifffile", "opencv-python", "torch", "torchvision"], capture_output=True)

CompletedProcess(args=['uv', 'add', 'timm', 'tifffile', 'opencv-python', 'torch', 'torchvision'], returncode=0, stdout=b'', stderr=b'\x1b[2mResolved \x1b[1m94 packages\x1b[0m \x1b[2min 30ms\x1b[0m\x1b[0m\n\x1b[2mUninstalled \x1b[1m1 package\x1b[0m \x1b[2min 1.62s\x1b[0m\x1b[0m\n\x1b[2mInstalled \x1b[1m1 package\x1b[0m \x1b[2min 333ms\x1b[0m\x1b[0m\n \x1b[33m~\x1b[39m \x1b[1mtorch\x1b[0m\x1b[2m==2.10.0\x1b[0m\n')

In [3]:
# setup cfg
import os
from dataclasses import dataclass

@dataclass
class CFG:
    ROOT: str = "/Users/apple/Downloads/data"  
    TRAIN_DIR: str = "train"
    VAL_DIR: str = "val"
    
    USE_RGB: bool = True
    USE_MS: bool = True  
    USE_HS: bool = True
    
    IMG_SIZE: int = 224
    BATCH_SIZE: int = 16  # smaller for local debugging
    EPOCHS: int = 5       # few for testing
    LR: float = 3e-4
    WD: float = 1e-4
    
    NUM_WORKERS: int = 4
    SEED: int = 3557
    
    RGB_BACKBONE: str = "convnext_base"  # swap to resnet later
    AMP: bool = True
    
    HS_DROP_FIRST: int = 10
    HS_DROP_LAST: int = 14
    
    OUT_DIR: str = "./outputs"
    BEST_CKPT: str = "best.pt"

cfg = CFG()
print("CFG ready:")
print(f"  ROOT: {cfg.ROOT}")
print(f"  RGB backbone: {cfg.RGB_BACKBONE}")
os.makedirs(cfg.OUT_DIR, exist_ok=True)

CFG ready:
  ROOT: /Users/apple/Downloads/data
  RGB backbone: convnext_base


## Indexing + df

In [4]:
import os, re
from typing import Dict, List
from collections import defaultdict

LABELS = ["Health", "Rust", "Other"]
LBL2ID = {k: i for i, k in enumerate(LABELS)}
ID2LBL = {i: k for k, i in LBL2ID.items()}

def list_files(folder: str, exts: tuple) -> List[str]:
    if not os.path.isdir(folder):
        return []
    return sorted([os.path.join(folder, fn) 
                  for fn in os.listdir(folder) 
                  if fn.lower().endswith(exts)])

def base_id(path: str) -> str:
    return os.path.splitext(os.path.basename(path))[0]

def parse_label_from_train_name(bid: str) -> str:
    m = re.match(r"^(Health|Rust|Other)_", bid)
    return m.group(1) if m else None

def build_index(root: str, split: str) -> Dict[str, Dict[str, str]]:
    split_dir = os.path.join(root, split)
    rgb_dir = os.path.join(split_dir, "RGB")
    ms_dir  = os.path.join(split_dir, "MS") 
    hs_dir  = os.path.join(split_dir, "HS")
    
    rgb_files = list_files(rgb_dir, (".png",))
    ms_files  = list_files(ms_dir,  (".tif", ".tiff"))
    hs_files  = list_files(hs_dir,  (".tif", ".tiff"))
    
    idx = {}
    for p in rgb_files: idx.setdefault(base_id(p), {})["rgb"] = p
    for p in ms_files:  idx.setdefault(base_id(p), {})["ms"]  = p  
    for p in hs_files:  idx.setdefault(base_id(p), {})["hs"]  = p
    return idx

In [5]:
# build index
train_idx = build_index(cfg.ROOT, cfg.TRAIN_DIR)
val_idx   = build_index(cfg.ROOT, cfg.VAL_DIR)

print(f"Train IDs: {len(train_idx)}")
print(f"Val IDs:   {len(val_idx)}")

# check alignment
missing_modalities = []
for bid, paths in list(train_idx.items())[:10]:  # sample
    missing = [m for m in ["rgb","ms","hs"] if m not in paths]
    if missing: missing_modalities.append((bid, missing))

print(f"\nSample alignment: {len(missing_modalities)}/10 missing modalities")

Train IDs: 600
Val IDs:   300

Sample alignment: 0/10 missing modalities


## Hold-out split

In [6]:
# import subprocess
# subprocess.run(["uv", "add", "pandas", "numpy"], capture_output=True)

In [7]:
# build train_df/val_df + stratified holdout

import pandas as pd

def make_train_df(train_idx: Dict[str, Dict[str, str]]) -> pd.DataFrame:
    rows = []
    for bid, paths in train_idx.items():
        lab = parse_label_from_train_name(bid)
        if lab:
            rows.append({"base_id": bid, "label": lab, 
                        "rgb": paths.get("rgb"), "ms": paths.get("ms"), "hs": paths.get("hs")})
    return pd.DataFrame(rows)

def make_val_df(val_idx: Dict[str, Dict[str, str]]) -> pd.DataFrame:
    rows = []
    for bid, paths in val_idx.items():
        rows.append({"base_id": bid, "rgb": paths.get("rgb"), 
                    "ms": paths.get("ms"), "hs": paths.get("hs")})
    return pd.DataFrame(rows)

def stratified_holdout(df: pd.DataFrame, frac: float = 0.1, seed: int = 42) -> tuple:
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    parts = []
    for lab, g in df.groupby("label"):
        n = max(1, int(len(g) * frac))
        parts.append(g.iloc[:n])
    df_va = pd.concat(parts).drop_duplicates("base_id")
    df_tr = df[~df["base_id"].isin(df_va["base_id"])].reset_index(drop=True)
    return df_tr, df_va

In [8]:
train_df = make_train_df(train_idx)
val_df   = make_val_df(val_idx)

print(f"Train DF: {len(train_df)} rows; label dist: {train_df['label'].value_counts().to_dict()}")
print(f"Val DF:   {len(val_df)} rows")

# STRATIFIED HOLDOUT (10% of train for validation)
df_tr, df_va_holdout = stratified_holdout(train_df, frac=0.1, seed=cfg.SEED)
print(f"\nSPLITS:")
print(f"  Training:   {len(df_tr)}")
print(f"  Holdout VA: {len(df_va_holdout)}")

# Preview
print("\nSAMPLE TRAIN ROW:")
print(df_tr.head(2)[["base_id", "label"]].to_string(index=False))

Train DF: 600 rows; label dist: {'Health': 200, 'Other': 200, 'Rust': 200}
Val DF:   300 rows

SPLITS:
  Training:   540
  Holdout VA: 60

SAMPLE TRAIN ROW:
         base_id  label
Health_hyper_146 Health
Health_hyper_200 Health


## HS channel inf + preview dataset

In [9]:
import torch
import numpy as np
import cv2
import tifffile as tiff

def read_tiff_multiband(path: str) -> np.ndarray:
    arr = tiff.imread(path)
    if arr.ndim != 3:
        raise ValueError(f"Expected 3D TIFF, got {arr.shape}")
    if arr.shape[0] < arr.shape[1]:  # (C,H,W) -> (H,W,C)
        arr = np.transpose(arr, (1, 2, 0))
    return arr

def read_hs(path: str, drop_first: int, drop_last: int) -> torch.Tensor:
    arr = read_tiff_multiband(path)
    B = arr.shape[2]
    if B > (drop_first + drop_last + 1):
        arr = arr[:, :, drop_first:B-drop_last]
    # Min-max norm per band (simplified)
    for c in range(arr.shape[2]):
        arr[:,:,c] = (arr[:,:,c] - arr[:,:,c].min()) / (arr[:,:,c].max() - arr[:,:,c].min() + 1e-6)
    return torch.from_numpy(arr).permute(2, 0, 1).float()

def infer_hs_in_ch(train_df, val_df, cfg):
    for df in (train_df, val_df):
        hs_paths = df["hs"].dropna().tolist()
        if hs_paths:
            x = read_hs(hs_paths[0], cfg.HS_DROP_FIRST, cfg.HS_DROP_LAST)
            print(f"HS sample shape: {x.shape} ({x.shape[0]} bands)")
            return int(x.shape[0])
    return 101

In [10]:
# INFER HS CHANNELS
hs_in_ch = infer_hs_in_ch(train_df, val_df, cfg)
print(f"\nHS_IN_CH = {hs_in_ch}")

# Test one full sample
sample_row = df_tr.iloc[0]
print(f"\nTesting sample: {sample_row['base_id']}")
print(f"  RGB: {sample_row['rgb'] is not None}")
print(f"  MS:  {sample_row['ms'] is not None}")
print(f"  HS:  {sample_row['hs'] is not None}")

HS sample shape: torch.Size([101, 32, 32]) (101 bands)

HS_IN_CH = 101

Testing sample: Health_hyper_146
  RGB: True
  MS:  True
  HS:  True


## Full dataset class + quick test

In [11]:
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

def read_rgb(path: str) -> torch.Tensor:
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    x = torch.from_numpy(img).permute(2, 0, 1)
    return (x - IMAGENET_MEAN) / IMAGENET_STD

def resize_tensor(x: torch.Tensor, size: int) -> torch.Tensor:
    return F.interpolate(x.unsqueeze(0), size=(size, size), mode="bilinear").squeeze(0)

class WheatMultiModalDataset(Dataset):
    def __init__(self, df, cfg, hs_in_ch, train=False):
        self.df = df.reset_index(drop=True)
        self.cfg = cfg
        self.hs_in_ch = hs_in_ch
        self.train = train
    
    def __len__(self): return len(self.df)
    
    def __getitem__(self, i):
        row = self.df.iloc[i]
        
        x_rgb = x_ms = x_hs = None
        m_rgb = m_ms = m_hs = 0.0
        
        # RGB
        if self.cfg.USE_RGB and row.get("rgb"):
            x_rgb = read_rgb(row["rgb"])
            x_rgb = resize_tensor(x_rgb, self.cfg.IMG_SIZE)
            m_rgb = 1.0
        
        # MS (5 bands)
        if self.cfg.USE_MS and row.get("ms"):
            arr = read_tiff_multiband(row["ms"])
            arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-6)  # simple norm
            x_ms = torch.from_numpy(arr).permute(2, 0, 1).float()
            x_ms = resize_tensor(x_ms, self.cfg.IMG_SIZE)
            m_ms = 1.0
        
        # HS
        if self.cfg.USE_HS and row.get("hs"):
            x_hs = read_hs(row["hs"], self.cfg.HS_DROP_FIRST, self.cfg.HS_DROP_LAST)
            x_hs = resize_tensor(x_hs, self.cfg.IMG_SIZE)
            m_hs = 1.0
        
        # Zero-pad missing
        if x_rgb is None: x_rgb = torch.zeros(3, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE)
        if x_ms  is None: x_ms  = torch.zeros(5, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE)
        if x_hs  is None: x_hs  = torch.zeros(self.hs_in_ch, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE)
        
        mask = torch.tensor([m_rgb, m_ms, m_hs])
        
        if "label" in row:
            return {"id": row["base_id"], "rgb": x_rgb, "ms": x_ms, "hs": x_hs, 
                   "mask": mask, "y": torch.tensor(LBL2ID[row["label"]])}
        return {"id": row["base_id"], "rgb": x_rgb, "ms": x_ms, "hs": x_hs, "mask": mask}


In [12]:
# TEST DATASET
ds_tr = WheatMultiModalDataset(df_tr, cfg, hs_in_ch, train=True)
sample = ds_tr[0]

print(f"Sample shapes:")
print(f"  RGB: {sample['rgb'].shape}  mask: {sample['mask'][0]:.1f}")
print(f"  MS:  {sample['ms'].shape}  mask: {sample['mask'][1]:.1f}")
print(f"  HS:  {sample['hs'].shape}  mask: {sample['mask'][2]:.1f}")
print(f"  Label: {sample['id']} -> {ID2LBL[sample['y'].item()]}")

Sample shapes:
  RGB: torch.Size([3, 224, 224])  mask: 1.0
  MS:  torch.Size([5, 224, 224])  mask: 1.0
  HS:  torch.Size([101, 224, 224])  mask: 1.0
  Label: Health_hyper_146 -> Health


## ConvNeXt baseline

In [13]:
import torch.nn as nn
import timm

class SmallSpectralEncoder(nn.Module):
    def __init__(self, in_ch: int, emb_dim: int = 256):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, 32, 1, bias=False),
            nn.BatchNorm2d(32), nn.ReLU(inplace=True)
        )
        self.block = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), 
            nn.Linear(128, emb_dim), nn.ReLU(inplace=True)
        )
    
    def forward(self, x): 
        x = self.stem(x)
        x = self.block(x)
        return self.head(x)

class MultiModalNet(nn.Module):
    def __init__(self, cfg, hs_in_ch, n_classes=3):
        super().__init__()
        self.use_rgb = cfg.USE_RGB
        self.use_ms = cfg.USE_MS
        self.use_hs = cfg.USE_HS
        
        feat_dims = []
        if self.use_rgb:
            self.rgb_enc = timm.create_model(cfg.RGB_BACKBONE, pretrained=True, num_classes=0, global_pool="avg")
            feat_dims.append(self.rgb_enc.num_features)
        if self.use_ms: 
            self.ms_enc = SmallSpectralEncoder(5, 256)
            feat_dims.append(256)
        if self.use_hs:
            self.hs_enc = SmallSpectralEncoder(hs_in_ch, 256) 
            feat_dims.append(256)
        
        self.classifier = nn.Sequential(
            nn.Linear(sum(feat_dims), 512), nn.ReLU(inplace=True), nn.Dropout(0.2),
            nn.Linear(512, n_classes)
        )
    
    def forward(self, rgb, ms, hs, mask):
        feats = []
        if self.use_rgb: feats.append(self.rgb_enc(rgb) * mask[:, 0:1])
        if self.use_ms:  feats.append(self.ms_enc(ms)  * mask[:, 1:2])
        if self.use_hs:  feats.append(self.hs_enc(hs)  * mask[:, 2:3])
        return self.classifier(torch.cat(feats, dim=1))


  from .autonotebook import tqdm as notebook_tqdm


In [14]:
# instantiate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiModalNet(cfg, hs_in_ch).to(device)
print(f"Model created on {device}")
print(f"  RGB dim: {model.rgb_enc.num_features if cfg.USE_RGB else 'N/A'}")
print(f"  Total params: {sum(p.numel() for p in model.parameters()):,}")

Model created on cpu
  RGB dim: 1024
  Total params: 88,905,027


# ResNet

In [15]:
cfg.RGB_BACKBONE = "resnet50"  
cfg.BATCH_SIZE = 8              # smaller for CPU lol
print(f"Updated: RGB={cfg.RGB_BACKBONE}, Batch={cfg.BATCH_SIZE}")

# RELOAD MODEL with ResNet
model = MultiModalNet(cfg, hs_in_ch).to(device)
print(f"Now using ResNet50 + raw HS (PCA next)")
print(f"  RGB dim: {model.rgb_enc.num_features}")

Updated: RGB=resnet50, Batch=8
Now using ResNet50 + raw HS (PCA next)
  RGB dim: 2048


## PCA for hs dim reduction

In [16]:
# import subprocess
# subprocess.run(["uv", "add", "scikit-learn", "joblib"], capture_output=True)

In [21]:
# DEBUG: Check HS band counts
hs_paths = df_tr["hs"].dropna().tolist()[:20]  # first 20
band_counts = []
for path in hs_paths:
    try:
        arr = read_tiff_multiband(path)
        B_trim = arr.shape[2] - cfg.HS_DROP_FIRST - cfg.HS_DROP_LAST
        band_counts.append((os.path.basename(path), arr.shape, B_trim))
    except:
        pass

print("HS BAND COUNTS (first 20):")
unique_bands = set(b[2] for b in band_counts)
print(f"Raw shapes: {[b[1] for b in band_counts[:5]]}")
print(f"Trimmed bands: {band_counts[:5]}")
print(f"Unique trimmed B: {sorted(unique_bands)}")


HS BAND COUNTS (first 20):
Raw shapes: [(32, 32, 125), (32, 32, 126), (32, 32, 126), (32, 32, 125), (32, 32, 125)]
Trimmed bands: [('Health_hyper_146.tif', (32, 32, 125), 101), ('Health_hyper_200.tif', (32, 32, 126), 102), ('Health_hyper_14.tif', (32, 32, 126), 102), ('Other_hyper_116.tif', (32, 32, 125), 101), ('Other_hyper_27.tif', (32, 32, 125), 101)]
Unique trimmed B: [101, 102]


In [25]:
def fit_hs_pca_fixed(train_df, target_ch=101, n_components=30, sample_frac=0.1):
    all_pixels = []
    hs_paths = train_df["hs"].dropna().tolist()
    
    print(f"Sampling {int(len(hs_paths)*sample_frac)} patches")
    for path in np.random.choice(hs_paths, size=int(len(hs_paths)*sample_frac), replace=False):
        try:
            arr = read_tiff_multiband(path)
            B_full = arr.shape[2]
            start, end = cfg.HS_DROP_FIRST, B_full - cfg.HS_DROP_LAST
            
            # Take exactly target_ch bands
            actual_ch = min(end - start, target_ch)
            arr_trim = arr[:,:,start:start+actual_ch]
            
            # Pad if needed
            if arr_trim.shape[2] < target_ch:
                pad_shape = (arr_trim.shape[0], arr_trim.shape[1], target_ch - arr_trim.shape[2])
                arr_trim = np.pad(arr_trim, ((0,0),(0,0),(0,pad_shape[2])), mode='constant')
            
            # Flatten to pixels x bands (2D)
            pixels_patch = arr_trim.reshape(-1, target_ch)
            all_pixels.append(pixels_patch)
            
        except Exception as e:
            continue
    
    # Stack ALL pixels (2D: N_pixels × target_ch)
    pixels = np.vstack(all_pixels)
    print(f"PCA: {pixels.shape[0]:,} pixels × {pixels.shape[1]} bands")
    
    scaler = StandardScaler()
    pixels_scaled = scaler.fit_transform(pixels)
    
    pca = PCA(n_components=n_components)
    pca.fit(pixels_scaled)
    
    print(f"Variance: {pca.explained_variance_ratio_.sum():.1%} ({pca.n_components_} comps)")
    
    joblib.dump({"pca": pca, "scaler": scaler, "target_ch": target_ch}, 
                f"{cfg.OUT_DIR}/hs_pca.pkl")
    return pca, scaler, target_ch

pca, scaler, hs_pca_ch = fit_hs_pca_fixed(df_tr, target_ch=101, n_components=30)
print(f"PCA ready!!! 101→{pca.n_components_} bands")


Sampling 54 patches
PCA: 55,296 pixels × 101 bands
Variance: 100.0% (30 comps)
PCA ready!!! 101→30 bands


## Int PCA w/ the dataset

In [35]:
# LOAD PCA
pca_data = joblib.load(f"{cfg.OUT_DIR}/hs_pca.pkl")
HS_PCA = pca_data["pca"]
HS_SCALER = pca_data["scaler"] 
HS_PCA_CH = HS_PCA.n_components_  # 30
print(f"HS Dataset: raw {pca_data['target_ch']} → PCA {HS_PCA_CH} bands")

# def apply_hs_pca(x_hs: torch.Tensor) -> torch.Tensor:
#     """Apply scaler + PCA to HS tensor (C,H,W) → (pca_ch,H,W)"""
#     # (C,H,W) → (H*W,C) → scale → PCA → (H*W,pca_ch) → (pca_ch,H,W)
#     B, H, W = x_hs.shape
#     pixels = x_hs.permute(1,2,0).reshape(-1, B).numpy()  # H*W x C
    
#     pixels_scaled = HS_SCALER.transform(pixels)
#     pixels_pca = HS_PCA.transform(pixels_scaled)
    
#     return torch.from_numpy(pixels_pca.T).reshape(HS_PCA_CH, H, W).float()

def apply_hs_pca(x_hs: torch.Tensor) -> torch.Tensor:
    """Apply scaler + PCA: truncate to 101 → scale → PCA → 30"""
    B, H, W = x_hs.shape
    
    # FIXED: Truncate to exactly 101 bands (PCA input_ch)
    x_hs_trunc = x_hs[:101]  # take first 101
    
    pixels = x_hs_trunc.permute(1,2,0).reshape(-1, 101).numpy()
    
    pixels_scaled = HS_SCALER.transform(pixels)
    pixels_pca = HS_PCA.transform(pixels_scaled)
    
    return torch.from_numpy(pixels_pca.T).reshape(HS_PCA_CH, H, W).float()


class WheatPCADataset(Dataset):

    def __init__(self, df, cfg, hs_pca_ch, train=False):
        self.df = df.reset_index(drop=True)
        self.cfg = cfg
        self.hs_pca_ch = hs_pca_ch
        self.train = train
    
    def __len__(self): 
        return len(self.df)

    
    def __getitem__(self, i):
        row = self.df.iloc[i]
        x_rgb = x_ms = x_hs = None
        mask = torch.zeros(3)
        
        # RGB (unchanged)
        if self.cfg.USE_RGB and row.get("rgb"):
            x_rgb = read_rgb(row["rgb"])
            x_rgb = resize_tensor(x_rgb, self.cfg.IMG_SIZE)
            mask[0] = 1.0
        
        # MS (unchanged)  
        if self.cfg.USE_MS and row.get("ms"):
            arr = read_tiff_multiband(row["ms"])
            arr = (arr - arr.min(axis=(0,1), keepdims=True)) / (arr.max(axis=(0,1), keepdims=True) + 1e-6)
            x_ms = torch.from_numpy(arr).permute(2,0,1).float()
            x_ms = resize_tensor(x_ms, self.cfg.IMG_SIZE)
            mask[1] = 1.0
        
        # HS + PCA! 
        if self.cfg.USE_HS and row.get("hs"):
            x_hs_raw = read_hs(row["hs"], self.cfg.HS_DROP_FIRST, self.cfg.HS_DROP_LAST)
            x_hs = apply_hs_pca(resize_tensor(x_hs_raw, self.cfg.IMG_SIZE))
            mask[2] = 1.0
        
        # Pad missing
        if x_rgb is None: x_rgb = torch.zeros(3, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE)
        if x_ms  is None: x_ms  = torch.zeros(5, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE)
        if x_hs  is None: x_hs  = torch.zeros(self.hs_pca_ch, self.cfg.IMG_SIZE, self.cfg.IMG_SIZE)
        
        out = {"id": row["base_id"], "rgb": x_rgb, "ms": x_ms, "hs": x_hs, "mask": mask}
        if "label" in row:
            out["y"] = torch.tensor(LBL2ID[row["label"]])
        return out

# TEST PCA DATASET
print("Testing PCA dataset...")
ds_pca_tr = WheatPCADataset(df_tr, cfg, HS_PCA_CH, train=True)
sample_pca = ds_pca_tr[0]

print(f"PCA Sample shapes:")
print(f"  RGB: {sample_pca['rgb'].shape}  mask: {sample_pca['mask'][0]:.1f}")
print(f"  MS:  {sample_pca['ms'].shape}  mask: {sample_pca['mask'][1]:.1f}")
print(f"  HS:  {sample_pca['hs'].shape}  mask: {sample_pca['mask'][2]:.1f}")


HS Dataset: raw 101 → PCA 30 bands
Testing PCA dataset...
PCA Sample shapes:
  RGB: torch.Size([3, 224, 224])  mask: 1.0
  MS:  torch.Size([5, 224, 224])  mask: 1.0
  HS:  torch.Size([30, 224, 224])  mask: 1.0


## update model + train 1 epoch test

In [36]:
model = MultiModalNet(cfg, HS_PCA_CH).to(device)  # 30 HS input channels!
print(f"Model updated for PCA HS: {HS_PCA_CH} input channels")

Model updated for PCA HS: 30 input channels


In [37]:
model = MultiModalNet(cfg, HS_PCA_CH).to(device)  # Uses 30 HS chans
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.LR, weight_decay=cfg.WD)

print(f"Model: ResNet50 + PCA(30)")
print(f"  Params: {sum(p.numel() for p in model.parameters()):,}")

# Dataloaders (small for CPU test)
ds_tr_pca = WheatPCADataset(df_tr, cfg, HS_PCA_CH, train=True)
ds_va_pca = WheatPCADataset(df_va_holdout, cfg, HS_PCA_CH, train=False)

dl_tr = DataLoader(ds_tr_pca, batch_size=cfg.BATCH_SIZE, shuffle=True, 
                  num_workers=0, pin_memory=False)  # ← workers=0, no pin
dl_va = DataLoader(ds_va_pca, batch_size=cfg.BATCH_SIZE, shuffle=False, 
                  num_workers=0, pin_memory=False)

print(f"Dataloaders: train={len(dl_tr)} batches, val={len(dl_va)} batches")


Model: ResNet50 + PCA(30)
  Params: 25,368,611
Dataloaders: train=68 batches, val=8 batches


In [38]:
def train_one_batch(model, batch, optimizer, device):
    rgb = batch["rgb"].to(device)
    ms  = batch["ms"].to(device)  
    hs  = batch["hs"].to(device)
    mask = batch["mask"].to(device)
    y = batch["y"].to(device)
    
    optimizer.zero_grad()
    logits = model(rgb, ms, hs, mask)
    loss = F.cross_entropy(logits, y)
    loss.backward()
    optimizer.step()
    
    acc = (logits.argmax(1) == y).float().mean().item()
    return loss.item(), acc

# TEST
model.train()
batch = next(iter(dl_tr))
loss, acc = train_one_batch(model, batch, optimizer, device)
print(f"Loss: {loss:.4f}; acc: {acc:.3f}")


Loss: 1.1100; acc: 0.125


# Proper train

In [43]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in loader:
        rgb, ms, hs, mask, y = [batch[k].to(device) for k in ["rgb","ms","hs","mask","y"]]
        
        optimizer.zero_grad()
        logits = model(rgb, ms, hs, mask)  # no autocast on CPU
        loss = F.cross_entropy(logits, y)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # gradient clipping
        optimizer.step()
        
        total_loss += loss.item() * y.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct = 0, 0
    conf = torch.zeros(3, 3, dtype=torch.long)

    for batch in loader:
        rgb  = batch["rgb"].to(device)
        ms   = batch["ms"].to(device)
        hs   = batch["hs"].to(device)
        mask = batch["mask"].to(device)
        y    = batch["y"].to(device)

        logits = model(rgb, ms, hs, mask)
        pred = logits.argmax(1)

        total += y.size(0)
        correct += (pred == y).sum().item()

        yt = y.cpu().numpy()
        yp = pred.cpu().numpy()
        for t, p in zip(yt, yp):
            conf[t, p] += 1

    acc = correct / max(1, total)

    f1s = []
    for c in range(3):
        tp = conf[c, c].item()
        fp = int(conf[:, c].sum().item() - tp)
        fn = int(conf[c, :].sum().item() - tp)
        prec = tp / max(1, (tp + fp))
        rec  = tp / max(1, (tp + fn))
        f1 = 0.0 if (prec + rec) == 0 else (2 * prec * rec / (prec + rec))
        f1s.append(f1)

    return {"acc": acc, "macro_f1": sum(f1s) / 3.0}



In [44]:
print("Training ResNet50 + PCA HS")
best_f1, patience = 0.0, 5
no_improve = 0

for epoch in range(1, 21):
    tr_loss = train_epoch(model, dl_tr, optimizer, device)
    metrics = evaluate(model, dl_va, device)

    acc_pct = metrics["acc"] * 100
    f1_pct  = metrics["macro_f1"] * 100

    print(f"E{epoch:2d}: loss={tr_loss:.4f} | acc={acc_pct:5.1f}% | F1={f1_pct:5.1f}%")

    if metrics["macro_f1"] > best_f1:
        best_f1 = metrics["macro_f1"]
        torch.save(model.state_dict(), f"{cfg.OUT_DIR}/trial2_best.pt")
        no_improve = 0
        print(f"NEW BEST F1: {f1_pct:5.1f}%")
    else:
        no_improve += 1
        if no_improve >= patience:
            print(f"Early stop (no improve {patience} epochs)")
            break

print(f"\nFINAL best F1: {best_f1*100:.1f}%")


Training ResNet50 + PCA HS
E 1: loss=0.9186 | acc= 33.3% | F1= 16.7%
NEW BEST F1:  16.7%
E 2: loss=0.8633 | acc= 33.3% | F1= 16.7%
E 3: loss=0.8163 | acc= 33.3% | F1= 16.7%
E 4: loss=0.7422 | acc= 33.3% | F1= 16.7%
E 5: loss=0.6335 | acc= 33.3% | F1= 16.7%
E 6: loss=0.6573 | acc= 33.3% | F1= 16.7%
Early stop (no improve 5 epochs)

FINAL best F1: 16.7%
