
# Paddy Disease Classification — **PyTorch + timm**

This is a **minimal, beginner-friendly** image classification pipeline using **PyTorch** and **[timm](https://github.com/huggingface/pytorch-image-models)** (Torch Image Models).  
It mirrors a typical structure with clear comments:

1) Setup & configuration  
2) Dataset & transforms (train/valid split)  
3) Model (transfer learning with `timm.create_model`)  
4) Training loop (loss/optimizer)  
5) Evaluation & confusion matrix  
6) (Optional) Inference on test set + `submission.csv`

> **Note:** Point `DATA_DIR` to your local Kaggle Paddy dataset. This notebook avoids any fastai dependencies.


## 1) Setup & Configuration

In [4]:

# If timm isn't installed, uncomment:
# !pip install timm --quiet

import os, random, math, time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split

from torchvision import datasets, transforms
import timm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

DATA_DIR = Path("/path/to/paddy")  # <-- change me
USE_IMAGEFOLDER = True             # set False to use CSV dataset class below

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


### Get the Data 

Data comes from <strong><a href="https://www.kaggle.com/competitions/paddy-disease-classification">Paddy Doctor: Paddy Disease Classification</a></strong>.  A Kaggle competition whose goal is to identify the type of disease present in rice paddy leaf images. You can download it directly after signing up for a <a href="https://www.kaggle.com/">Kaggle</a> account.<br>

### (Optional)
A better way to get the data is through the <strong><a href="https://www.kaggle.com/docs/api">Kaggle CLI (command line interface)</a></strong>.<br>
It lets you programatically interact with Kaggle (get/browse datasets and competitions and submit results)<br>
BTW- you need to have an API key in order to use the CLI, to get one:<br>
Scroll to the API section in your Account settings and click the Create New API Token button.<br>
Kaggle will generate a JSON file named kaggle.json and prompt you to save the file to your computer.<br>
Put this file in the ~/.kaggle directory on your machine, make sure its only readable by you (chmod 600 /root/.kaggle/kaggle.json)<br>
BTW you are probably going to use this again and again, its a good idea to put it in /storage/cfg then symlink it in setup.sh<br>
The CLI looks in this place for your key.<br>
				

 <br>


In [13]:
!pip install --upgrade kaggle

Collecting kaggle
  Downloading kaggle-1.7.4.5-py3-none-any.whl.metadata (16 kB)
Downloading kaggle-1.7.4.5-py3-none-any.whl (181 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m181.2/181.2 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kaggle
  Attempting uninstall: kaggle
    Found existing installation: kaggle 1.6.17
    Uninstalling kaggle-1.6.17:
      Successfully uninstalled kaggle-1.6.17
Successfully installed kaggle-1.7.4.5


In [14]:
!kaggle datasets list 

ref                                                        title                                                  size  lastUpdated                 downloadCount  voteCount  usabilityRating  
---------------------------------------------------------  -----------------------------------------------  ----------  --------------------------  -------------  ---------  ---------------  
mosapabdelghany/medical-insurance-cost-dataset             Medical Insurance Cost Dataset                        16425  2025-08-24 11:54:36.533000          18350        367  1.0              
zadafiyabhrami/global-crocodile-species-dataset            Global Crocodile Species Dataset                      57473  2025-08-26 08:46:11.950000          12036        325  1.0              
codebynadiia/gdp-per-country-20202025                      GDP per Country 2020–2025                              5677  2025-09-04 14:37:43.563000           9667        180  1.0              
minahilfatima12328/lifestyle-and-sleep-p


### (Optional) CSV-based Dataset

If your data follows the Kaggle layout (`train_images/` and `train.csv` with `image_id,label`), use this dataset class.  
If you rearranged files into class folders, keep `USE_IMAGEFOLDER=True` to use `ImageFolder` instead.


In [5]:

from PIL import Image
import csv

class PaddyCSVDataset(Dataset):
    def __init__(self, images_dir, csv_path, transform=None, class_to_idx=None):
        self.images_dir = Path(images_dir)
        self.transform  = transform
        rows = []
        with open(csv_path, newline="") as f:
            rd = csv.DictReader(f)
            if rd.fieldnames is None or 'image_id' not in rd.fieldnames:
                f.seek(0)
                rd = csv.reader(f)
                for r in rd:
                    if len(r) >= 2:
                        rows.append((r[0], r[1]))
            else:
                for r in rd:
                    rows.append((r['image_id'], r['label']))
        self.items = rows
        labels = sorted(set(label for _, label in self.items))
        self.class_to_idx = class_to_idx or {c:i for i,c in enumerate(labels)}
        self.idx_to_class = {v:k for k,v in self.class_to_idx.items()}
    def __len__(self): return len(self.items)
    def __getitem__(self, idx):
        img_id, label_name = self.items[idx]
        img = Image.open(self.images_dir / img_id).convert("RGB")
        if self.transform:
            img = self.transform(img)
        y = self.class_to_idx[label_name]
        return img, y


## 2) Transforms (match the timm model’s expectations)

In [6]:

MODEL_NAME = "resnet18"  # try: 'efficientnet_b0', 'convnext_tiny', 'mobilenetv3_large_100', ...

config = timm.data.resolve_data_config({}, model=MODEL_NAME)
if 'input_size' in config and len(config['input_size']) == 3:
    _, H, W = config['input_size']
else:
    H = W = 224

train_tfms = timm.data.create_transform(
    **config,
    is_training=True,
    hflip=0.5,
    color_jitter=None,
    auto_augment=None
)

valid_tfms = timm.data.create_transform(
    **config,
    is_training=False
)

print("Resolved input size:", (H, W))


Resolved input size: (224, 224)


## 3) Datasets & Dataloaders

In [7]:

BATCH_SIZE = 32

if USE_IMAGEFOLDER:
    train_dir = DATA_DIR / "train"
    valid_dir = DATA_DIR / "valid"
    test_dir  = DATA_DIR / "test"

    if valid_dir.exists():
        train_ds = datasets.ImageFolder(train_dir, transform=train_tfms)
        valid_ds = datasets.ImageFolder(valid_dir, transform=valid_tfms)
    else:
        full_ds = datasets.ImageFolder(train_dir, transform=train_tfms)
        n_total = len(full_ds)
        n_valid = int(0.1 * n_total)
        n_train = n_total - n_valid
        train_ds, valid_ds = random_split(full_ds, [n_train, n_valid], generator=torch.Generator().manual_seed(42))
        train_ds.dataset = full_ds
        valid_ds.dataset = full_ds

    class_names = train_ds.dataset.classes if hasattr(train_ds, 'dataset') else train_ds.classes
    num_classes = len(class_names)
else:
    train_images = DATA_DIR / "train_images"
    train_csv    = DATA_DIR / "train.csv"
    full_ds = PaddyCSVDataset(train_images, train_csv, transform=train_tfms)
    n_total = len(full_ds)
    n_valid = int(0.1 * n_total)
    n_train = n_total - n_valid
    train_ds, valid_ds = random_split(full_ds, [n_train, n_valid], generator=torch.Generator().manual_seed(42))
    class_names = list(full_ds.class_to_idx.keys())
    num_classes = len(class_names)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

print(f"Classes ({num_classes}):", class_names[:10], "..." if num_classes>10 else "")
print("Train/Valid sizes:", len(train_ds), len(valid_ds))


FileNotFoundError: [Errno 2] No such file or directory: '/path/to/paddy/train'

## 4) Model (Transfer Learning with **timm**)

In [None]:

model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=num_classes)
print("Classifier layer:", model.get_classifier())

# Freeze backbone; train classifier head first
for name, p in model.named_parameters():
    p.requires_grad = False
clf_name = model.get_classifier()
for name, p in model.named_parameters():
    if clf_name in name:
        p.requires_grad = True

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4)


## 5) Training Loop (minimal, commented)

In [None]:

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        total += xb.size(0)
    return total_loss/total, correct/total

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        total_loss += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        total += xb.size(0)
    return total_loss/total, correct/total

EPOCHS = 5
for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    va_loss, va_acc = evaluate(model, valid_loader, criterion)
    print(f"epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | valid loss {va_loss:.4f} acc {va_acc:.3f}")


### (Optional) Fine-tune the whole network

In [None]:

for p in model.parameters():
    p.requires_grad = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

EPOCHS_FT = 3
for epoch in range(1, EPOCHS_FT+1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    va_loss, va_acc = evaluate(model, valid_loader, criterion)
    print(f"[FT] epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | valid loss {va_loss:.4f} acc {va_acc:.3f}")


## 6) Evaluation & Confusion Matrix

In [None]:

import numpy as np

@torch.no_grad()
def get_all_preds_targets(model, loader):
    model.eval()
    preds, targs = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb)
        preds.append(logits.argmax(1).cpu().numpy())
        targs.append(yb.numpy())
    return np.concatenate(preds), np.concatenate(targs)

preds, targs = get_all_preds_targets(model, valid_loader)

num_classes = len(class_names)
cm = np.zeros((num_classes, num_classes), dtype=int)
for t, p in zip(targs, preds):
    cm[t, p] += 1

print("Confusion Matrix (truncated up to 10×10):")
show_n = min(10, num_classes)
print(cm[:show_n, :show_n])


## 7) (Optional) Inference on Test Set + `submission.csv`

In [None]:

import pandas as pd
from PIL import Image

@torch.no_grad()
def predict_folder_images(img_dir, transform, class_names):
    paths = sorted([p for p in Path(img_dir).glob("*.*") if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"}])
    ids, labels = [], []
    for p in paths:
        img = Image.open(p).convert("RGB")
        x = transform(img).unsqueeze(0).to(device)
        logits = model(x)
        pred = logits.argmax(1).item()
        ids.append(p.name)
        labels.append(class_names[pred])
    return ids, labels

test_dir_A = DATA_DIR / "test"
test_dir_B = DATA_DIR / "test_images"
sub_path   = DATA_DIR / "submission.csv"

if test_dir_A.exists():
    ids, labels = predict_folder_images(test_dir_A, valid_tfms, class_names)
elif test_dir_B.exists():
    ids, labels = predict_folder_images(test_dir_B, valid_tfms, class_names)
else:
    ids, labels = [], []
    print("No test directory found; skipping submission.csv.")

if ids:
    df = pd.DataFrame({"image_id": ids, "label": labels})
    df.to_csv(sub_path, index=False)
    print("Saved:", sub_path)



---

## Notes for Beginners

- **Why `timm`?** Lots of pretrained models + convenient transforms. Switching `MODEL_NAME` is an easy way to try stronger backbones.
- **Transforms:** Using `timm.data.create_transform` keeps preprocessing consistent with the chosen model.
- **Training recipe:** Freeze → train head → unfreeze → fine-tune at smaller LR.
- **OOM tips:** Lower `BATCH_SIZE` or try a smaller model (e.g., `efficientnet_b0`, `mobilenetv3_large_100`).
- **Save/load:** `torch.save(model.state_dict(), "model.pth")`, then `model.load_state_dict(torch.load("model.pth", map_location=device))`.
