In [33]:
# ===============================
# SECTION 1: Import Libraries & Setup
# ===============================

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import random
import copy

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# ===============================
# DATASET PATH (Kaggle Notebook)
# ===============================
DATA_DIR = "/kaggle/input/datasets/snikhilrao/crop-disease-detection-dataset/Plant Village Dataset"

train_dir = os.path.join(DATA_DIR, "Train")
val_dir   = os.path.join(DATA_DIR, "Val")
test_dir  = os.path.join(DATA_DIR, "Test")

# Safety check
for d in [train_dir, val_dir]:
    if not os.path.exists(d):
        raise FileNotFoundError(f"Folder not found: {d}")
    print(f"Folder {d} found. Classes: {os.listdir(d)}")


Using device: cuda
Folder /kaggle/input/datasets/snikhilrao/crop-disease-detection-dataset/Plant Village Dataset/Train found. Classes: ['Grape - Healthy', 'Potato - Early Blight', 'Bell Pepper - Healthy', 'Potato - Late Blight', 'Corn (Maize) - Cercospora Leaf Spot', 'Tomato - Septoria Leaf Spot', 'Bell Pepper - Bacterial Spot', 'Cherry - Powdery Mildew', 'Apple - Healthy', 'Tomato - Late Blight', 'Tomato - Healthy', 'Tomato - Early Blight', 'Grape - Black Rot', 'Potato - Healthy', 'Corn (Maize) - Northern Leaf Blight', 'Strawberry - Leaf Scorch', 'Tomato - Bacterial Spot', 'Peach - Bacterial Spot', 'Corn (Maize) - Common Rust', 'Strawberry - Healthy', 'Cherry - Healthy', 'Grape - Esca (Black Measles)', 'Apple - Cedar Apple Rust', 'Tomato - Yellow Leaf Curl Virus', 'Apple - Apple Scab', 'Apple - Black Rot', 'Corn (Maize) - Healthy', 'Peach - Healthy', 'Grape - Leaf Blight']
Folder /kaggle/input/datasets/snikhilrao/crop-disease-detection-dataset/Plant Village Dataset/Val found. Classes:

In [34]:
# ===============================
# SECTION 2: Load Dataset & Apply Transformations
# ===============================

# Image transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load datasets
train_dataset = torchvision.datasets.ImageFolder(root=train_dir, transform=train_transform)
val_dataset   = torchvision.datasets.ImageFolder(root=val_dir, transform=val_transform)

# Optional test dataset
if os.path.exists(test_dir):
    test_dataset = torchvision.datasets.ImageFolder(root=test_dir, transform=val_transform)
else:
    test_dataset = None

# Number of classes
NUM_CLASSES = len(train_dataset.classes)
print(f"Number of classes: {NUM_CLASSES}")
print("Classes:", train_dataset.classes)

# DataLoaders
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
if test_dataset:
    print(f"Test samples: {len(test_dataset)}")


Number of classes: 29
Classes: ['Apple - Apple Scab', 'Apple - Black Rot', 'Apple - Cedar Apple Rust', 'Apple - Healthy', 'Bell Pepper - Bacterial Spot', 'Bell Pepper - Healthy', 'Cherry - Healthy', 'Cherry - Powdery Mildew', 'Corn (Maize) - Cercospora Leaf Spot', 'Corn (Maize) - Common Rust', 'Corn (Maize) - Healthy', 'Corn (Maize) - Northern Leaf Blight', 'Grape - Black Rot', 'Grape - Esca (Black Measles)', 'Grape - Healthy', 'Grape - Leaf Blight', 'Peach - Bacterial Spot', 'Peach - Healthy', 'Potato - Early Blight', 'Potato - Healthy', 'Potato - Late Blight', 'Strawberry - Healthy', 'Strawberry - Leaf Scorch', 'Tomato - Bacterial Spot', 'Tomato - Early Blight', 'Tomato - Healthy', 'Tomato - Late Blight', 'Tomato - Septoria Leaf Spot', 'Tomato - Yellow Leaf Curl Virus']
Training samples: 53693
Validation samples: 12067
Test samples: 1358


In [35]:
# ===============================
# SECTION 3: Pre-trained ResNet18 & Federated Learning Setup
# ===============================

# Load pre-trained ResNet18
import torchvision.models as models
global_model = models.resnet18(pretrained=True)

# Replace final fully connected layer
num_ftrs = global_model.fc.in_features
global_model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
global_model = global_model.to(DEVICE)
print(global_model)

# -------- Federated Averaging --------
def federated_average(global_model, local_models):
    global_dict = global_model.state_dict()
    for key in global_dict.keys():
        global_dict[key] = torch.stack(
            [local_models[i].state_dict()[key].float() for i in range(len(local_models))], 0
        ).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# -------- Local Training Function --------
def train_local(model, train_dataset, epochs=3, batch_size=32, lr=0.0001):
    model.train()
    loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(loader):.4f}")
    return model

# -------- Evaluation Function --------
def evaluate(model, dataset, batch_size=32):
    model.eval()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s] 


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [38]:
# ===============================
# SECTION 4: Federated Learning Training Loop
# ===============================

NUM_CLIENTS = 2       # reduce for better per-client data
LOCAL_EPOCHS = 3
ROUNDS = 2
BATCH_SIZE = 32
LR = 0.0001

# Split dataset into clients (balanced)
def split_dataset(dataset, num_clients):
    data_size = len(dataset)
    indices = list(range(data_size))
    random.shuffle(indices)
    split_size = data_size // num_clients
    client_datasets = []
    for i in range(num_clients):
        start = i * split_size
        end = start + split_size if i < num_clients - 1 else data_size
        client_datasets.append(Subset(dataset, indices[start:end]))
    return client_datasets

client_datasets = split_dataset(train_dataset, NUM_CLIENTS)
print(f"Dataset split into {NUM_CLIENTS} clients.")

# Federated Training
global_model = models.resnet18(pretrained=True)
global_model.fc = nn.Linear(global_model.fc.in_features, NUM_CLASSES)
global_model = global_model.to(DEVICE)

for round_idx in range(ROUNDS):
    print(f"\n--- Federated Round {round_idx+1}/{ROUNDS} ---")
    local_models = []

    for client_idx in range(NUM_CLIENTS):
        print(f"Training client {client_idx+1}/{NUM_CLIENTS}")
        local_model = copy.deepcopy(global_model)
        local_model = train_local(local_model, client_datasets[client_idx],
                                  epochs=LOCAL_EPOCHS, batch_size=BATCH_SIZE, lr=LR)
        local_models.append(local_model)

    # Aggregate weights
    global_model = federated_average(global_model, local_models)

    # Evaluate global model on validation set
    val_acc = evaluate(global_model, val_dataset, batch_size=BATCH_SIZE)
    print(f"Global Model Validation Accuracy after Round {round_idx+1}: {val_acc:.2f}%")

print("\n✅ Federated Training Complete!")


Dataset split into 2 clients.

--- Federated Round 1/2 ---
Training client 1/2
Epoch 1/3 - Loss: 0.2483
Epoch 2/3 - Loss: 0.0503
Epoch 3/3 - Loss: 0.0304
Training client 2/2
Epoch 1/3 - Loss: 0.2463
Epoch 2/3 - Loss: 0.0458
Epoch 3/3 - Loss: 0.0339
Global Model Validation Accuracy after Round 1: 99.59%

--- Federated Round 2/2 ---
Training client 1/2
Epoch 1/3 - Loss: 0.0374
Epoch 2/3 - Loss: 0.0214
Epoch 3/3 - Loss: 0.0259
Training client 2/2
Epoch 1/3 - Loss: 0.0369
Epoch 2/3 - Loss: 0.0244
Epoch 3/3 - Loss: 0.0224
Global Model Validation Accuracy after Round 2: 99.68%

✅ Federated Training Complete!


In [39]:
# ===============================
# SECTION 5: Test & Save Trained Global Model
# ===============================

# Evaluate on test dataset
if test_dataset is not None:
    global_model.eval()
    loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    correct = 0
    total = 0
    class_correct = [0] * NUM_CLASSES
    class_total = [0] * NUM_CLASSES

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = global_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            for i in range(len(labels)):
                label = labels[i]
                class_total[label] += 1
                if predicted[i] == label:
                    class_correct[label] += 1

    overall_acc = 100 * correct / total
    print(f"\nOverall Test Accuracy: {overall_acc:.2f}%")
    print("\nPer-Class Accuracy:")
    for i, class_name in enumerate(test_dataset.classes):
        if class_total[i] > 0:
            acc = 100 * class_correct[i] / class_total[i]
            print(f"{class_name}: {acc:.2f}%")
else:
    print("No test dataset found. Skipping test evaluation.")

# Save global model
MODEL_PATH = "/kaggle/working/global_resnet_federated.pth"
torch.save(global_model.state_dict(), MODEL_PATH)
print(f"\nGlobal model saved at: {MODEL_PATH}")



Overall Test Accuracy: 99.63%

Per-Class Accuracy:
Apple - Apple Scab: 100.00%
Apple - Black Rot: 100.00%
Apple - Cedar Apple Rust: 100.00%
Apple - Healthy: 100.00%
Bell Pepper - Bacterial Spot: 100.00%
Bell Pepper - Healthy: 100.00%
Cherry - Healthy: 100.00%
Cherry - Powdery Mildew: 100.00%
Corn (Maize) - Cercospora Leaf Spot: 95.56%
Corn (Maize) - Common Rust: 100.00%
Corn (Maize) - Healthy: 100.00%
Corn (Maize) - Northern Leaf Blight: 97.92%
Grape - Black Rot: 100.00%
Grape - Esca (Black Measles): 100.00%
Grape - Healthy: 100.00%
Grape - Leaf Blight: 100.00%
Peach - Bacterial Spot: 100.00%
Peach - Healthy: 100.00%
Potato - Early Blight: 100.00%
Potato - Healthy: 100.00%
Potato - Late Blight: 100.00%
Strawberry - Healthy: 100.00%
Strawberry - Leaf Scorch: 100.00%
Tomato - Bacterial Spot: 100.00%
Tomato - Early Blight: 100.00%
Tomato - Healthy: 100.00%
Tomato - Late Blight: 100.00%
Tomato - Septoria Leaf Spot: 100.00%
Tomato - Yellow Leaf Curl Virus: 95.92%

Global model saved at: /k

In [42]:
# ===============================
# SECTION 6: Single Image Inference
# ===============================

from PIL import Image

# Load trained model
loaded_model = models.resnet18(pretrained=False)
loaded_model.fc = nn.Linear(loaded_model.fc.in_features, NUM_CLASSES)
loaded_model.load_state_dict(torch.load(MODEL_PATH))
loaded_model = loaded_model.to(DEVICE)
loaded_model.eval()
print("Model loaded for inference.")

# Image transformation (must match training)
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def predict_image(image_path):
    """
    Predict crop disease for a single image
    """
    img = Image.open(image_path).convert('RGB')
    img_tensor = inference_transform(img).unsqueeze(0).to(DEVICE)  # add batch dimension
    with torch.no_grad():
        outputs = loaded_model(img_tensor)
        _, predicted = torch.max(outputs.data, 1)
    class_name = train_dataset.classes[predicted.item()]
    return class_name

# Example usage
test_image_path = "/kaggle/input/datasets/snikhilrao/crop-disease-detection-dataset/Plant Village Dataset/Test/Apple - Apple Scab/03354abb-aa1c-4f9d-a1ef-9f40505cd539___FREC_Scab 3355.JPG"  # replace with your image path
predicted_class = predict_image(test_image_path)
print(f"The predicted disease class is: {predicted_class}")


Model loaded for inference.
The predicted disease class is: Apple - Apple Scab
