# Wheat Disease Classification - Multimodal Training Pipeline

This notebook implements the complete training pipeline for wheat disease classification using RGB, Multispectral, and Hyperspectral imagery.

## 1. Setup and Imports

In [1]:
import os
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from src.config import CFG, ID2LBL
from src.train import WheatClassifier
from src.utils import WheatDataModule, seed_everything

  from .autonotebook import tqdm as notebook_tqdm


## 2. Configuration

In [2]:
cfg = CFG()
cfg.ROOT = "./data"
cfg.TRAIN_DIR = "train"
cfg.VAL_DIR = "test"
cfg.OUT_DIR = "./outputs"
cfg.WANDB_ENABLED = False
os.makedirs(cfg.OUT_DIR, exist_ok=True)

## 3. Compute Statistics and Fit PCA (Run Once)

This step computes normalization statistics and fits PCA on hyperspectral data.

In [3]:
from src.stats import calculate_stats
import joblib

if not os.path.exists(cfg.PCA_PATH):
    stats = calculate_stats(cfg, verbose=False, fit_pca=True, pca_path=cfg.PCA_PATH)
    if 'ms_mean' in stats:
        cfg.MS_MEAN = stats['ms_mean']
        cfg.MS_STD = stats['ms_std']
    if 'hs_pca_mean' in stats:
        cfg.HS_MEAN = stats['hs_pca_mean']
        cfg.HS_STD = stats['hs_pca_std']
    print(f"PCA fitted: {stats['pca_explained_variance']:.1%} variance explained")
else:
    pca_data = joblib.load(cfg.PCA_PATH)
    if isinstance(pca_data, dict):
        if pca_data.get('ms_mean'):
            cfg.MS_MEAN = pca_data['ms_mean']
            cfg.MS_STD = pca_data['ms_std']
        if pca_data.get('hs_pca_mean'):
            cfg.HS_MEAN = pca_data['hs_pca_mean']
            cfg.HS_STD = pca_data['hs_pca_std']
    print(f"PCA loaded from {cfg.PCA_PATH}")

PCA fitted: 99.6% variance explained


## 4. Setup Data Module

In [4]:
seed_everything(cfg.SEED)
dm = WheatDataModule(cfg)
dm.setup()
print(f"Train: {len(dm.train_ds)} | Val: {len(dm.val_ds)} | Test: {len(dm.test_ds)} | HS: {dm.hs_ch} channels")

Train: 540 | Val: 60 | Test: 300 | HS: 20 channels


## 5. Initialize Model

In [5]:
model = WheatClassifier(cfg, hs_channels=dm.hs_ch, num_classes=3)
trainable_params = model.model.count_trainable_params()
print(f"Trainable params: {trainable_params:,} ({trainable_params/len(dm.train_ds):.0f} per sample)")

Trainable params: 22,873,100 (42358 per sample)


## 6. Setup Training Callbacks and Logger

In [6]:
checkpoint_cb = ModelCheckpoint(
    dirpath=cfg.OUT_DIR,
    filename='best-{epoch:02d}-{val_f1:.4f}',
    monitor='val_f1',
    mode='max',
    save_top_k=1,
    verbose=False
)

early_stop_cb = EarlyStopping(monitor='val_f1', patience=10, mode='max', verbose=False)
logger = WandbLogger(project=cfg.WANDB_PROJECT_NAME, name=cfg.WANDB_RUN_NAME) if cfg.WANDB_ENABLED else False

## 7. Train Model

In [7]:
trainer = pl.Trainer(
    max_epochs=cfg.EPOCHS,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_cb, early_stop_cb],
    logger=logger,
    precision='16-mixed',
    deterministic=True,
    log_every_n_steps=10,
    enable_progress_bar=True
)

trainer.fit(model, dm)
print(f"Best val_f1: {checkpoint_cb.best_model_score:.4f}")

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
ðŸ’¡ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
You are using a CUDA device ('NVIDIA GeForce RTX 4090 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/krschap/code/foss/gdap_kaggle/.venv/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:881: Checkpoint directory /home/krschap/code/foss/gdap_kaggle/outputs exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/krscha

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/krschap/code/foss/gdap_kaggle/.venv/lib/python3.12/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.


Epoch 22: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 8/8 [00:05<00:00,  1.47it/s, train_loss=0.832, train_acc=0.750, val_loss=1.090, val_acc=0.517, val_f1=0.503]
Best val_f1: 0.5469


## 8. Generate Test Predictions

In [8]:
torch.serialization.add_safe_globals([CFG])
test_preds = trainer.predict(model, dm.test_dataloader(), ckpt_path='best')

preds = torch.cat([batch['preds'] for batch in test_preds]).cpu().numpy()
sub = pd.DataFrame({
    'Id': [os.path.basename(dm.test_df.iloc[i].get('hs') or dm.test_df.iloc[i].get('ms') or dm.test_df.iloc[i].get('rgb'))
           for i in range(len(dm.test_df))],
    'Category': [ID2LBL[p] for p in preds]
})

sub_path = os.path.join(cfg.OUT_DIR, 'submission.csv')
sub.to_csv(sub_path, index=False)
print(f"Submission: {sub_path}")

Restoring states from the checkpoint path at /home/krschap/code/foss/gdap_kaggle/outputs/best-epoch=12-val_f1=0.5469.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/krschap/code/foss/gdap_kaggle/outputs/best-epoch=12-val_f1=0.5469.ckpt
/home/krschap/code/foss/gdap_kaggle/.venv/lib/python3.12/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.


Predicting DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5/5 [00:00<00:00, 17.46it/s]
Submission: ./outputs/submission.csv
