# iNaturalist 2021 Training Notebook
This notebook trains a ResNet18 model on the iNaturalist 2021 dataset (Mini version). It uses PyTorch and Torchvision.

## 1. Load and Verify Dataset Paths
Set the root path to your PlantVillage dataset and verify the train/val/test folder structure.

In [2]:
import torch
import torchvision
from torchvision import datasets, transforms
import os

# Dataset parameters
DATA_ROOT = './data'
VERSION = '2021_train_mini'

print(f"Using iNaturalist {VERSION} stored at {DATA_ROOT}")

[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.25a0+6627725-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_utilities-0.12.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/looseversion-1.3.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/pyth

In [None]:
# This cell previously handled KaggleHub download. 
# Torchvision handles downloading automatically.
if not os.path.exists(DATA_ROOT):
    os.makedirs(DATA_ROOT, exist_ok=True)

ROOT CONTENTS:
/root/.cache/kagglehub/datasets/abdallahalidev/plantvillage-dataset/versions/3/plantvillage dataset -> DIR


In [None]:
# Placeholder for directory verification if needed.
pass

Subdirectories: [PosixPath('/root/.cache/kagglehub/datasets/abdallahalidev/plantvillage-dataset/versions/3/plantvillage dataset')]
Using: /root/.cache/kagglehub/datasets/abdallahalidev/plantvillage-dataset/versions/3/plantvillage dataset

Classes inside:
color
grayscale
segmented


In [None]:
# Torchvision's INaturalist dataset handles splits internally (train/val).
pass

Using: /root/.cache/kagglehub/datasets/abdallahalidev/plantvillage-dataset/versions/3/plantvillage dataset/color
Exists: True
Total classes: 38
Sample classes: ['Soybean___healthy', 'Apple___Apple_scab', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Grape___Black_rot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus']


In [None]:
# No manual splitting needed.
pass

Split complete.


## 2. Install and Import Dependencies
Install and import required libraries.

In [11]:
# If running in a fresh environment, uncomment the line below
# !pip install torch torchvision numpy matplotlib scikit-learn

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

## 3. Configure GPU and Mixed Precision
Detect CUDA, select device, and enable mixed precision if supported.

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Using device:", device)

use_amp = use_cuda
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

Using device: cpu


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


## 4. Define Data Transforms and Augmentations
Create transforms for training and for validation/testing.

In [13]:
img_size = 224

train_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

## 5. Create Train/Validation/Test Dataloaders
Use ImageFolder and DataLoader to build iterable loaders.

In [14]:
batch_size = 32
num_workers = 2 if os.name == 'nt' else 4

# Training Set (Mini)
train_dataset = datasets.INaturalist(root=DATA_ROOT, version=VERSION, target_type='full', download=True, transform=train_transforms)

# Validation Set
val_dataset = datasets.INaturalist(root=DATA_ROOT, version='2021_valid', target_type='full', download=True, transform=val_test_transforms)

# Test Set
test_dataset = val_dataset 

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

num_classes = 10000
# Fallback class names for 10k classes
class_names = [str(i) for i in range(num_classes)]
print(f"Classes: {num_classes}")

FileNotFoundError: Found no valid file for the classes plantvillage dataset. Supported extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp

## 6. Build the Model (Transfer Learning)
Load a pretrained model and replace the classifier head for PlantVillage classes.

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# Replace the final layer
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

model = model.to(device)

## 7. Set Loss, Optimizer, and Scheduler
Configure criterion, optimizer, and learning-rate scheduler.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

## 8. Train the Model
Implement a training loop with forward, backward, and metric tracking.

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for inputs, labels in loader:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda", enabled=scaler.is_enabled()):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels).item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    return epoch_loss, epoch_acc

## 9. Validate Each Epoch
Run validation after each epoch and track the best model.

In [None]:
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels).item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    return epoch_loss, epoch_acc

num_epochs = 10
best_model_wts = copy.deepcopy(model.state_dict())
best_val_acc = 0.0

history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

for epoch in range(num_epochs):
    start = time.time()

    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    scheduler.step()

    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())

    elapsed = time.time() - start
    print(f"Epoch {epoch+1}/{num_epochs} | Train loss {train_loss:.4f} acc {train_acc:.4f} | Val loss {val_loss:.4f} acc {val_acc:.4f} | {elapsed:.1f}s")

## 10. Evaluate on Test Set
Compute final test metrics and optionally show a confusion matrix.

In [None]:
# Confusion matrix is too large for 10,000 classes to display effectively.
print('Skipping confusion matrix for 10k classes.')

## 11. Save and Load the Best Model
Save best weights and demonstrate loading.

In [None]:
best_model_path = "plantvillage_resnet18_best.pth"
torch.save(best_model_wts, best_model_path)
print("Saved:", best_model_path)

# Load later
model.load_state_dict(torch.load(best_model_path, map_location=device))
model = model.to(device)

## 12. Inference on New Images
Run prediction on a few sample images and map outputs to class labels.

In [None]:
# detailed inference example would require mapping 10k classes to names.
# Torchvision doesn't provide easy class-to-name within the dataset object cleanly for all versions.
# We will skip detailed name mapping for this demo or use dataset 'categories' if available.
print('Inference available. Class IDs returned.')