
# Paddy Disease Classification — **PyTorch + timm** and Multi‑Head model

This notebook demonstrates a **multi‑task** setup using **PyTorch** and and **[timm](https://github.com/huggingface/pytorch-image-models)** (Torch Image Models) for the Kaggle **Paddy** dataset.

1) Set up the environment and choose a timm backbone (`convnext_tiny`).  
2) Build a **custom Dataset** reading `train.csv` (`image_id`, `label`, `variety`, `age`).  
   - Images are stored in subfolders named by **label** (e.g., `train/<label>/<image_id>`).  
   - Each sample returns a tuple: **`(image_tensor, variety_idx, age_float, label_idx)`** as requested.  
3) Create DataLoaders with timm‑compatible transforms.  
4) Define a **multi‑head model**:  
   - Head A → **disease label** classification  
   - Head B → **variety** classification  
   - Head R → **age** regression  
5) Train and evaluate with a minimal, well‑commented loop.

> **Note:** Point `DATA_DIR` to your local Kaggle Paddy dataset. 


## 1) Setup & Configuration

In [1]:

# If timm isn't installed, uncomment:
# !pip install timm --quiet

import os, random, math, time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split

from torchvision import datasets, transforms
import timm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

DATA_DIR = Path("./data/")  
TRAIN_CSV = DATA_DIR / 'train.csv'
TRAIN_IMG_ROOT = DATA_DIR   / 'train_images'

USE_IMAGEFOLDER = False             # set False to use CSV dataset class below

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
 


Device: cuda


## 2) Custom Dataset (returns `(image, variety, age, label)`)


In [2]:

class PaddyMultitaskDataset(Dataset):
    def __init__(self, csv_path: Path, img_root: Path, transform=None):
        super().__init__()
        self.df = pd.read_csv(csv_path)
        expected_cols = {'image_id', 'label', 'variety', 'age'}
        missing = expected_cols - set(self.df.columns)
        if missing:
            raise ValueError(f"CSV is missing columns: {missing}")
        self.img_root = Path(img_root)
        self.transform = transform
        self.labels = sorted(self.df['label'].astype(str).unique())
        self.varieties = sorted(self.df['variety'].astype(str).unique())
        self.label_to_idx = {s:i for i,s in enumerate(self.labels)}
        self.variety_to_idx = {s:i for i,s in enumerate(self.varieties)}
        self.df['age'] = pd.to_numeric(self.df['age'], errors='coerce')
        if self.df['age'].isna().any():
            med = float(self.df['age'].median())
            self.df['age'] = self.df['age'].fillna(med)
        self.num_label_classes = len(self.labels)
        self.num_variety_classes = len(self.varieties)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = str(row['image_id'])
        label_name = str(row['label'])
        variety_name = str(row['variety'])
        age_val = float(row['age'])

        img_path = self.img_root / label_name / image_id
        if not img_path.exists():
            for ext in ('.jpg', '.jpeg', '.png', '.bmp'):
                cand = img_path.with_suffix(ext)
                if cand.exists():
                    img_path = cand
                    break
        if not img_path.exists():
            raise FileNotFoundError(f"Image not found for row {idx}: {img_path}")

        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)

        y_label = self.label_to_idx[label_name]
        y_var = self.variety_to_idx[variety_name]
        y_age = age_val

        return img, torch.tensor(y_var, dtype=torch.long), torch.tensor(y_age, dtype=torch.float32), torch.tensor(y_label, dtype=torch.long)



## 3) Create datasets with timm transforms 

Not going to use ImageFolder dataset.  Instead use PaddyMultitaskDataset 



In [3]:
import pprint
pp = pprint.PrettyPrinter(indent=4)

# MODEL_NAME = "resnet18"  # try: 'efficientnet_b0', 'convnext_tiny', 'mobilenetv3_large_100', ...
MODEL_NAME = 'convnext_tiny'
config = timm.data.resolve_data_config({}, model=MODEL_NAME)
train_tfms = timm.data.create_transform(**config, is_training=True, hflip=0.5, auto_augment=None)
valid_tfms = timm.data.create_transform(**config, is_training=False)


In [4]:
#split the dataset into train/validation/test sets
df = pd.read_csv(TRAIN_CSV)

indices = np.arange(len(df))
np.random.shuffle(indices) #shuffle the indexes before splitting
ten_percent = int(len(indices) * 0.1)

tst_idx = indices[ :ten_percent]
val_idx = indices[ten_percent:2*ten_percent]
trn_idx = indices[2*ten_percent: ]
print(f"Train: {len(trn_idx)}, Valid: {len(val_idx)}, Test: {len(tst_idx)}")

train_csv_tmp = DATA_DIR / 'train_split.csv'
val_csv_tmp   = DATA_DIR / 'valid_split.csv'
test_csv_tmp  = DATA_DIR / 'test_split.csv'
df.iloc[tst_idx].to_csv(test_csv_tmp,  index=False)
df.iloc[trn_idx].to_csv(train_csv_tmp, index=False)
df.iloc[val_idx].to_csv(val_csv_tmp,   index=False)

train_ds = PaddyMultitaskDataset(train_csv_tmp, TRAIN_IMG_ROOT, transform=train_tfms)
valid_ds = PaddyMultitaskDataset(val_csv_tmp,   TRAIN_IMG_ROOT, transform=valid_tfms)
test_ds  = PaddyMultitaskDataset(test_csv_tmp,  TRAIN_IMG_ROOT, transform=valid_tfms)

#should be stratified split, but close enough for now
# train_ds, valid_ds = random_split(full_ds, [n_train, n_valid], generator=torch.Generator().manual_seed(42))

print('Label classes:', train_ds.labels)
print('Variety classes:', train_ds.varieties)
print('Train/Valid/test sizes:', len(train_ds), len(valid_ds), len(test_ds))


Train: 8327, Valid: 1040, Test: 1040
Label classes: ['bacterial_leaf_blight', 'bacterial_leaf_streak', 'bacterial_panicle_blight', 'blast', 'brown_spot', 'dead_heart', 'downy_mildew', 'hispa', 'normal', 'tungro']
Variety classes: ['ADT45', 'AndraPonni', 'AtchayaPonni', 'IR20', 'KarnatakaPonni', 'Onthanel', 'Ponni', 'RR', 'Surya', 'Zonal']
Train/Valid/test sizes: 8327 1040 1040


In [5]:
class_names=train_ds.labels

## 3) Create Dataloaders

- Shuffle the training loader; keep validation loader deterministic.  
- Adjust `BATCH_SIZE` to fit your GPU/CPU memory.

This is a datascience competition:<br>
the train_images folder contains images with class membership info (in the train.csv file).<br>
the test_images folder contains images that your model infers membership on.  These inferences are bundled into a file (see sample_submission.csv) which is submitted for ranking  

In [6]:

BATCH_SIZE = 256
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

xb, y_var, y_age, y_lbl = next(iter(train_loader))
print('Batch shapes:', xb.shape, y_var.shape, y_age.shape, y_lbl.shape)



Batch shapes: torch.Size([256, 3, 224, 224]) torch.Size([256]) torch.Size([256]) torch.Size([256])


## 4) Load Multi‑head timm Model (Transfer Learning with **timm**)

- Create a **pretrained** model with 3 heads, one to predict label, one to predict variety, and one to predict age
- **Warm-up:** freeze backbone; train the classifier head first.



In [7]:

class MultiHeadNet(nn.Module):
    def __init__(self, model_name: str, num_label_classes: int, num_variety_classes: int, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        feat_dim = self.backbone.num_features
        self.head_label   = nn.Linear(feat_dim, num_label_classes)
        self.head_variety = nn.Linear(feat_dim, num_variety_classes)
        self.head_age     = nn.Linear(feat_dim, 1)

    def forward(self, x):
        feats = self.backbone(x)
        logits_label   = self.head_label(feats)
        logits_variety = self.head_variety(feats)
        age_pred       = self.head_age(feats).squeeze(1)
        return {'label': logits_label, 'variety': logits_variety, 'age': age_pred}

model = MultiHeadNet(MODEL_NAME, train_ds.num_label_classes, train_ds.num_variety_classes, pretrained=True).to(device)

for p in model.backbone.parameters():
    p.requires_grad = False
print('Trainable params (heads only):', sum(p.numel() for p in model.parameters() if p.requires_grad))


Trainable params (heads only): 16149


In [8]:
model.backbone

ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (norm): Identity()
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (shortcut): Identity()
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)


In [9]:
#make sure just training the last layer
for name, p in model.named_parameters():
    print (f'Name={name},p.shape={p.shape}, p.requires_grad = {p.requires_grad}')


Name=backbone.stem.0.weight,p.shape=torch.Size([96, 3, 4, 4]), p.requires_grad = False
Name=backbone.stem.0.bias,p.shape=torch.Size([96]), p.requires_grad = False
Name=backbone.stem.1.weight,p.shape=torch.Size([96]), p.requires_grad = False
Name=backbone.stem.1.bias,p.shape=torch.Size([96]), p.requires_grad = False
Name=backbone.stages.0.blocks.0.gamma,p.shape=torch.Size([96]), p.requires_grad = False
Name=backbone.stages.0.blocks.0.conv_dw.weight,p.shape=torch.Size([96, 1, 7, 7]), p.requires_grad = False
Name=backbone.stages.0.blocks.0.conv_dw.bias,p.shape=torch.Size([96]), p.requires_grad = False
Name=backbone.stages.0.blocks.0.norm.weight,p.shape=torch.Size([96]), p.requires_grad = False
Name=backbone.stages.0.blocks.0.norm.bias,p.shape=torch.Size([96]), p.requires_grad = False
Name=backbone.stages.0.blocks.0.mlp.fc1.weight,p.shape=torch.Size([384, 96]), p.requires_grad = False
Name=backbone.stages.0.blocks.0.mlp.fc1.bias,p.shape=torch.Size([384]), p.requires_grad = False
Name=backb

In [10]:
#stopped here 9/29/25
print(model)

MultiHeadNet(
  (backbone): ConvNeXt(
    (stem): Sequential(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=96, out_features=384, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): Linear(in_features=384, out_features=96, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (shortcut): Identity()
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock(
            (conv_dw): C


## 5) Optimizer & loss 


In [11]:
lr=2e-3
criterion = {
    'label':   nn.CrossEntropyLoss(),
    'variety': nn.CrossEntropyLoss(),
    'age':     nn.MSELoss()
}
loss_weights = {'label': 1.0, 'variety': 0.7, 'age': 0.5}
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)


# 6) Training & validation loop

- Log train/valid **loss** and **accuracy** per epoch.  
- warm up heads

In [12]:
%%time 

def compute_multitask_loss(outputs, targets, criteria, weights=None):
    total = 0.0
    for k in ['label', 'variety', 'age']:
        w = 1.0 if (weights is None or k not in weights) else weights[k]
        total = total + w * criteria[k](outputs[k], targets[k])
    return total

def train_one_epoch(model, loader, optimizer, criteria, device, weights=None):
    model.train()
    run_loss= 0.0
    correct_label, correct_variety, total = 0, 0, 0
    for images, y_var, y_age, y_lbl in loader:
        images = images.to(device); y_lbl = y_lbl.to(device); y_var = y_var.to(device); y_age = y_age.to(device)
        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)  #only feeding in images
        loss = compute_multitask_loss(outputs, {'label': y_lbl, 'variety': y_var, 'age': y_age}, criteria, weights)
        loss.backward()
        optimizer.step()
        bs = images.size(0)
        run_loss += loss.item() * bs;  
        total += bs
        correct_label   += (outputs['label'].argmax(1) == y_lbl).sum().item()
        correct_variety += (outputs['variety'].argmax(1) == y_var).sum().item()
    return run_loss / total, correct_label / total, correct_variety / total

@torch.no_grad()
def evaluate(model, loader, criteria, device, weights=None):
    model.eval()
    run_loss= 0.0
    correct_label, correct_variety, total, mae_age = 0, 0, 0, 0.0
    for images, y_var, y_age, y_lbl in loader:
        images = images.to(device); y_lbl = y_lbl.to(device); y_var = y_var.to(device); y_age = y_age.to(device)
        outputs = model(images)
        loss = compute_multitask_loss(outputs, {'label': y_lbl, 'variety': y_var, 'age': y_age}, criteria, weights)
        bs = images.size(0)
        run_loss += loss.item() * bs; total += bs
        correct_label   += (outputs['label'].argmax(1) == y_lbl).sum().item()
        correct_variety += (outputs['variety'].argmax(1) == y_var).sum().item()
        mae_age         += torch.abs(outputs['age'] - y_age).sum().item()
    return run_loss / total, correct_label / total, correct_variety / total, mae_age / total

def train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, loss_weights=None, num_epochs=10):
    for epoch in range(1, num_epochs + 1):
        # Train for one epoch
        tr_loss, tr_acc_lbl, tr_acc_var = train_one_epoch(model, train_loader, optimizer, criterion, device, weights=loss_weights)

        # Evaluate on validation set
        va_loss, va_acc_lbl, va_acc_var, va_mae_age = evaluate(model, valid_loader, criterion, device, weights=loss_weights)
        print(f"Epoch {epoch:02d} | train loss={tr_loss:.4f} tr_acc_label={tr_acc_lbl:.3f} tr_acc_variety={tr_acc_var:.3f} | valid loss={va_loss:.4f} va_acc_label={va_acc_lbl:.3f} va_acc_variety={va_acc_var:.3f} va_mae_age={va_mae_age:.3f}")

train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, loss_weights=loss_weights, num_epochs=20)


Epoch 01 | train loss=1153.2969 tr_acc_label=0.443 tr_acc_variety=0.680 | valid loss=382.1500 va_acc_label=0.622 va_acc_variety=0.714 va_mae_age=25.305
Epoch 02 | train loss=198.7945 tr_acc_label=0.596 tr_acc_variety=0.752 | valid loss=81.0375 va_acc_label=0.643 va_acc_variety=0.751 va_mae_age=9.989
Epoch 03 | train loss=75.3589 tr_acc_label=0.628 tr_acc_variety=0.779 | valid loss=69.9145 va_acc_label=0.714 va_acc_variety=0.802 va_mae_age=9.131
Epoch 04 | train loss=63.7352 tr_acc_label=0.654 tr_acc_variety=0.797 | valid loss=60.4969 va_acc_label=0.736 va_acc_variety=0.809 va_mae_age=8.443
Epoch 05 | train loss=58.4966 tr_acc_label=0.670 tr_acc_variety=0.808 | valid loss=54.5497 va_acc_label=0.761 va_acc_variety=0.826 va_mae_age=7.993
Epoch 06 | train loss=54.2428 tr_acc_label=0.685 tr_acc_variety=0.819 | valid loss=50.6365 va_acc_label=0.775 va_acc_variety=0.840 va_mae_age=7.685
Epoch 07 | train loss=52.6000 tr_acc_label=0.692 tr_acc_variety=0.823 | valid loss=48.3230 va_acc_label=0.7

In [13]:
def eval(loader=test_loader):
    run_loss, correct_label, correct_variety, mae_age = evaluate(model, loader, criterion, device, weights=loss_weights)
    print(f"loss={run_loss:.4f} acc_label={correct_label:.3f} acc_variety={correct_variety:.3f}")

eval()

loss=34.8115 acc_label=0.809 acc_variety=0.870


### (Optional) Fine-tune the whole network

After warming up the head, unfreeze the backbone and fine-tune at a **smaller LR**.

In [14]:
%%time 

# train the whole thing (unfreeze all layers)
for p in model.parameters():
    p.requires_grad = True

#make lr smaller for fine-tuning
lr1=lr/10

#change learning rate for fine-tuning
for g in optimizer.param_groups:
    g['lr'] = lr1

train_and_evaluate(model, train_loader, valid_loader, optimizer, criterion, loss_weights=loss_weights, num_epochs=20)


Epoch 01 | train loss=39.3526 tr_acc_label=0.751 tr_acc_variety=0.874 | valid loss=34.4629 va_acc_label=0.828 va_acc_variety=0.896 va_mae_age=6.360
Epoch 02 | train loss=38.7105 tr_acc_label=0.761 tr_acc_variety=0.870 | valid loss=34.4670 va_acc_label=0.830 va_acc_variety=0.895 va_mae_age=6.355
Epoch 03 | train loss=39.0011 tr_acc_label=0.759 tr_acc_variety=0.876 | valid loss=34.4646 va_acc_label=0.837 va_acc_variety=0.897 va_mae_age=6.351
Epoch 04 | train loss=39.3526 tr_acc_label=0.758 tr_acc_variety=0.875 | valid loss=34.4237 va_acc_label=0.829 va_acc_variety=0.902 va_mae_age=6.347
Epoch 05 | train loss=39.0868 tr_acc_label=0.755 tr_acc_variety=0.871 | valid loss=34.3939 va_acc_label=0.835 va_acc_variety=0.906 va_mae_age=6.342
Epoch 06 | train loss=38.5538 tr_acc_label=0.759 tr_acc_variety=0.878 | valid loss=34.3860 va_acc_label=0.838 va_acc_variety=0.897 va_mae_age=6.338
Epoch 07 | train loss=39.4844 tr_acc_label=0.761 tr_acc_variety=0.875 | valid loss=34.3677 va_acc_label=0.833 va

## 10) Save / Load

In [15]:

torch.save(model.state_dict(), 'multitask_convnext_tiny_paddy.pth')
print('Saved to multitask_convnext_tiny_paddy.pth')
# model.load_state_dict(torch.load('multitask_convnext_tiny_paddy.pth', map_location=device))
# model.eval()


Saved to multitask_convnext_tiny_paddy.pth


## 7) Evaluation & Confusion Matrix

In [16]:

@torch.no_grad()
def get_all_preds_targets(model, loader):
    model.eval()
    preds, targs = [], []
    for xb,_,_,yb in loader:
        xb = xb.to(device)
        logits = model(xb)

        #just want the label preds
        preds.append(logits['label'].argmax(1).cpu().numpy())
        targs.append(yb.numpy())
    return np.concatenate(preds), np.concatenate(targs)

preds, targs = get_all_preds_targets(model, valid_loader)

# Confusion matrix
num_classes = len(class_names)
cm = np.zeros((num_classes, num_classes), dtype=int)
for t, p in zip(targs, preds):
    cm[t, p] += 1

print('Confusion Matrix (6x6 for PaddyDoctor):')
print(cm)

# Per-class metrics
per_class = []
for k in range(num_classes):
    TP = cm[k,k]
    FP = cm[:,k].sum() - TP
    FN = cm[k,:].sum() - TP
    TN = cm.sum() - TP - FP - FN
    prec = TP/(TP+FP) if (TP+FP)>0 else 0.0
    rec  = TP/(TP+FN) if (TP+FN)>0 else 0.0
    f1   = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0
    per_class.append((prec, rec, f1))

macro_p = float(np.mean([p for p,_,_ in per_class]))
macro_r = float(np.mean([r for _,r,_ in per_class]))
macro_f = float(np.mean([f for _,_,f in per_class]))
overall_acc = float((preds == targs).mean())

print('\nPer-class (precision, recall, f1):')
for name,(p,r,f) in zip(class_names, per_class):
    print(f'{name:>10s}: P={p:.3f} R={r:.3f} F1={f:.3f}')
print(f"\nMacro avg: P={macro_p:.3f} R={macro_r:.3f} F1={macro_f:.3f}")
print(f'Overall Accuracy: {overall_acc:.3f}')


Confusion Matrix (6x6 for PaddyDoctor):
[[ 21   1   0   4   2   0   1   6   1  10]
 [  1  34   0   1   1   0   1   4   1   1]
 [  0   0  21   1   0   1   0   0   0   0]
 [  0   1   0 136   4   0   3   5   4   8]
 [  3   2   1  10  75   0   4   1   0   0]
 [  0   0   1   0   0 166   0   0   0   1]
 [  3   0   0   5   3   0  38   6   1   5]
 [  0   0   0   8   0   0   5 139  13   5]
 [  0   0   2   0   1   0   0   3 171   3]
 [  4   0   0   6   2   0   4   4   1  70]]

Per-class (precision, recall, f1):
bacterial_leaf_blight: P=0.656 R=0.457 F1=0.538
bacterial_leaf_streak: P=0.895 R=0.773 F1=0.829
bacterial_panicle_blight: P=0.840 R=0.913 F1=0.875
     blast: P=0.795 R=0.845 F1=0.819
brown_spot: P=0.852 R=0.781 F1=0.815
dead_heart: P=0.994 R=0.988 F1=0.991
downy_mildew: P=0.679 R=0.623 F1=0.650
     hispa: P=0.827 R=0.818 F1=0.822
    normal: P=0.891 R=0.950 F1=0.919
    tungro: P=0.680 R=0.769 F1=0.722

Macro avg: P=0.811 R=0.792 F1=0.798
Overall Accuracy: 0.838


## 8) (Optional) Inference on Test Set + `submission.csv`

In [17]:

import pandas as pd
from PIL import Image

@torch.no_grad()
def predict_folder_images(img_dir, transform, class_names):
    paths = sorted([p for p in Path(img_dir).glob("*.*") if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"}])
    ids, labels = [], []
    for p in paths:
        img = Image.open(p).convert("RGB")
        x = transform(img).unsqueeze(0).to(device)
        logits = model(x)
        pred = logits.argmax(1).item()
        ids.append(p.name)
        labels.append(class_names[pred])
    return ids, labels

test_dir_A = DATA_DIR / "test"
test_dir_B = DATA_DIR / "test_images"
sub_path   = DATA_DIR / "submission.csv"

if test_dir_A.exists():
    ids, labels = predict_folder_images(test_dir_A, valid_tfms, class_names)
elif test_dir_B.exists():
    ids, labels = predict_folder_images(test_dir_B, valid_tfms, class_names)
else:
    ids, labels = [], []
    print("No test directory found; skipping submission.csv.")

if ids:
    df = pd.DataFrame({"image_id": ids, "label": labels})
    df.to_csv(sub_path, index=False)
    print("Saved:", sub_path)


AttributeError: 'dict' object has no attribute 'argmax'

In [None]:
!cat ./data/submission.csv | head -10

image_id,label
200001.jpg,hispa
200002.jpg,normal
200003.jpg,downy_mildew
200004.jpg,blast
200005.jpg,blast
200006.jpg,brown_spot
200007.jpg,dead_heart
200008.jpg,brown_spot
200009.jpg,hispa
cat: write error: Broken pipe



---

## Notes for Beginners

- **Why `timm`?** Lots of pretrained models + convenient transforms. Switching `MODEL_NAME` is an easy way to try stronger backbones.
- **Transforms:** Using `timm.data.create_transform` keeps preprocessing consistent with the chosen model.
- **Training recipe:** Freeze → train head → unfreeze → fine-tune at smaller LR.
- **OOM tips:** Lower `BATCH_SIZE` or try a smaller model (e.g., `efficientnet_b0`, `mobilenetv3_large_100`).
- **Save/load:** `torch.save(model.state_dict(), "model.pth")`, then `model.load_state_dict(torch.load("model.pth", map_location=device))`.
