
# Paddy Disease Classification — **PyTorch + timm** and Multi‑Head model

This notebook demonstrates a **multi‑task** setup using **PyTorch** and and **[timm](https://github.com/huggingface/pytorch-image-models)** (Torch Image Models) for the Kaggle **Paddy** dataset.

1) Set up the environment and choose a timm backbone (`convnext_tiny`).  
2) Build a **custom Dataset** reading `train.csv` (`image_id`, `label`, `variety`, `age`).  
   - Images are stored in subfolders named by **label** (e.g., `train/<label>/<image_id>`).  
   - Each sample returns a tuple: **`(image_tensor, variety_idx, age_float, label_idx)`** as requested.  
3) Create DataLoaders with timm‑compatible transforms.  
4) Define a **multi‑head model**:  
   - Head A → **disease label** classification  
   - Head B → **variety** classification  
   - Head R → **age** regression  
5) Train and evaluate with a minimal, well‑commented loop.

> **Note:** Point `DATA_DIR` to your local Kaggle Paddy dataset. 


## 1) Setup & Configuration

In [1]:

# If timm isn't installed, uncomment:
# !pip install timm --quiet

import os, random, math, time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split

from torchvision import datasets, transforms
import timm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

DATA_DIR = Path("./data/")  
TRAIN_CSV = DATA_DIR / 'train.csv'
TRAIN_IMG_ROOT = DATA_DIR   / 'train_images'

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print("Device:", device)

import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

#this makes sure that colab instances running this notebook can get to included utils_kp.py
try:
    import utils_kp as ut
except ModuleNotFoundError:
    !wget https://raw.githubusercontent.com/CNUClasses/471_code/master/week6/utils_kp.py
import utils_kp as ut
 
BATCH_SIZE = 256

# lr=0.01 #start with this one (its too high see lr finder below)
lr=2e-3 #lr finder approved this one
NUM_EPOCHS=5

Device: cuda:1


## 2) Custom Dataset (returns `(image, variety, age, label)`)


In [2]:

# see utils_kp.py for PaddyDataset class


## 3) Create datasets with timm transforms 

Not going to use ImageFolder dataset.  Instead use PaddyMultitaskDataset 



In [3]:
import pprint
pp = pprint.PrettyPrinter(indent=4)

# MODEL_NAME = "resnet18"  # try: 'efficientnet_b0', 'convnext_tiny', 'mobilenetv3_large_100', ...
MODEL_NAME = 'convnext_tiny'
config = timm.data.resolve_data_config({}, model=MODEL_NAME)
train_tfms = timm.data.create_transform(**config, is_training=True, hflip=0.5, auto_augment=None)
valid_tfms = timm.data.create_transform(**config, is_training=False)

pp.pprint(config)

{   'crop_mode': 'center',
    'crop_pct': 0.875,
    'input_size': (3, 224, 224),
    'interpolation': 'bicubic',
    'mean': (0.485, 0.456, 0.406),
    'std': (0.229, 0.224, 0.225)}


In [4]:
#get the split datasets
train_df, valid_df, test_df=ut.getdataframes(DATA_DIR, TRAIN_CSV, TRAIN_IMG_ROOT, valid_pct=0.1, test_pct=0.1,verbose=True )

train_ds = ut.PaddyMultitaskDataset(train_df, TRAIN_IMG_ROOT, transform=train_tfms)
valid_ds = ut.PaddyMultitaskDataset(valid_df,   TRAIN_IMG_ROOT, transform=valid_tfms)
test_ds  = ut.PaddyMultitaskDataset(test_df,  TRAIN_IMG_ROOT, transform=valid_tfms)

print('Label classes:', train_ds.labels)
print('Variety classes:', train_ds.varieties)
print('Train/Valid/test sizes:', len(train_ds), len(valid_ds), len(test_ds))


Label classes: ['bacterial_leaf_blight', 'bacterial_leaf_streak', 'bacterial_panicle_blight', 'blast', 'brown_spot', 'dead_heart', 'downy_mildew', 'hispa', 'normal', 'tungro']
Variety classes: ['ADT45', 'AndraPonni', 'AtchayaPonni', 'IR20', 'KarnatakaPonni', 'Onthanel', 'Ponni', 'RR', 'Surya', 'Zonal']
Train/Valid/test sizes: 8327 1040 1040


In [5]:
class_names=train_ds.labels
class_names

['bacterial_leaf_blight',
 'bacterial_leaf_streak',
 'bacterial_panicle_blight',
 'blast',
 'brown_spot',
 'dead_heart',
 'downy_mildew',
 'hispa',
 'normal',
 'tungro']

## 3) Create Dataloaders

- Shuffle the training loader; keep validation loader deterministic.  
- Adjust `BATCH_SIZE` to fit your GPU/CPU memory.

This is a datascience competition:<br>
the train_images folder contains images with class membership info (in the train.csv file).<br>
the test_images folder contains images that your model infers membership on.  These inferences are bundled into a file (see sample_submission.csv) which is submitted for ranking  

In [6]:

BATCH_SIZE = 256
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

xb, y_var, y_age, y_lbl = next(iter(train_loader))
print('Batch shapes:', xb.shape, y_var.shape, y_age.shape, y_lbl.shape)



Batch shapes: torch.Size([256, 3, 224, 224]) torch.Size([256]) torch.Size([256]) torch.Size([256])


## 4) Load Multi‑head timm Model (Transfer Learning with **timm**)

- Create a **pretrained** model with 3 heads, one to predict label, one to predict variety, and one to predict age
- **Warm-up:** freeze backbone; train the classifier head first.



In [7]:

class MultiHeadNet(nn.Module):
    """
    MultiHeadNet: a simple multi-task head on top of a timm backbone.

    Architecture:
     - backbone (timm model, num_classes=0, global_pool='avg') produces pooled feature vector of shape (B, feat_dim)
     - head_label   : Linear(feat_dim -> num_label_classes)    -> classification logits for disease label
     - head_variety : Linear(feat_dim -> num_variety_classes)  -> classification logits for variety
     - head_age     : Linear(feat_dim -> 1)                    -> scalar age prediction (regression)

    Forward input:
     - x : image tensor of shape (B, 3, H, W)

    Forward output:
     - dict with keys 'label', 'variety', 'age'
        - 'label'   : Tensor (B, num_label_classes)  (use CrossEntropyLoss)
        - 'variety' : Tensor (B, num_variety_classes) (use CrossEntropyLoss)
        - 'age'     : Tensor (B,)                     (use MSELoss or L1)

    Notes:
     - The backbone is created in __init__ below; you can freeze it after instantiation:
         for p in model.backbone.parameters(): p.requires_grad = False
     - This class is intended for transfer learning: warm up heads first, then optionally fine-tune backbone.
    """
    def __init__(self, model_name: str, num_label_classes: int, num_variety_classes: int, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        feat_dim = self.backbone.num_features
        self.head_label   = nn.Linear(feat_dim, num_label_classes)
        self.head_variety = nn.Linear(feat_dim, num_variety_classes)
        self.head_age     = nn.Linear(feat_dim, 1)

    def forward(self, x):
        feats = self.backbone(x)
        logits_label   = self.head_label(feats)
        logits_variety = self.head_variety(feats)
        age_pred       = self.head_age(feats).squeeze(1)
        return {'label': logits_label, 'variety': logits_variety, 'age': age_pred}  #returns a dictionary of outputs

model = MultiHeadNet(MODEL_NAME, train_ds.num_label_classes, train_ds.num_variety_classes, pretrained=True).to(device)

for p in model.backbone.parameters():
    p.requires_grad = False
print('Trainable params (heads only):', sum(p.numel() for p in model.parameters() if p.requires_grad))


Trainable params (heads only): 16149


In [8]:
#check to see the model structure
# print(model.backbone)
# print(model)

In [9]:
#make sure just training the last layer
# for name, p in model.named_parameters():
#     print (f'Name={name},p.shape={p.shape}, p.requires_grad = {p.requires_grad}')

#stopped here 10/1/25


## 5) Optimizer & loss 


In [10]:

criterion = {
    'label':   nn.CrossEntropyLoss(),
    'variety': nn.CrossEntropyLoss(),
    'age':     nn.MSELoss()
}
loss_weights = {'label': 1.0, 'variety': 0.7, 'age': 0.5}
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)


# 6) Training & validation loop

- Log train/valid **loss** and **accuracy** per epoch.  
- warm up heads

In [11]:
%%time 

def compute_multitask_loss(outputs, targets, criteria, weights=None):
    total = 0.0
    for k in ['label', 'variety', 'age']:
        w = 1.0 if (weights is None) else weights[k]
        # print(f'for key={k}, loss={criteria[k](outputs[k], targets[k])}')
        total = total + w * criteria[k](outputs[k], targets[k])

    #maybe return the average?
    return total

def train_one_epoch(model, loader, optimizer, criteria, device, weights=None):
    model.train()
    run_loss= 0.0
    correct_label, correct_variety, total = 0, 0, 0
    for images, y_var, y_age, y_lbl in loader:
        images = images.to(device); y_lbl = y_lbl.to(device); y_var = y_var.to(device); y_age = y_age.to(device)
        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)  #only feeding in images
        loss = compute_multitask_loss(outputs, {'label': y_lbl, 'variety': y_var, 'age': y_age}, criteria, weights)
        loss.backward()
        optimizer.step()
        bs = images.size(0)
        run_loss += loss.item() * bs;  
        total += bs
        correct_label   += (outputs['label'].argmax(1) == y_lbl).sum().item()
        correct_variety += (outputs['variety'].argmax(1) == y_var).sum().item()
    return run_loss / total, correct_label / total, correct_variety / total

@torch.no_grad()
def evaluate(model, loader, criteria, device, weights=None):
    model.eval()
    run_loss= 0.0
    correct_label, correct_variety, total, mae_age = 0, 0, 0, 0.0
    for images, y_var, y_age, y_lbl in loader:
        images = images.to(device); y_lbl = y_lbl.to(device); y_var = y_var.to(device); y_age = y_age.to(device)
        outputs = model(images)
        loss = compute_multitask_loss(outputs, {'label': y_lbl, 'variety': y_var, 'age': y_age}, criteria, weights)
        bs = images.size(0)
        run_loss += loss.item() * bs; total += bs
        correct_label   += (outputs['label'].argmax(1) == y_lbl).sum().item()
        correct_variety += (outputs['variety'].argmax(1) == y_var).sum().item()
        mae_age         += torch.abs(outputs['age'] - y_age).sum().item()
    return run_loss / total, correct_label / total, correct_variety / total, mae_age / total

def train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, loss_weights=None, num_epochs=10):
    for epoch in range(1, num_epochs + 1):
        # Train for one epoch
        tr_loss, tr_acc_lbl, tr_acc_var = train_one_epoch(model, train_loader, optimizer, criterion, device, weights=loss_weights)

        # Evaluate on validation set
        va_loss, va_acc_lbl, va_acc_var, va_mae_age = evaluate(model, valid_loader, criterion, device, weights=loss_weights)
        print(f"Epoch {epoch:02d} | train loss={tr_loss:.4f} tr_acc_label={tr_acc_lbl:.3f} tr_acc_variety={tr_acc_var:.3f} | valid loss={va_loss:.4f} va_acc_label={va_acc_lbl:.3f} va_acc_variety={va_acc_var:.3f} va_mae_age={va_mae_age:.3f}")

train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, loss_weights=loss_weights, num_epochs=NUM_EPOCHS)


Epoch 01 | train loss=2.8847 tr_acc_label=0.443 tr_acc_variety=0.680 | valid loss=2.0369 va_acc_label=0.622 va_acc_variety=0.714 va_mae_age=0.696
Epoch 02 | train loss=2.0492 tr_acc_label=0.596 tr_acc_variety=0.752 | valid loss=1.7100 va_acc_label=0.643 va_acc_variety=0.751 va_mae_age=0.585
Epoch 03 | train loss=1.8736 tr_acc_label=0.628 tr_acc_variety=0.779 | valid loss=1.5222 va_acc_label=0.714 va_acc_variety=0.802 va_mae_age=0.566
Epoch 04 | train loss=1.7485 tr_acc_label=0.654 tr_acc_variety=0.797 | valid loss=1.4017 va_acc_label=0.736 va_acc_variety=0.809 va_mae_age=0.571
Epoch 05 | train loss=1.6760 tr_acc_label=0.670 tr_acc_variety=0.808 | valid loss=1.3235 va_acc_label=0.761 va_acc_variety=0.826 va_mae_age=0.548
CPU times: user 39.8 s, sys: 8.51 s, total: 48.3 s
Wall time: 3min 32s


In [12]:
def eval(loader=test_loader):
    run_loss, correct_label, correct_variety, mae_age = evaluate(model, loader, criterion, device, weights=loss_weights)
    print(f"loss={run_loss:.4f} acc_label={correct_label:.3f} acc_variety={correct_variety:.3f}")

eval()

loss=1.3664 acc_label=0.744 acc_variety=0.828


### (Optional) Fine-tune the whole network

After warming up the head, unfreeze the backbone and fine-tune at a **smaller LR**.

In [13]:
%%time 

# train the whole thing (unfreeze all layers)
for p in model.parameters():
    p.requires_grad = True

#make lr smaller for fine-tuning
lr1=lr/100

#change learning rate for fine-tuning
for g in optimizer.param_groups:
    g['lr'] = lr1
    
train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, loss_weights=loss_weights, num_epochs=NUM_EPOCHS)


Epoch 01 | train loss=1.6699 tr_acc_label=0.680 tr_acc_variety=0.807 | valid loss=0.9981 va_acc_label=0.830 va_acc_variety=0.854 va_mae_age=0.508
Epoch 02 | train loss=1.1864 tr_acc_label=0.774 tr_acc_variety=0.859 | valid loss=0.7778 va_acc_label=0.861 va_acc_variety=0.912 va_mae_age=0.468
Epoch 03 | train loss=0.9459 tr_acc_label=0.817 tr_acc_variety=0.893 | valid loss=0.6062 va_acc_label=0.895 va_acc_variety=0.927 va_mae_age=0.409
Epoch 04 | train loss=0.8361 tr_acc_label=0.840 tr_acc_variety=0.914 | valid loss=0.5067 va_acc_label=0.905 va_acc_variety=0.953 va_mae_age=0.382
Epoch 05 | train loss=0.6903 tr_acc_label=0.870 tr_acc_variety=0.932 | valid loss=0.4626 va_acc_label=0.913 va_acc_variety=0.952 va_mae_age=0.362
CPU times: user 2min 8s, sys: 7.63 s, total: 2min 16s
Wall time: 3min 34s


## 7) Save / Load model

In [14]:

torch.save(model.state_dict(), 'kaggle_paddy_timm_multihead.pth')
print('Saved to kaggle_paddy_timm_multihead.pth')
model.load_state_dict(torch.load('kaggle_paddy_timm_multihead.pth', map_location=device))



Saved to kaggle_paddy_timm_multihead.pth


  model.load_state_dict(torch.load('kaggle_paddy_timm_multihead.pth', map_location=device))


<All keys matched successfully>

## 8) Evaluation & Confusion Matrix

Evaluate the model on the test_loader.  This is the **only** time the model sees this loader

In [15]:
LOADER=test_loader

@torch.no_grad()
def get_all_preds_targets(model, loader):
    model.eval()
    preds, targs = [], []
    for xb,_,_,yb in loader:
        xb = xb.to(device)
        logits = model(xb)

        #just want the label preds
        preds.append(logits['label'].argmax(1).cpu().numpy())
        targs.append(yb.numpy())
    return np.concatenate(preds), np.concatenate(targs)

preds, targs = get_all_preds_targets(model, LOADER)

# Confusion matrix
num_classes = len(class_names)
cm = np.zeros((num_classes, num_classes), dtype=int)
for t, p in zip(targs, preds):
    cm[t, p] += 1

print('Confusion Matrix (6x6 for PaddyDoctor):')
print(cm)

# Per-class metrics
per_class = []
for k in range(num_classes):
    TP = cm[k,k]
    FP = cm[:,k].sum() - TP
    FN = cm[k,:].sum() - TP
    TN = cm.sum() - TP - FP - FN
    prec = TP/(TP+FP) if (TP+FP)>0 else 0.0
    rec  = TP/(TP+FN) if (TP+FN)>0 else 0.0
    f1   = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0
    per_class.append((prec, rec, f1))

macro_p = float(np.mean([p for p,_,_ in per_class]))
macro_r = float(np.mean([r for _,r,_ in per_class]))
macro_f = float(np.mean([f for _,_,f in per_class]))
overall_acc = float((preds == targs).mean())

print('\nPer-class (precision, recall, f1):')
for name,(p,r,f) in zip(class_names, per_class):
    print(f'{name:>10s}: P={p:.3f} R={r:.3f} F1={f:.3f}')
print(f"\nMacro avg: P={macro_p:.3f} R={macro_r:.3f} F1={macro_f:.3f}")
print(f'Overall Accuracy: {overall_acc:.3f}')


Confusion Matrix (6x6 for PaddyDoctor):
[[ 45   1   0   2   3   0   1   2   0   3]
 [  0  42   0   0   1   0   0   0   0   1]
 [  0   0  36   0   0   0   0   0   3   0]
 [  1   2   0 145   2   0   0   3   0  11]
 [  2   2   0   2  91   0   0   0   2   0]
 [  0   0   1   0   0 142   0   0   0   0]
 [  2   0   0   3   2   1  51   0   2   6]
 [  1   1   0   3   0   0   0 129   8   3]
 [  0   0   3   0   0   0   0   3 168   0]
 [  0   0   0   0   1   0   2   3   1 101]]

Per-class (precision, recall, f1):
bacterial_leaf_blight: P=0.882 R=0.789 F1=0.833
bacterial_leaf_streak: P=0.875 R=0.955 F1=0.913
bacterial_panicle_blight: P=0.900 R=0.923 F1=0.911
     blast: P=0.935 R=0.884 F1=0.909
brown_spot: P=0.910 R=0.919 F1=0.915
dead_heart: P=0.993 R=0.993 F1=0.993
downy_mildew: P=0.944 R=0.761 F1=0.843
     hispa: P=0.921 R=0.890 F1=0.905
    normal: P=0.913 R=0.966 F1=0.939
    tungro: P=0.808 R=0.935 F1=0.867

Macro avg: P=0.908 R=0.901 F1=0.903
Overall Accuracy: 0.913
