# Wheat Disease Classification - Multimodal Training Pipeline

This notebook implements the complete training pipeline for wheat disease classification using RGB, Multispectral, and Hyperspectral imagery.

## 1. Setup and Imports

In [None]:
import os
import joblib
import pandas as pd
import pytorch_lightning as pl
import torch
import kornia.augmentation as K
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from src.config import CFG, ID2LBL
from src.train import WheatClassifier
from src.utils import WheatDataset, make_df, infer_hs_channels, seed_everything

## 2. Configuration

In [None]:
cfg = CFG()
cfg.ROOT = "./data"
cfg.TRAIN_DIR = "train"
cfg.VAL_DIR = "test"
cfg.OUT_DIR = "./outputs"
cfg.WANDB_ENABLED = False
os.makedirs(cfg.OUT_DIR, exist_ok=True)

## 3. Compute Statistics and Fit PCA (Run Once)

This step computes normalization statistics and fits PCA on hyperspectral data.

In [None]:
from src.stats import calculate_stats
import joblib

if not os.path.exists(cfg.PCA_PATH):
    stats = calculate_stats(cfg, verbose=False, fit_pca=True, pca_path=cfg.PCA_PATH)
    if 'ms_mean' in stats:
        cfg.MS_MEAN = stats['ms_mean']
        cfg.MS_STD = stats['ms_std']
    if 'hs_pca_mean' in stats:
        cfg.HS_MEAN = stats['hs_pca_mean']
        cfg.HS_STD = stats['hs_pca_std']
    print(f"PCA fitted: {stats['pca_explained_variance']:.1%} variance explained")
else:
    pca_data = joblib.load(cfg.PCA_PATH)
    if isinstance(pca_data, dict):
        if pca_data.get('ms_mean'):
            cfg.MS_MEAN = pca_data['ms_mean']
            cfg.MS_STD = pca_data['ms_std']
        if pca_data.get('hs_pca_mean'):
            cfg.HS_MEAN = pca_data['hs_pca_mean']
            cfg.HS_STD = pca_data['hs_pca_std']
    print(f"PCA loaded from {cfg.PCA_PATH}")

## 4. Setup Data Module

In [None]:
seed_everything(cfg.SEED)

train_df = make_df(cfg.ROOT, cfg.TRAIN_DIR)
test_df = make_df(cfg.ROOT, cfg.VAL_DIR)
hs_ch = infer_hs_channels(train_df, cfg)

pca_model, pca_n_features = None, None
if cfg.PCA_COMPONENTS > 0 and cfg.USE_HS and os.path.exists(cfg.PCA_PATH):
    pca_data = joblib.load(cfg.PCA_PATH)
    pca_model = pca_data['model'] if isinstance(pca_data, dict) else pca_data
    pca_n_features = pca_data.get('n_features', hs_ch) if isinstance(pca_data, dict) else hs_ch
    hs_ch = cfg.PCA_COMPONENTS

train_transforms = K.AugmentationSequential(
    K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5),
    K.RandomRotation(degrees=90.0, p=0.5),
    K.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), p=0.3),
    K.RandomGaussianNoise(mean=0., std=0.03, p=0.2), data_keys=["image"]
)

full_dataset = WheatDataset(train_df, cfg, hs_ch, train_transforms, pca_model, pca_n_features)
test_dataset = WheatDataset(test_df, cfg, hs_ch, None, pca_model, pca_n_features)

print(f"Train: {len(train_df)} | Test: {len(test_df)} | HS: {hs_ch} channels")

## 5. Initialize Model

In [None]:
model = WheatClassifier(cfg, hs_ch, 3)
trainable_params = model.model.count_trainable_params()
print(f"Trainable params: {trainable_params:,} ({trainable_params/len(train_df):.0f} per sample)")

## 7. Train Model with 5-Fold Cross-Validation

In [None]:
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Subset, DataLoader
import numpy as np

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=cfg.SEED)
fold_scores = []

for fold, (tr_idx, val_idx) in enumerate(skf.split(train_df, train_df['label'])):
    print(f"\nFold {fold+1}/5: Train={len(tr_idx)}, Val={len(val_idx)}")
    
    train_subset = Subset(full_dataset, tr_idx)
    val_subset = Subset(full_dataset, val_idx)
    
    train_loader = DataLoader(train_subset, cfg.BATCH_SIZE, True, num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_subset, cfg.BATCH_SIZE, False, num_workers=cfg.NUM_WORKERS, pin_memory=True)
    
    model = WheatClassifier(cfg, hs_ch, 3)
    checkpoint_cb = ModelCheckpoint(dirpath=f'{cfg.OUT_DIR}/fold_{fold}', filename='best', monitor='val_f1', mode='max')
    trainer = pl.Trainer(
        max_epochs=cfg.EPOCHS, 
        callbacks=[checkpoint_cb, EarlyStopping(monitor='val_f1', patience=15, mode='max')],
        logger=WandbLogger(cfg.WANDB_PROJECT_NAME, f'{cfg.WANDB_RUN_NAME}_f{fold}') if cfg.WANDB_ENABLED else False,
        accelerator='auto', devices=1, precision='16-mixed', deterministic=True
    )
    trainer.fit(model, train_loader, val_loader)
    fold_scores.append(checkpoint_cb.best_model_score.item())

print(f"\n5-Fold CV: {' | '.join([f'F{i+1}={s:.3f}' for i,s in enumerate(fold_scores)])}")
print(f"Mean: {np.mean(fold_scores):.4f} Â± {np.std(fold_scores):.4f}")

best_fold = np.argmax(fold_scores)
print(f"\nUsing best fold {best_fold+1} (F1={fold_scores[best_fold]:.4f}) for test predictions")

## 8. Generate Test Predictions

In [None]:
model = WheatClassifier.load_from_checkpoint(
    f'{cfg.OUT_DIR}/fold_{best_fold}/best.ckpt', 
    cfg=cfg, hs_channels=hs_ch, num_classes=3
)
test_loader = DataLoader(test_dataset, cfg.BATCH_SIZE, False, num_workers=cfg.NUM_WORKERS, pin_memory=True)
test_preds = pl.Trainer(accelerator='auto', devices=1).predict(model, test_loader)

preds = torch.cat([batch['preds'] for batch in test_preds]).cpu().numpy()
sub = pd.DataFrame({
    'Id': [os.path.basename(test_df.iloc[i].get('hs','')) for i in range(len(test_df))],
    'Category': [ID2LBL[p] for p in preds]
})

sub_path = os.path.join(cfg.OUT_DIR, 'submission.csv')
sub.to_csv(sub_path, index=False)
print(f"Submission: {sub_path}")