
# Paddy Disease Classification — **PyTorch + timm** and Multi‑Head model

This notebook demonstrates a **multi‑input** setup using **PyTorch** and and **[timm](https://github.com/huggingface/pytorch-image-models)** (Torch Image Models) for the Kaggle **Paddy** dataset.

1) Set up the environment and choose a timm backbone (`convnext_tiny`).  
2) Build a **custom Dataset** reading `train.csv` (`image_id`, `label`, `variety`, `age`).  
   - Images are stored in subfolders named by **label** (e.g., `train/<label>/<image_id>`).  
   - Each sample returns a tuple: **`(image_tensor, variety_idx, age_float, label_idx)`** as requested.  
3) Create DataLoaders with timm‑compatible transforms.  
4) Define a **multi‑input model**:  
   
5) Train and evaluate with a minimal, well‑commented loop.

> **Note:** Point `DATA_DIR` to your local Kaggle Paddy dataset. 


## 1) Setup & Configuration

In [1]:

# If timm isn't installed, uncomment:
# !pip install timm --quiet

import os, random, math, time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split

from torchvision import datasets, transforms
import timm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

DATA_DIR = Path("./data/")  
TRAIN_CSV = DATA_DIR / 'train.csv'
TRAIN_IMG_ROOT = DATA_DIR   / 'train_images'

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print("Device:", device)

import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

#this makes sure that colab instances running this notebook can get to included utils_kp.py
try:
    import utils_kp as ut
except ModuleNotFoundError:
    !wget https://raw.githubusercontent.com/CNUClasses/471_code/master/week6/utils_kp.py
import utils_kp as ut
 
BATCH_SIZE = 256

# lr=0.01 #start with this one (its too high see lr finder below)
lr=2e-3 #lr finder approved this one
NUM_EPOCHS=5

Device: cuda:2


## 2) Custom Dataset (returns `(image, variety, age, label)`)


In [2]:

# see utils_kp.py for PaddyDataset class


## 3) Create datasets with timm transforms 

Not going to use ImageFolder dataset.  Instead use PaddyMultitaskDataset 



In [3]:
import pprint
pp = pprint.PrettyPrinter(indent=4)

# MODEL_NAME = "resnet18"  # try: 'efficientnet_b0', 'convnext_tiny', 'mobilenetv3_large_100', ...
MODEL_NAME = 'convnext_tiny'
config = timm.data.resolve_data_config({}, model=MODEL_NAME)
train_tfms = timm.data.create_transform(**config, is_training=True, hflip=0.5, auto_augment=None)
valid_tfms = timm.data.create_transform(**config, is_training=False)

pp.pprint(config)

{   'crop_mode': 'center',
    'crop_pct': 0.875,
    'input_size': (3, 224, 224),
    'interpolation': 'bicubic',
    'mean': (0.485, 0.456, 0.406),
    'std': (0.229, 0.224, 0.225)}


In [4]:
#get the split datasets
train_df, valid_df, test_df=ut.getdataframes(DATA_DIR, TRAIN_CSV, TRAIN_IMG_ROOT, valid_pct=0.1, test_pct=0.1,verbose=True )

train_ds = ut.PaddyMultitaskDataset(train_df, TRAIN_IMG_ROOT, transform=train_tfms)
valid_ds = ut.PaddyMultitaskDataset(valid_df,   TRAIN_IMG_ROOT, transform=valid_tfms)
test_ds  = ut.PaddyMultitaskDataset(test_df,  TRAIN_IMG_ROOT, transform=valid_tfms)

print('Label classes:', train_ds.labels)
print('Variety classes:', train_ds.varieties)
print('Train/Valid/test sizes:', len(train_ds), len(valid_ds), len(test_ds))


Label classes: ['bacterial_leaf_blight', 'bacterial_leaf_streak', 'bacterial_panicle_blight', 'blast', 'brown_spot', 'dead_heart', 'downy_mildew', 'hispa', 'normal', 'tungro']
Variety classes: ['ADT45', 'AndraPonni', 'AtchayaPonni', 'IR20', 'KarnatakaPonni', 'Onthanel', 'Ponni', 'RR', 'Surya', 'Zonal']
Train/Valid/test sizes: 8327 1040 1040


In [5]:
class_names=train_ds.labels
class_names

['bacterial_leaf_blight',
 'bacterial_leaf_streak',
 'bacterial_panicle_blight',
 'blast',
 'brown_spot',
 'dead_heart',
 'downy_mildew',
 'hispa',
 'normal',
 'tungro']

## 3) Create Dataloaders

- Shuffle the training loader; keep validation loader deterministic.  
- Adjust `BATCH_SIZE` to fit your GPU/CPU memory.

This is a datascience competition:<br>
the train_images folder contains images with class membership info (in the train.csv file).<br>
the test_images folder contains images that your model infers membership on.  These inferences are bundled into a file (see sample_submission.csv) which is submitted for ranking  

In [6]:

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

xb, y_var, y_age, y_lbl = next(iter(train_loader))
print('Batch shapes:', xb.shape, y_var.shape, y_age.shape, y_lbl.shape)



Batch shapes: torch.Size([256, 3, 224, 224]) torch.Size([256]) torch.Size([256]) torch.Size([256])


## 4) Load Multi‑input timm Model (Transfer Learning with **timm**)

- Use the train/test/val split to test this out
Model will now only predict label class
- **Warm-up:** freeze backbone; train the classifier head first.



In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

class MultiModalClassifier(nn.Module):
    """
    Predicts the disease label (10 classes) from:
      - image (RGB tensor, e.g. [B,3,H,W])
      - variety (categorical index, int in [0..num_variety_classes-1])
      - age (normalized float scalar per sample)

    Pipeline:
      image --(timm backbone + global pool)--> feats   (B, feat_dim)
      variety_idx --(one_hot)--> one_hot_var           (B, num_variety_classes)
      age_norm  -------------------------------------> (B, 1)
      concat([feats, one_hot_var, age_norm]) --MLP--> logits_label (B, num_label_classes)
    """
    def __init__(
        self,
        model_name: str = "convnext_tiny",
        num_label_classes: int = 10,
        num_variety_classes: int = 10,
        pretrained: bool = True,
        mlp_hidden: int = 256,
        dropout: float = 0.2,
    ):
        super().__init__()

        # 1) CNN backbone with no classifier so we can get a pooled feature vector
        self.backbone = timm.create_model(
            model_name, pretrained=pretrained, num_classes=0, global_pool="avg"
        )
        feat_dim = self.backbone.num_features

        # 2) Remember variety class count for one-hot
        self.num_variety_classes = num_variety_classes

        # 3) Fusion MLP head (features + one-hot(variety) + age)
        fused_in = feat_dim + num_variety_classes + 1  # +1 for age scalar
        self.mlp = nn.Sequential(
            nn.Linear(fused_in, mlp_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, num_label_classes),  # final logits for label prediction
        )

    def forward(self, x):
        """
        x is a tuple: (images, variety_idx, age_norm)
          - images: float tensor [B, 3, H, W]
          - variety_idx: long/int tensor [B] (values in [0..num_variety_classes-1])
          - age_norm: float tensor [B] or [B,1] (already normalized)
        returns:
          - logits_label: tensor [B, num_label_classes]
        """
        images, variety_idx, age_norm = x  # unpack the tuple

        # Backbone features (B, feat_dim)
        feats = self.backbone(images)

        # One-hot encode variety (B, num_variety_classes)
        # F.one_hot expects Long dtype and returns int; cast to float for concat
        one_hot_var = F.one_hot(variety_idx.long(), num_classes=self.num_variety_classes).float()

        # Ensure age is shape (B,1)
        if age_norm.dim() == 1:
            age_norm = age_norm.unsqueeze(1)
        age_norm = age_norm.float()

        # Concatenate along feature dimension
        fused = torch.cat([feats, one_hot_var, age_norm], dim=1)

        # MLP head -> logits for label prediction
        logits_label = self.mlp(fused)
        return logits_label


In [8]:
model = MultiModalClassifier(
    model_name="convnext_tiny",
    num_label_classes=10,
    num_variety_classes=10,
    pretrained=True
).to(device)

for p in model.backbone.parameters():
    p.requires_grad = False
print('Trainable params (heads only):', sum(p.numel() for p in model.parameters() if p.requires_grad))


Trainable params (heads only): 202250


In [9]:
#check to see the model structure
# print(model.backbone)
# print(model)

In [10]:
#make sure just training the last layer
# for name, p in model.named_parameters():
#     print (f'Name={name},p.shape={p.shape}, p.requires_grad = {p.requires_grad}')

#stopped here 10/1/25


## 5) Optimizer & loss 


In [11]:
#single head now
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)


# 6) Training & validation loop

- Log train/valid **loss** and **accuracy** per epoch.  
- warm up heads

In [12]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for images, variety_idx, age_norm, label_idx in loader:
        images      = images.to(device)
        variety_idx = variety_idx.to(device)
        age_norm    = age_norm.to(device)
        label_idx   = label_idx.to(device)

        
        logits = model((images, variety_idx, age_norm))
        loss = criterion(logits, label_idx) #loss is the avearege over the batch (so multiply by batch size below)
 
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        correct += (logits.argmax(1) == label_idx).sum().item()
        total += images.size(0)
    return total_loss/total, correct/total

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for images, variety_idx, age_norm, label_idx in loader:
        images      = images.to(device)
        variety_idx = variety_idx.to(device)
        age_norm    = age_norm.to(device)
        label_idx   = label_idx.to(device)

        logits = model((images, variety_idx, age_norm))
        loss = criterion(logits, label_idx)  #loss is the avearege over the batch (so multiply by batch size below)
        total_loss += loss.item() * images.size(0)
        correct += (logits.argmax(1) == label_idx).sum().item()
        total += images.size(0)
    return total_loss/total, correct/total

def train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, num_epochs=10):
    for epoch in range(1, num_epochs + 1):
        # Train for one epoch
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)

        # Evaluate on validation set
        va_loss, va_acc = evaluate(model, valid_loader, criterion)
        print(f"Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | valid loss {va_loss:.4f} acc {va_acc:.3f}")

train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, num_epochs=NUM_EPOCHS)


Epoch 01 | train loss 1.5192 acc 0.472 | valid loss 0.9755 acc 0.665
Epoch 02 | train loss 1.1261 acc 0.613 | valid loss 0.8298 acc 0.709
Epoch 03 | train loss 0.9844 acc 0.660 | valid loss 0.7184 acc 0.754
Epoch 04 | train loss 0.9046 acc 0.692 | valid loss 0.6148 acc 0.783
Epoch 05 | train loss 0.8317 acc 0.717 | valid loss 0.5665 acc 0.805


### (Optional) Fine-tune the whole network

After warming up the head, unfreeze the backbone and fine-tune at a **smaller LR**.

In [13]:
# train the whole thing (unfreeze all layers)
for p in model.parameters():
    p.requires_grad = True

#make lr smaller for fine-tuning
lr1=lr/100

#change learning rate for fine-tuning
for g in optimizer.param_groups:
    g['lr'] = lr1

train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, num_epochs=NUM_EPOCHS)

Epoch 01 | train loss 0.8869 acc 0.699 | valid loss 0.4592 acc 0.838
Epoch 02 | train loss 0.6041 acc 0.796 | valid loss 0.3321 acc 0.875
Epoch 03 | train loss 0.5104 acc 0.829 | valid loss 0.2695 acc 0.906
Epoch 04 | train loss 0.4458 acc 0.850 | valid loss 0.2749 acc 0.910
Epoch 05 | train loss 0.3966 acc 0.868 | valid loss 0.2121 acc 0.940


## 7) Save / Load model

In [14]:

torch.save(model.state_dict(), 'kaggle_paddy_timm_multihead.pth')
print('Saved to kaggle_paddy_timm_multihead.pth')
model.load_state_dict(torch.load('kaggle_paddy_timm_multihead.pth', map_location=device))



Saved to kaggle_paddy_timm_multihead.pth


  model.load_state_dict(torch.load('kaggle_paddy_timm_multihead.pth', map_location=device))


<All keys matched successfully>

## 8) Evaluation & Confusion Matrix

Evaluate the model on the test_loader.  This is the **only** time the model sees this loader

In [15]:
LOADER=test_loader

@torch.no_grad()
def get_all_preds_targets(model, loader):
    model.eval()
    preds, targs = [], []
    for images, variety_idx, age_norm, label_idx in loader:
        images      = images.to(device)
        variety_idx = variety_idx.to(device)
        age_norm    = age_norm.to(device)
        label_idx   = label_idx.to(device)

        logits = model((images, variety_idx, age_norm))

        #just want the label preds
        preds.append(logits.argmax(1).cpu().numpy())
        targs.append(label_idx.cpu().numpy())
    return np.concatenate(preds), np.concatenate(targs)

preds, targs = get_all_preds_targets(model, LOADER)

# Confusion matrix
num_classes = len(class_names)
cm = np.zeros((num_classes, num_classes), dtype=int)
for t, p in zip(targs, preds):
    cm[t, p] += 1

print('Confusion Matrix (6x6 for PaddyDoctor):')
print(cm)

# Per-class metrics
per_class = []
for k in range(num_classes):
    TP = cm[k,k]
    FP = cm[:,k].sum() - TP
    FN = cm[k,:].sum() - TP
    TN = cm.sum() - TP - FP - FN
    prec = TP/(TP+FP) if (TP+FP)>0 else 0.0
    rec  = TP/(TP+FN) if (TP+FN)>0 else 0.0
    f1   = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0
    per_class.append((prec, rec, f1))

macro_p = float(np.mean([p for p,_,_ in per_class]))
macro_r = float(np.mean([r for _,r,_ in per_class]))
macro_f = float(np.mean([f for _,_,f in per_class]))
overall_acc = float((preds == targs).mean())

print('\nPer-class (precision, recall, f1):')
for name,(p,r,f) in zip(class_names, per_class):
    print(f'{name:>10s}: P={p:.3f} R={r:.3f} F1={f:.3f}')
print(f"\nMacro avg: P={macro_p:.3f} R={macro_r:.3f} F1={macro_f:.3f}")
print(f'Overall Accuracy: {overall_acc:.3f}')


Confusion Matrix (6x6 for PaddyDoctor):
[[ 48   2   0   1   2   0   0   3   1   0]
 [  1  42   0   1   0   0   0   0   0   0]
 [  0   0  39   0   0   0   0   0   0   0]
 [  0   0   0 158   1   0   1   1   1   2]
 [  0   1   0   2  92   0   2   0   2   0]
 [  0   0   1   0   0 142   0   0   0   0]
 [  0   0   0   5   0   0  58   0   2   2]
 [  0   0   0   3   0   0   0 137   5   0]
 [  0   0   3   0   1   0   0   3 167   0]
 [  3   0   0   1   0   0   3   3   1  97]]

Per-class (precision, recall, f1):
bacterial_leaf_blight: P=0.923 R=0.842 F1=0.881
bacterial_leaf_streak: P=0.933 R=0.955 F1=0.944
bacterial_panicle_blight: P=0.907 R=1.000 F1=0.951
     blast: P=0.924 R=0.963 F1=0.943
brown_spot: P=0.958 R=0.929 F1=0.944
dead_heart: P=1.000 R=0.993 F1=0.996
downy_mildew: P=0.906 R=0.866 F1=0.885
     hispa: P=0.932 R=0.945 F1=0.938
    normal: P=0.933 R=0.960 F1=0.946
    tungro: P=0.960 R=0.898 F1=0.928

Macro avg: P=0.938 R=0.935 F1=0.936
Overall Accuracy: 0.942
