In [7]:
from torchinfo import summary
from torch import nn

class PlantDiseaseModel(nn.Module):
    """Convolutional Neural Network for plant disease classification"""
    def __init__(self, num_classes, dropout_rate=0.5):
        super(PlantDiseaseModel, self).__init__()
        # Convolutional Block 1
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding="same"),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        # Convolutional Block 2
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding="same"),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        # Convolutional Block 3
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding="same"),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        # Convolutional Block 4
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding="same"),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        # Convolutional Block 5
        self.conv_block5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding="same"),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        # Global Average Pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        # Fully Connected Layers
        self.fc_block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.conv_block5(x)
        x = self.global_avg_pool(x)
        x = self.fc_block(x)
        return x

print(summary(PlantDiseaseModel(15), input_size=(1, 3, 224, 224)))

Layer (type:depth-idx)                   Output Shape              Param #
PlantDiseaseModel                        [1, 15]                   --
├─Sequential: 1-1                        [1, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 64, 224, 224]         1,792
│    └─BatchNorm2d: 2-2                  [1, 64, 224, 224]         128
│    └─ReLU: 2-3                         [1, 64, 224, 224]         --
│    └─MaxPool2d: 2-4                    [1, 64, 112, 112]         --
├─Sequential: 1-2                        [1, 128, 56, 56]          --
│    └─Conv2d: 2-5                       [1, 128, 112, 112]        73,856
│    └─BatchNorm2d: 2-6                  [1, 128, 112, 112]        256
│    └─ReLU: 2-7                         [1, 128, 112, 112]        --
│    └─MaxPool2d: 2-8                    [1, 128, 56, 56]          --
├─Sequential: 1-3                        [1, 256, 28, 28]          --
│    └─Conv2d: 2-9                       [1, 256, 56, 56]          295,168
│

In [10]:
import kagglehub

path = kagglehub.dataset_download("emmarex/plantdisease")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Downloading from https://www.kaggle.com/api/v1/datasets/download/emmarex/plantdisease?dataset_version_number=1...


100%|██████████| 658M/658M [06:41<00:00, 1.72MB/s] 

Extracting files...





Path to dataset files: /home/estienne/.cache/kagglehub/datasets/emmarex/plantdisease/versions/1


In [23]:
import json
from multiprocessing.spawn import prepare
import pickle
from sklearn.preprocessing import LabelEncoder
import os
from torchvision import transforms

def load_images(directory_root):
    """Load images and their labels from directory structure"""
    image_list, label_list = [], []
    print("[INFO] Loading images...")

    for disease_folder in os.listdir(directory_root):
        disease_folder_path = os.path.join(directory_root, disease_folder)
        if not os.path.isdir(disease_folder_path):
            continue

        for img_name in os.listdir(disease_folder_path):
            if img_name.startswith("."):
                continue
            img_path = os.path.join(disease_folder_path, img_name)
            if img_path.lower().endswith(('.jpg', '.jpeg', '.png')):
                image_list.append(img_path)
                label_list.append(disease_folder)

    print("[INFO] Image loading completed")
    print(f"Total images: {len(image_list)}")
    return image_list, label_list

def prepare_data(directory_root, image_size=(256, 256), batch_size=32, test_size=0.3, valid_ratio=0.5, random_state=42):
    """Prepare data loaders and label encoder"""
    # Load images and labels
    image_paths, labels = load_images(directory_root)

    # Encode labels as integers
    label_encoder = LabelEncoder()
    labels_encoded = label_encoder.fit_transform(labels)

    # Save label encoder for inference
    with open('label_encoder.pkl', 'wb') as f:
        pickle.dump(label_encoder, f)

    valid_test_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Save image transformation for inference
    with open('inference_transform.pkl', 'wb') as f:
        pickle.dump(valid_test_transform, f)


prepare_data(path+"/PlantVillage")

[INFO] Loading images...
[INFO] Image loading completed
Total images: 20638


In [29]:
import os
import shutil
import random

source_dir = os.path.join(path, "PlantVillage")
test_dir = os.path.join("plant_disease/images")

os.makedirs(test_dir, exist_ok=True)

for class_name in os.listdir(source_dir):
    class_path = os.path.join(source_dir, class_name)
    if not os.path.isdir(class_path):
        continue

    images = [img for img in os.listdir(class_path) if img.lower().endswith(('.jpg', '.jpeg', '.png'))]
    selected_images = random.sample(images, min(5, len(images)))

    target_class_dir = os.path.join(test_dir, class_name)
    os.makedirs(target_class_dir, exist_ok=True)

    for img_name in selected_images:
        src_img = os.path.join(class_path, img_name)
        dst_img = os.path.join(target_class_dir, img_name)
        shutil.copy2(src_img, dst_img)

print("Random test images copied to:", test_dir)

Random test images copied to: plant_disease/images


In [33]:
def predict_plant_disease(image_path, model_path, label_encoder_path, transform_path):
    """
    Predict plant disease from an image
    
    Arguments:
    image_path -- path to the image file
    model_path -- path to the trained model
    label_encoder_path -- path to the saved label encoder
    transform_path -- path to the saved transform
    
    Returns:
    predicted_class -- the predicted disease class
    confidence -- confidence score as percentage
    """
    import torch
    import pickle
    import json
    from PIL import Image
    
    # Load model
    with open('class_names.json', 'r') as f:
        class_names = json.load(f)
    num_classes = len(class_names)
    
    model = PlantDiseaseModel(num_classes=num_classes)
    model.load_state_dict(torch.load(model_path))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    # Load label encoder and transform
    with open(label_encoder_path, 'rb') as f:
        label_encoder = pickle.load(f)
    
    with open(transform_path, 'rb') as f:
        transform = pickle.load(f)
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Get prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
    
    # Get top prediction
    top_prob, top_class = torch.max(probabilities, 0)
    predicted_class = label_encoder.inverse_transform([top_class.item()])[0]
    confidence = float(top_prob.item()) * 100
    
    return predicted_class, confidence

# Example usage
image_path = "plant_disease/images/Tomato__Target_Spot/5d157d30-435f-4572-9722-43439ebe6ed2___Com.G_TgS_FL 1048.JPG"
model_path = "plant_disease/model.pth"
label_encoder_path = "label_encoder.pkl"
transform_path = "inference_transform.pkl"
predicted_class, confidence = predict_plant_disease(image_path, model_path, label_encoder_path, transform_path)
print(f"Predicted class: {predicted_class}, Confidence: {confidence:.2f}%")


Predicted class: Tomato__Target_Spot, Confidence: 99.96%
