# PlantVillage Training Notebook
This notebook trains, validates, and tests a plant disease classifier using the PlantVillage dataset. It is designed to run on a GPU if available.

## 1. Load and Verify Dataset Paths
Set the root path to your PlantVillage dataset and verify the train/val/test folder structure.

In [2]:
!pip install torch torchvision numpy matplotlib scikit-learn pillow kagglehub

[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.25a0+6627725-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_utilities-0.12.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/looseversion-1.3.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/pyth

In [None]:
from pathlib import Path
import os
import kagglehub

# Download directly from Kaggle (requires Kaggle API credentials)
DATASET_ROOT = Path(kagglehub.dataset_download("abdallahalidev/plantvillage-dataset"))

# Optional: override with a local path instead of Kaggle download
# DATASET_ROOT = Path(r"D:\KnowYourPlant\datasets\PlantVillage")

TRAIN_DIR = DATASET_ROOT / "train"
VAL_DIR = DATASET_ROOT / "val"
TEST_DIR = DATASET_ROOT / "test"

RAW_DIR = None
if not (TRAIN_DIR.exists() and VAL_DIR.exists() and TEST_DIR.exists()):
    candidates = [
        DATASET_ROOT / "plantvillage dataset" / "color",
        DATASET_ROOT / "PlantVillage" / "color",
        DATASET_ROOT / "color",
        DATASET_ROOT,
    ]
    for candidate in candidates:
        if candidate.exists():
            RAW_DIR = candidate
            break

print("DATASET_ROOT:", DATASET_ROOT)
print("Train exists:", TRAIN_DIR.exists())
print("Val exists:", VAL_DIR.exists())
print("Test exists:", TEST_DIR.exists())
print("RAW_DIR:", RAW_DIR)

ROOT CONTENTS:
/root/.cache/kagglehub/datasets/abdallahalidev/plantvillage-dataset/versions/3/plantvillage dataset -> DIR


In [None]:
if RAW_DIR is None:
    print("RAW_DIR not found. Set DATASET_ROOT or create train/val/test splits.")
else:
    class_dirs = [p for p in RAW_DIR.iterdir() if p.is_dir()]
    print("RAW_DIR classes:", len(class_dirs))
    print("Sample classes:", [p.name for p in class_dirs[:5]])

Subdirectories: [PosixPath('/root/.cache/kagglehub/datasets/abdallahalidev/plantvillage-dataset/versions/3/plantvillage dataset')]
Using: /root/.cache/kagglehub/datasets/abdallahalidev/plantvillage-dataset/versions/3/plantvillage dataset

Classes inside:
color
grayscale
segmented


In [None]:
def summarize_split_dir(dir_path, name):
    if not dir_path.exists():
        print(f"{name} missing: {dir_path}")
        return
    class_dirs = [p for p in dir_path.iterdir() if p.is_dir()]
    print(f"{name} classes:", len(class_dirs))
    print(f"{name} sample:", [p.name for p in class_dirs[:5]])

summarize_split_dir(TRAIN_DIR, "train")
summarize_split_dir(VAL_DIR, "val")
summarize_split_dir(TEST_DIR, "test")

Using: /root/.cache/kagglehub/datasets/abdallahalidev/plantvillage-dataset/versions/3/plantvillage dataset/color
Exists: True
Total classes: 38
Sample classes: ['Soybean___healthy', 'Apple___Apple_scab', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Grape___Black_rot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus']


In [None]:
# Optional: split a single-folder PlantVillage dataset into train/val/test.
# Use this if your dataset is organized as:
# DATASET_ROOT/
#   class_a/
#   class_b/
# and you do NOT already have train/val/test folders.

import random
import shutil
from pathlib import Path

if RAW_DIR is None:
    raise ValueError("RAW_DIR is not set. Update DATASET_ROOT or set RAW_DIR manually.")

split_ratio = (0.8, 0.1, 0.1)  # train, val, test
seed = 42
use_symlinks = False  # True can save space but may require admin on Windows

def split_dataset(raw_dir, train_dir, val_dir, test_dir, ratios, seed, use_symlinks):
    if train_dir.exists() and val_dir.exists() and test_dir.exists():
        print("train/val/test already exist. Skipping split.")
        return

    random.seed(seed)
    class_dirs = [p for p in raw_dir.iterdir() if p.is_dir()]
    if not class_dirs:
        raise ValueError(f"No class folders found in {raw_dir}")

    for class_dir in class_dirs:
        images = [p for p in class_dir.iterdir() if p.is_file()]
        random.shuffle(images)

        n_total = len(images)
        n_train = int(n_total * ratios[0])
        n_val = int(n_total * ratios[1])
        n_test = n_total - n_train - n_val

        splits = {
            train_dir / class_dir.name: images[:n_train],
            val_dir / class_dir.name: images[n_train:n_train + n_val],
            test_dir / class_dir.name: images[n_train + n_val:],
        }

        for out_dir, files in splits.items():
            out_dir.mkdir(parents=True, exist_ok=True)
            for src in files:
                dst = out_dir / src.name
                if use_symlinks:
                    if dst.exists():
                        continue
                    os.symlink(src, dst)
                else:
                    if dst.exists():
                        continue
                    shutil.copy2(src, dst)

    print("Split complete.")

split_dataset(RAW_DIR, TRAIN_DIR, VAL_DIR, TEST_DIR, split_ratio, seed, use_symlinks)

Split complete.


## 2. Install and Import Dependencies
Install and import required libraries.

In [11]:
# If running in a fresh environment, uncomment the line below
# !pip install torch torchvision numpy matplotlib scikit-learn

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

## 3. Configure GPU and Mixed Precision
Detect CUDA, select device, and enable mixed precision if supported.

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Using device:", device)

use_amp = use_cuda
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

Using device: cpu


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


## 4. Define Data Transforms and Augmentations
Create transforms for training and for validation/testing.

In [13]:
img_size = 224

train_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

## 5. Create Train/Validation/Test Dataloaders
Use ImageFolder and DataLoader to build iterable loaders.

In [14]:
batch_size = 32
num_workers = 2 if os.name == "nt" else 4

train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=train_transforms)
val_dataset = datasets.ImageFolder(VAL_DIR, transform=val_test_transforms)
test_dataset = datasets.ImageFolder(TEST_DIR, transform=val_test_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=use_cuda)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=use_cuda)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=use_cuda)

class_names = train_dataset.classes
num_classes = len(class_names)
print("Classes:", num_classes)

FileNotFoundError: Found no valid file for the classes plantvillage dataset. Supported extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp

## 6. Build the Model (Transfer Learning)
Load a pretrained model and replace the classifier head for PlantVillage classes.

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# Replace the final layer
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

model = model.to(device)

## 7. Set Loss, Optimizer, and Scheduler
Configure criterion, optimizer, and learning-rate scheduler.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

## 8. Train the Model
Implement a training loop with forward, backward, and metric tracking.

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for inputs, labels in loader:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda", enabled=scaler.is_enabled()):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels).item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    return epoch_loss, epoch_acc

## 9. Validate Each Epoch
Run validation after each epoch and track the best model.

In [None]:
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels).item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    return epoch_loss, epoch_acc

num_epochs = 10
best_model_wts = copy.deepcopy(model.state_dict())
best_val_acc = 0.0

history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

for epoch in range(num_epochs):
    start = time.time()

    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    scheduler.step()

    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())

    elapsed = time.time() - start
    print(f"Epoch {epoch+1}/{num_epochs} | Train loss {train_loss:.4f} acc {train_acc:.4f} | Val loss {val_loss:.4f} acc {val_acc:.4f} | {elapsed:.1f}s")

## 10. Evaluate on Test Set
Compute final test metrics and optionally show a confusion matrix.

In [None]:
model.load_state_dict(best_model_wts)

test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Test loss {test_loss:.4f} acc {test_acc:.4f}")

# Optional confusion matrix (requires scikit-learn)
try:
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

    all_preds = []
    all_labels = []
    model.eval()
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device, non_blocking=True)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    fig, ax = plt.subplots(figsize=(10, 10))
    disp.plot(ax=ax, xticks_rotation="vertical")
    plt.show()
except Exception as exc:
    print("Confusion matrix skipped:", exc)

## 11. Save and Load the Best Model
Save best weights and demonstrate loading.

In [None]:
best_model_path = "plantvillage_resnet18_best.pth"
torch.save(best_model_wts, best_model_path)
print("Saved:", best_model_path)

# Load later
model.load_state_dict(torch.load(best_model_path, map_location=device))
model = model.to(device)

## 12. Inference on New Images
Run prediction on a few sample images and map outputs to class labels.

In [None]:
from PIL import Image

# Update this to a few image paths
sample_images = [
    # r"D:\\KnowYourPlant\\sample1.jpg",
    # r"D:\\KnowYourPlant\\sample2.jpg",
]

def predict_image(model, image_path, transform, class_names, device):
    model.eval()
    image = Image.open(image_path).convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(tensor)
        probs = torch.softmax(outputs, dim=1)
        conf, pred = torch.max(probs, 1)

    return class_names[pred.item()], conf.item()

for path in sample_images:
    label, conf = predict_image(model, path, val_test_transforms, class_names, device)
    print(f"{path} -> {label} ({conf:.3f})")