
# 🌾 Multi‑Head CNN with **timm** on Kaggle Paddy Dataset (`convnext_tiny`)

This notebook demonstrates a **multi‑task** setup using **PyTorch** and **timm** for the Kaggle **Paddy** dataset.

We will:  
1) Set up the environment and choose a timm backbone (`convnext_tiny`).  
2) Build a **custom Dataset** reading `train.csv` (`image_id`, `label`, `variety`, `age`).  
   - Images are stored in subfolders named by **label** (e.g., `train/<label>/<image_id>`).  
   - Each sample returns a tuple: **`(image_tensor, variety_idx, age_float, label_idx)`** as requested.  
3) Create DataLoaders with timm‑compatible transforms.  
4) Define a **multi‑head model**:  
   - Head A → **disease label** classification  
   - Head B → **variety** classification  
   - Head R → **age** regression  
5) Train and evaluate with a minimal, well‑commented loop.


## 1) Setup

In [2]:

# If needed:
# !pip install -q timm pandas

import os, random
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm
from torchvision import transforms

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

DATA_ROOT = Path("./data/") 
TRAIN_CSV = DATA_ROOT / 'train.csv'
TRAIN_IMG_ROOT = DATA_ROOT / 'train'
assert TRAIN_CSV.exists(), f"Missing {TRAIN_CSV}"
assert TRAIN_IMG_ROOT.exists(), f"Missing {TRAIN_IMG_ROOT} (folder of label subdirs)"


Device: cpu


AssertionError: Missing data/train.csv

## 2) Custom Dataset (returns `(image, variety, age, label)`)

In [None]:

class PaddyMultitaskDataset(Dataset):
    def __init__(self, csv_path: Path, img_root: Path, transform=None):
        super().__init__()
        self.df = pd.read_csv(csv_path)
        expected_cols = {'image_id', 'label', 'variety', 'age'}
        missing = expected_cols - set(self.df.columns)
        if missing:
            raise ValueError(f"CSV is missing columns: {missing}")
        self.img_root = Path(img_root)
        self.transform = transform
        self.labels = sorted(self.df['label'].astype(str).unique())
        self.varieties = sorted(self.df['variety'].astype(str).unique())
        self.label_to_idx = {s:i for i,s in enumerate(self.labels)}
        self.variety_to_idx = {s:i for i,s in enumerate(self.varieties)}
        self.df['age'] = pd.to_numeric(self.df['age'], errors='coerce')
        if self.df['age'].isna().any():
            med = float(self.df['age'].median())
            self.df['age'] = self.df['age'].fillna(med)
        self.num_label_classes = len(self.labels)
        self.num_variety_classes = len(self.varieties)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = str(row['image_id'])
        label_name = str(row['label'])
        variety_name = str(row['variety'])
        age_val = float(row['age'])

        img_path = self.img_root / label_name / image_id
        if not img_path.exists():
            for ext in ('.jpg', '.jpeg', '.png', '.bmp'):
                cand = img_path.with_suffix(ext)
                if cand.exists():
                    img_path = cand
                    break
        if not img_path.exists():
            raise FileNotFoundError(f"Image not found for row {idx}: {img_path}")

        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)

        y_label = self.label_to_idx[label_name]
        y_var = self.variety_to_idx[variety_name]
        y_age = age_val

        return img, torch.tensor(y_var, dtype=torch.long), torch.tensor(y_age, dtype=torch.float32), torch.tensor(y_label, dtype=torch.long)


## 3) Transforms (timm) & Train/Valid split

In [None]:

MODEL_NAME = 'convnext_tiny'
cfg = timm.data.resolve_data_config({}, model=MODEL_NAME)
train_tfms = timm.data.create_transform(**cfg, is_training=True, hflip=0.5, auto_augment=None)
valid_tfms = timm.data.create_transform(**cfg, is_training=False)

df = pd.read_csv(TRAIN_CSV)
indices = np.arange(len(df)); np.random.shuffle(indices)
val_frac = 0.1; n_val = max(1, int(len(indices) * val_frac))
val_idx = indices[:n_val]; train_idx = indices[n_val:]

train_csv_tmp = DATA_ROOT / 'train_split.csv'
val_csv_tmp   = DATA_ROOT / 'valid_split.csv'
df.iloc[train_idx].to_csv(train_csv_tmp, index=False)
df.iloc[val_idx].to_csv(val_csv_tmp,   index=False)

train_ds = PaddyMultitaskDataset(train_csv_tmp, TRAIN_IMG_ROOT, transform=train_tfms)
valid_ds = PaddyMultitaskDataset(val_csv_tmp,   TRAIN_IMG_ROOT, transform=valid_tfms)

print('Label classes:', train_ds.labels)
print('Variety classes:', train_ds.varieties)
print('Train/Valid sizes:', len(train_ds), len(valid_ds))


## 4) DataLoaders

In [None]:

BATCH_SIZE = 32
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

xb, y_var, y_age, y_lbl = next(iter(train_loader))
print('Batch shapes:', xb.shape, y_var.shape, y_age.shape, y_lbl.shape)


## 5) Multi‑head model (`convnext_tiny` backbone)

In [None]:

class MultiHeadNet(nn.Module):
    def __init__(self, model_name: str, num_label_classes: int, num_variety_classes: int, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        feat_dim = self.backbone.num_features
        self.head_label   = nn.Linear(feat_dim, num_label_classes)
        self.head_variety = nn.Linear(feat_dim, num_variety_classes)
        self.head_age     = nn.Linear(feat_dim, 1)

    def forward(self, x):
        feats = self.backbone(x)
        logits_label   = self.head_label(feats)
        logits_variety = self.head_variety(feats)
        age_pred       = self.head_age(feats).squeeze(1)
        return {'label': logits_label, 'variety': logits_variety, 'age': age_pred}

model = MultiHeadNet(MODEL_NAME, train_ds.num_label_classes, train_ds.num_variety_classes, pretrained=True).to(device)

for p in model.backbone.parameters():
    p.requires_grad = False
print('Trainable params (heads only):', sum(p.numel() for p in model.parameters() if p.requires_grad))


## 6) Loss functions & optimizer

In [None]:

criteria = {
    'label':   nn.CrossEntropyLoss(),
    'variety': nn.CrossEntropyLoss(),
    'age':     nn.SmoothL1Loss(beta=1.0)
}
loss_weights = {'label': 1.0, 'variety': 0.7, 'age': 0.5}
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4)


## 7) Training / evaluation utilities

In [None]:

def compute_multitask_loss(outputs, targets, criteria, weights=None):
    total = 0.0
    for k in ['label', 'variety', 'age']:
        w = 1.0 if (weights is None or k not in weights) else weights[k]
        total = total + w * criteria[k](outputs[k], targets[k])
    return total

def train_one_epoch(model, loader, optimizer, criteria, device, weights=None):
    model.train()
    run_loss, n = 0.0, 0
    correct_label, correct_variety, total = 0, 0, 0
    for images, y_var, y_age, y_lbl in loader:
        images = images.to(device); y_lbl = y_lbl.to(device); y_var = y_var.to(device); y_age = y_age.to(device)
        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = compute_multitask_loss(outputs, {'label': y_lbl, 'variety': y_var, 'age': y_age}, criteria, weights)
        loss.backward()
        optimizer.step()
        bs = images.size(0)
        run_loss += loss.item() * bs; n += bs; total += bs
        correct_label   += (outputs['label'].argmax(1) == y_lbl).sum().item()
        correct_variety += (outputs['variety'].argmax(1) == y_var).sum().item()
    return run_loss / n, correct_label / total, correct_variety / total

@torch.no_grad()
def evaluate(model, loader, criteria, device, weights=None):
    model.eval()
    run_loss, n = 0.0, 0
    correct_label, correct_variety, total, mae_age = 0, 0, 0, 0.0
    for images, y_var, y_age, y_lbl in loader:
        images = images.to(device); y_lbl = y_lbl.to(device); y_var = y_var.to(device); y_age = y_age.to(device)
        outputs = model(images)
        loss = compute_multitask_loss(outputs, {'label': y_lbl, 'variety': y_var, 'age': y_age}, criteria, weights)
        bs = images.size(0)
        run_loss += loss.item() * bs; n += bs; total += bs
        correct_label   += (outputs['label'].argmax(1) == y_lbl).sum().item()
        correct_variety += (outputs['variety'].argmax(1) == y_var).sum().item()
        mae_age         += torch.abs(outputs['age'] - y_age).sum().item()
    return run_loss / n, correct_label / total, correct_variety / total, mae_age / total


## 8) Train & validate (warm‑up on heads)

In [None]:

EPOCHS = 3
for ep in range(1, EPOCHS + 1):
    tr_loss, tr_acc_lbl, tr_acc_var = train_one_epoch(model, train_loader, optimizer, criteria, device, weights=loss_weights)
    va_loss, va_acc_lbl, va_acc_var, va_mae_age = evaluate(model, valid_loader, criteria, device, weights=loss_weights)
    print(f"epoch {ep:02d} | train loss {tr_loss:.4f} | val loss {va_loss:.4f} | "
          f"label acc {va_acc_lbl:.3f} | variety acc {va_acc_var:.3f} | age MAE {va_mae_age:.2f}")


## 9) (Optional) Unfreeze & fine‑tune the backbone

In [None]:

for p in model.backbone.parameters():
    p.requires_grad = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

EPOCHS_FT = 2
for ep in range(1, EPOCHS_FT + 1):
    tr_loss, tr_acc_lbl, tr_acc_var = train_one_epoch(model, train_loader, optimizer, criteria, device, weights=loss_weights)
    va_loss, va_acc_lbl, va_acc_var, va_mae_age = evaluate(model, valid_loader, criteria, device, weights=loss_weights)
    print(f"[FT] epoch {ep:02d} | train loss {tr_loss:.4f} | val loss {va_loss:.4f} | "
          f"label acc {va_acc_lbl:.3f} | variety acc {va_acc_var:.3f} | age MAE {va_mae_age:.2f}")


## 10) Save / Load

In [None]:

torch.save(model.state_dict(), 'multitask_convnext_tiny_paddy.pth')
print('Saved to multitask_convnext_tiny_paddy.pth')
# model.load_state_dict(torch.load('multitask_convnext_tiny_paddy.pth', map_location=device))
# model.eval()
