# Potato Disease Classification

*This notebook demonstrates a complete machine learning pipeline for classifying potato plant diseases using deep learning with TensorFlow.*

## Project Overview

Potato diseases can cause significant crop loss for farmers worldwide. Early detection of these diseases is crucial for effective treatment and prevention of crop damage. This project uses Convolutional Neural Networks (CNNs) to classify potato leaves into three categories:

1. Early Blight (Alternaria solani)
2. Late Blight (Phytophthora infestans)
3. Healthy

## Table of Contents

1. [Setup and Dependencies](#setup-and-dependencies)
2. [Data Acquisition and Preparation](#data-acquisition-and-preparation)
3. [Data Exploration and Visualization](#data-exploration-and-visualization)
4. [Data Preprocessing](#data-preprocessing)
5. [Model Building and Training](#model-building-and-training)
6. [Model Evaluation](#model-evaluation)
7. [Deployment](#deployment)

## Setup and Dependencies

*First, we'll import all the necessary libraries for our project.*


In [None]:
# Deep learning framework
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import (
    Conv2D,
    Dense,
    Flatten,
    Input,
    MaxPool2D,
    RandomFlip,
    RandomRotation,
    RandomZoom,
    Rescaling,
    Resizing,
)
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing import image_dataset_from_directory

# Machine learning evaluation metrics
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.preprocessing import label_binarize

# Data manipulation and analysis
import numpy as np
import pandas as pd

# Visualization libraries
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import seaborn as sns
from plotly.subplots import make_subplots

# Image processing
from PIL import Image
import io

# Utilities
from tqdm import tqdm
import gradio as gr
import os
import shutil
import zipfile


## Data Acquisition and Preparation

### Downloading Dataset from Kaggle

*We'll use the PlantVillage dataset from Kaggle, which contains images of various plant diseases including potato diseases.*

**Note:** To download from Kaggle, you need to:
1. Create a Kaggle account if you don't have one
2. Go to your account settings and create an API token
3. Download the kaggle.json file and upload it below


#### Downloading on Google Colab
Run the next cell if you are using Google Colab otherwise refer the next cell if you are running it locally.

In [None]:
# Install Kaggle API and setup authentication
%pip install -q kaggle
from google.colab import files

# Upload your kaggle.json API token file (skip if already uploaded)
files.upload()

# Create a Kaggle directory and set permissions for the API token
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download the PlantVillage dataset
!kaggle datasets download -d arjuntejaswi/plant-village

# Extract the downloaded dataset
!unzip plant-village.zip

#### Downloading locally


In [None]:
os.makedirs("DATASET", exist_ok=True)

kaggle_json_path = os.path.expanduser("kaggle.json")
kaggle_dir = os.path.expanduser("~/.kaggle")
os.makedirs(kaggle_dir, exist_ok=True)
if not os.path.exists(os.path.join(kaggle_dir, "kaggle.json")):
    shutil.copy(kaggle_json_path, os.path.join(kaggle_dir, "kaggle.json"))

os.chmod(os.path.join(kaggle_dir, "kaggle.json"), 0o600)

!kaggle datasets download -d arjuntejaswi/plant-village -p "DATASET"

with zipfile.ZipFile("DATASET/plant-village.zip", "r") as zip_ref:
    zip_ref.extractall("DATASET")


### Extracting Potato Disease Images

*The PlantVillage dataset contains images of various plants and diseases. For this project, we'll focus only on potato diseases, so we'll extract just those images.*


In [None]:
# Define source and destination directories
data_dir = "/content/PlantVillage"  # Original dataset location
save_dir = "PotatoDir"  # Where we'll save potato images

# Create a new directory structure with only potato-related folders
for folder in tqdm(os.listdir(data_dir), desc="Extracting potato images"):
    # Select only folders that contain potato images
    if folder.startswith("Potato"):
        folder_path = os.path.join(save_dir, folder)

        # Ensure the destination folder exists
        os.makedirs(folder_path, exist_ok=True)

        # Add a short delay to ensure the filesystem catches up
        # This helps prevent file not found errors
        import time

        time.sleep(1)

        # Copy all images from the source to the destination folder
        for image_file in os.listdir(os.path.join(data_dir, folder)):
            src = os.path.join(data_dir, folder, image_file)
            dst = os.path.join(folder_path, image_file)
            try:
                shutil.copy(src, dst)
            except FileNotFoundError as e:
                print(f"File not found: {e}")


In [None]:
# Verify the extraction by counting images in each class
print("Class distribution in the potato disease dataset:")
for folder in os.listdir(save_dir):
    image_count = len(os.listdir(os.path.join(save_dir, folder)))
    print(f"{folder} : {image_count} images")

## Data Exploration and Visualization

### Loading the Dataset

*Now we'll load our potato disease images using TensorFlow's dataset utilities, which will help us efficiently process and batch the data for training.*


In [None]:
# Define image processing parameters
IMAGE_SIZE = 256  # All images will be resized to 256x256 pixels
BATCH_SIZE = 32  # Number of images processed in each training batch

# Load images from the directory with automatic labeling based on the folder structure
save_dir = "/content/PotatoDir"
dataset = image_dataset_from_directory(
    save_dir,  # Path to the image directory
    shuffle=True,  # Shuffle data to prevent model memorization
    image_size=(IMAGE_SIZE, IMAGE_SIZE),  # Resize all images to consistent dimensions
    batch_size=BATCH_SIZE,  # Define batch size for training
)


In [None]:
# Extract class names from the dataset
class_names = dataset.class_names
print("Classes in our dataset:")
for i, class_name in enumerate(class_names):
    print(f"  {i}: {class_name}")


In [None]:
# Examine the shape of our batched data
print("Dataset structure:")
for image_batch, labels_batch in dataset.take(1):
    print(f"Image batch shape: {image_batch.shape}")
    print(f"Label batch shape: {labels_batch.shape}")
    print(f"Labels in this batch: {labels_batch.numpy()}")

### Visualizing Sample Images

*Let's examine some sample images from each class to better understand our dataset. This will help us identify any potential issues and get familiar with the visual characteristics of each disease.*


In [None]:
fig = plt.figure(figsize=(14, 14))
fig.suptitle("Sample Images from the Potato Disease Dataset", fontsize=20)

# Take one batch of images and display the first 16
for image_batch, label_batch in dataset.take(1):
    for i in range(16):
        # Create a subplot in a 4x4 grid
        ax = plt.subplot(4, 4, i + 1)

        # Convert tensor to a numpy array and display the image
        plt.imshow(image_batch[i].numpy().astype(np.uint8))

        # Add the class names as the subplot title
        class_name = class_names[label_batch.numpy()[i]]
        plt.title(class_name.replace("Potato___", "").replace("_", " "), fontsize=12)
        plt.axis("off")

plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()


### Key Observations about the Dataset

*From the sample images, we can observe:*

- **Early Blight**: Characterized by brown spots with concentric rings, typically on older leaves
- **Late Blight**: Shows dark brown patches with fuzzy white growth in humid conditions
- **Healthy**: Green leaves without visible lesions or discoloration

*These visual differences will be important features for our model to learn.*



## Data Preprocessing

*Before training our model, we need to prepare our data by splitting it into training, validation, and test sets, and setting up data augmentation to improve model generalization.*

### Splitting the Dataset

*We'll use a standard 80/10/10 split:*
- *80% for training (learning patterns)*
- *10% for validation (tuning hyperparameters)*
- *10% for testing (final evaluation)*


In [None]:
def split_dataset(
    ds, train_ratio, val_ratio, test_ratio, shuffle=True, shuffle_size=10000
):
    """
    Split a TensorFlow dataset into training, validation, and test sets.

    Parameters:
    -----------
    ds : tf.data.Dataset
        The dataset to split
    train_ratio : float
        Proportion of data to use for training (e.g., 0.8 for 80%)
    val_ratio : float
        Proportion of data to use for validation (e.g., 0.1 for 10%)
    test_ratio : float
        Proportion of data to use for testing (e.g., 0.1 for 10%)
    shuffle : bool
        Whether to shuffle the dataset before splitting
    shuffle_size : int
        Buffer size for shuffling

    Returns:
    --------
    train_ds, val_ds, test_ds : tuple of tf.data.Dataset
        The split datasets
    """
    # Shuffle the dataset if requested (with a fixed seed for reproducibility)
    if shuffle:
        ds = ds.shuffle(shuffle_size, seed=2024)

    # Calculate the size of each split
    TRAIN_SIZE = int(len(ds) * train_ratio)
    VAL_SIZE = int(len(ds) * val_ratio)
    TEST_SIZE = int(len(ds) * test_ratio)

    # Split the dataset
    train_ds = ds.take(TRAIN_SIZE)
    val_ds = ds.skip(TRAIN_SIZE).take(VAL_SIZE)
    test_ds = ds.skip(TRAIN_SIZE + VAL_SIZE).take(TEST_SIZE)

    # Optimize dataset performance with caching, shuffling and prefetching
    train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
    val_ds = val_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
    test_ds = test_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)

    return train_ds, val_ds, test_ds

In [None]:
# Split the dataset using our function with an 80/10/10 ratio
train_ds, val_ds, test_ds = split_dataset(dataset, 0.8, 0.1, 0.1)

# Verify our split sizes
print("Dataset split summary:")
print(f"  Training batches:   {len(train_ds)} ({len(train_ds) * BATCH_SIZE} images)")
print(f"  Validation batches: {len(val_ds)} ({len(val_ds) * BATCH_SIZE} images)")
print(f"  Testing batches:    {len(test_ds)} ({len(test_ds) * BATCH_SIZE} images)")

### Image Preprocessing Pipeline

*We'll create two preprocessing components:*

1. **Resize and Rescale**: Ensures all images have consistent dimensions and pixel values in the range [0,1]
2. **Data Augmentation**: Generates new training examples by applying random transformations to existing images

#### Resize and Rescale Layer

*This preprocessing step standardizes our input images:*


In [None]:
# Create a sequential model for resizing and rescaling images
resize_rescale = Sequential(
    [
        # Resize all images to a standard size
        Resizing(IMAGE_SIZE, IMAGE_SIZE),
        # Rescale pixel values from [0,255] to [0,1]
        Rescaling(1.0 / 255),
    ]
)


#### Data Augmentation

*Data augmentation is a powerful technique to prevent overfitting and improve model generalization. By creating variations of our training images, we help the model learn more robust features.*


In [None]:
# Create a sequential model for data augmentation
data_augmentation = Sequential(
    [
        # Randomly flip images both horizontally and vertically
        # This simulates different plant orientations
        RandomFlip("horizontal_and_vertical"),
        # Randomly rotate images by up to 40% in either direction
        # This helps the model become invariant to the orientation of leaves
        RandomRotation(0.4),
        # Randomly zoom in/out by up to 20%
        # This simulates variations in distance from the plant
        RandomZoom(0.2),
    ]
)


*These augmentation techniques will help our model generalize better to new, unseen images and become more robust to variations in how the potato leaves are photographed.*

### Visualizing Data Augmentation Effects

*Let's visualize how our augmentation techniques transform the original images. This helps us understand what variations our model will learn from.*

#### Random Rotation Visualization

*Random rotation simulates different angles at which the leaves might be photographed.*


In [None]:
def visualize_rotation(ds, factor):
    """Visualize the effect of random rotation on sample images"""
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f"Effect of Random Rotation (Factor: {factor})", fontsize=16)

    # Create a rotation layer with the specified factor
    rand_rot = RandomRotation(factor)

    for image, label in ds.take(1):
        for i in range(2):
            # Original image
            axs[i, 0].imshow(image[i].numpy().astype(np.uint8))
            class_name = (
                class_names[label[i]].replace("Potato___", "").replace("_", " ")
            )
            axs[i, 0].set_title(f"Original: {class_name}", fontsize=12)
            axs[i, 0].axis("off")

            # Rotated image
            rotated_image = rand_rot(image[i]).numpy().astype(np.uint8)
            axs[i, 1].imshow(rotated_image)
            axs[i, 1].set_title(f"Rotated: {class_name}", fontsize=12)
            axs[i, 1].axis("off")

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()


# Visualize rotation with a factor of 0.5 (up to 180 degrees)
visualize_rotation(train_ds, 0.5)


#### Random Zoom Visualization

*Random zoom simulates variations in camera distance and focus when photographing plant leaves.*


In [None]:
def visualize_zoom(ds, factor):
    """Visualize the effect of random zoom on sample images"""
    fig, axs = plt.subplots(4, 4, figsize=(14, 14))
    fig.suptitle(f"Effect of Random Zoom (Factor: {factor})", fontsize=18)

    # Create a zoom layer with the specified factor
    rand_zoom = RandomZoom(factor)

    for image, label in ds.take(1):
        for i in range(8):
            row, col = i // 2, (i % 2) * 2

            # Original image
            axs[row, col].imshow(image[i].numpy().astype(np.uint8))
            class_name = (
                class_names[label[i]].replace("Potato___", "").replace("_", " ")
            )
            axs[row, col].set_title(f"Original: {class_name}", fontsize=10)
            axs[row, col].axis("off")

            # Zoomed image
            zoomed_image = rand_zoom(image[i]).numpy().astype(np.uint8)
            axs[row, col + 1].imshow(zoomed_image)
            axs[row, col + 1].set_title(f"Zoomed: {class_name}", fontsize=10)
            axs[row, col + 1].axis("off")

    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    plt.show()


# Visualize zoom with a factor of 0.5 (zoom in/out by up to 50%)
visualize_zoom(train_ds, 0.5)


*These visualizations show how data augmentation creates diverse training examples from our original dataset. The model will learn to recognize diseases regardless of the leaf orientation, angle, or how zoomed in the photo is.*

## Model Building and Training

*Now we'll design and train a Convolutional Neural Network (CNN) to classify potato diseases.*

### CNN Architecture Design

*CNNs are the standard architecture for image classification tasks. Our model will use:*

- *Multiple convolutional layers to extract features from images*
- *Max pooling to reduce dimensionality and computational load*
- *Dense layers at the end for classification*
- *Data augmentation during training to improve generalization*


In [None]:
# Redefine our image parameters
IMAGE_SIZE = 256  # Input image dimensions
BATCH_SIZE = 32  # Batch size for training

# Define the input shape for our model
input_shape = (IMAGE_SIZE, IMAGE_SIZE, 3)  # Height, Width, Channels

# Build a sequential CNN model
seq_model = Sequential(
    [
        # Input layer specifying the input shape
        Input(shape=input_shape),
        # Preprocessing layers
        resize_rescale,  # Resize and normalize images
        data_augmentation,  # Apply data augmentation during training
        # First convolutional block
        Conv2D(filters=32, kernel_size=(3, 3), activation="relu", padding="same"),
        MaxPool2D(pool_size=(2, 2)),
        # Second convolutional block (doubling filters)
        Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
        MaxPool2D(pool_size=(2, 2)),
        # Third convolutional block (doubling filters again)
        Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
        MaxPool2D(pool_size=(2, 2)),
        # Fourth convolutional block (doubling filters again)
        Conv2D(filters=256, kernel_size=(3, 3), activation="relu", padding="same"),
        MaxPool2D(pool_size=(2, 2)),
        # Flatten the 3D feature maps to 1D feature vectors
        Flatten(),
        # First dense layer for a high-level feature combination
        Dense(units=128, activation="relu"),
        # Output layer with softmax activation for multi-class classification
        # Number of units equals the number of disease classes
        Dense(units=len(class_names), activation="softmax"),
    ]
)

# Display a summary of the model architecture
seq_model.summary()


*Model Architecture Highlights:*

- *Progressive increase in filter count (32→64→128→256) to learn hierarchical features*
- *Max pooling after each convolutional layer to reduce spatial dimensions*
- *ReLU activation for non-linearity in feature extraction*
- *128-unit dense layer to combine high-level features*
- *Softmax output to get class probabilities*


### Model Compilation

*Before training, we need to compile the model by specifying:*

- *Optimizer: Algorithm to update the model weights*
- *Loss function: Measures how far the model's predictions are from the true values*
- *Metrics: Used to monitor the training and evaluation process*


In [None]:
# Compile the model with appropriate settings for a classification task
seq_model.compile(
    # Adam optimizer with default learning rate (good for most tasks)
    optimizer="adam",
    # Sparse categorical crossentropy is appropriate when classes are mutually exclusive
    # and labels are integers (not one-hot encoded)
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)


### Training Callbacks

*Callbacks allow us to customize the training process. We'll use:*

- *EarlyStopping: Stops training when a monitored metric stops improving*


In [None]:
# Define callback for early stopping
cb = EarlyStopping(
    monitor="val_loss",  # Monitor validation loss
    patience=5,  # Number of epochs with no improvement after which to stop
    verbose=1,  # Provides updates in the console
    restore_best_weights=True,  # Restore the best weights when stopped
)


### Model Training

*Now we'll train our model on the prepared dataset. The model will learn to identify patterns associated with each disease class.*


In [None]:
# Track the start time to measure training duration
import time

start_time = time.time()

# Train the model
history = seq_model.fit(
    train_ds,  # Training dataset
    epochs=25,  # Maximum number of complete passes through the dataset
    batch_size=BATCH_SIZE,  # Number of samples per gradient update
    validation_data=val_ds,  # Dataset for validation
    callbacks=[cb],  # List of callbacks to apply during training
    verbose=1,  # Progress bar mode
)

# Calculate and print training time
training_time = time.time() - start_time
print(
    f"\nTraining completed in {training_time:.2f} seconds ({training_time / 60:.2f} minutes)"
)


*During training, we can observe:*

1. *The loss and accuracy on both training and validation sets*
2. *Whether the model is improving over epochs*
3. *If/when early stopping is triggered*

*This gives us insights into how well our model is learning the patterns in the data.*

### Analyzing Training Results

*Let's examine how our model performed during training by visualizing the training history.*


In [None]:
# Print the final values in our training history
final_epoch = len(history.history["loss"]) - 1
print(f"Training Results after {final_epoch + 1} epochs:")
print(f"  Training Accuracy:   {history.history['accuracy'][final_epoch]:.4f}")
print(f"  Validation Accuracy: {history.history['val_accuracy'][final_epoch]:.4f}")
print(f"  Training Loss:       {history.history['loss'][final_epoch]:.4f}")
print(f"  Validation Loss:     {history.history['val_loss'][final_epoch]:.4f}")


In [None]:
# Function to create an interactive visualization of training history
def plot_history(history):
    """Create an interactive plot of training history using Plotly"""
    # Extract metrics from the history object
    loss = history.history["loss"]
    val_loss = history.history["val_loss"]
    acc = history.history["accuracy"]
    val_acc = history.history["val_accuracy"]

    # Find the best epochs for loss and accuracy
    best_loss_epoch = np.argmin(val_loss)
    best_acc_epoch = np.argmax(val_acc)

    # Create subplots for loss and accuracy
    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=("Model Loss", "Model Accuracy"),
        shared_xaxes=True,
    )

    # Add loss traces
    fig.add_trace(
        go.Scatter(
            x=list(range(len(loss))),
            y=loss,
            mode="lines",
            line=dict(color="#00B5F7", width=2),
            name="Training Loss",
            legendgroup="train",
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Scatter(
            x=list(range(len(val_loss))),
            y=val_loss,
            mode="lines",
            line=dict(color="#FF6B6B", width=2),
            name="Validation Loss",
            legendgroup="val",
        ),
        row=1,
        col=1,
    )

    # Mark the best loss epoch
    fig.add_trace(
        go.Scatter(
            x=[best_loss_epoch],
            y=[min(val_loss)],
            mode="markers",
            marker=dict(color="#4ECDC4", size=10, symbol="star"),
            name=f"Best Loss Epoch ({best_loss_epoch})",
            legendgroup="best_points",
            hoverinfo="text",
            hovertext=f"Best Validation Loss: {min(val_loss):.4f} at epoch {best_loss_epoch}",
        ),
        row=1,
        col=1,
    )

    # Add accuracy traces
    fig.add_trace(
        go.Scatter(
            x=list(range(len(acc))),
            y=acc,
            mode="lines",
            line=dict(color="#00B5F7", width=2),
            name="Training Accuracy",
            legendgroup="train",
            showlegend=False,
        ),
        row=1,
        col=2,
    )

    fig.add_trace(
        go.Scatter(
            x=list(range(len(val_acc))),
            y=val_acc,
            mode="lines",
            line=dict(color="#FF6B6B", width=2),
            name="Validation Accuracy",
            legendgroup="val",
            showlegend=False,
        ),
        row=1,
        col=2,
    )

    # Mark the best accuracy epoch
    fig.add_trace(
        go.Scatter(
            x=[best_acc_epoch],
            y=[max(val_acc)],
            mode="markers",
            marker=dict(color="#4ECDC4", size=10, symbol="star"),
            name=f"Best Accuracy Epoch ({best_acc_epoch})",
            legendgroup="best_points",
            hoverinfo="text",
            hovertext=f"Best Validation Accuracy: {max(val_acc):.4f} at epoch {best_acc_epoch}",
        ),
        row=1,
        col=2,
    )

    # Update layout for better appearance
    fig.update_layout(
        template="plotly_dark",
        paper_bgcolor="rgba(0,0,0,1)",
        plot_bgcolor="rgba(0,0,0,1)",
        title=dict(text="Model's Training Performance", y=0.98),
        height=500,
        width=1200,
        showlegend=True,
        legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=1),
        xaxis_title="Epoch",
        xaxis2_title="Epoch",
        yaxis_title="Loss",
        yaxis2_title="Accuracy",
    )

    # Add grid lines with dark theme colors
    fig.update_xaxes(
        showgrid=True,
        gridwidth=1,
        gridcolor="rgba(255,255,255,0.2)",
        showline=True,
        mirror=True,
    )
    fig.update_yaxes(
        showgrid=True,
        gridwidth=1,
        gridcolor="rgba(255,255,255,0.2)",
        showline=True,
        mirror=True,
    )

    # Show the figure
    fig.show()


# Plot the training history
plot_history(history)


## Model Evaluation

*Now that we have a trained model, we need to evaluate its performance on the test dataset (data the model hasn't seen during training). This will give us an unbiased estimate of how well the model will perform on new, unseen data.*

### Test Set Performance

*First, let's calculate the overall accuracy on our test set:*


In [None]:
# Evaluate the model on the test dataset
score = seq_model.evaluate(test_ds, verbose=0)

# Print the evaluation results
print("Test Set Evaluation Results:")
print(f"  Loss:     {score[0]:.4f}")
print(f"  Accuracy: {score[1] * 100:.2f}%")


*The test accuracy gives us a good overall measure of the model's performance, but we need more detailed metrics to fully understand its strengths and weaknesses for each disease class.*

### Visualizing Predictions on Test Images

*Let's examine how our model performs on specific test images. This visual analysis helps us understand what kinds of images the model handles well and where it might struggle.*


In [None]:
def visualize_predictions(model, dataset, class_names, num_images=16):
    """Visualize model predictions on sample images from the dataset"""
    for image_batch, label_batch in dataset.take(1):
        fig = plt.figure(figsize=(16, 16))
        plt.style.use("dark_background")
        fig.suptitle("Model Predictions on Test Images", fontsize=24, y=0.98)

        for i in range(min(num_images, len(image_batch))):
            plt.subplot(4, 4, i + 1)

            image = image_batch[i].numpy().astype(np.uint8)
            plt.imshow(image)

            # Get model prediction for this image
            pred = model.predict(image.reshape(1, IMAGE_SIZE, IMAGE_SIZE, 3), verbose=0)

            # Extract prediction details
            pred_label = class_names[np.argmax(pred)]
            true_label = class_names[label_batch[i]]
            confidence = np.max(pred) * 100

            pred_label_clean = pred_label.replace("Potato___", "").replace("_", " ")
            true_label_clean = true_label.replace("Potato___", "").replace("_", " ")

            # Color code the title based on correctness of prediction
            title_color = "green" if pred_label == true_label else "red"

            # Set title with prediction information
            plt.title(
                f"Pred: {pred_label_clean}\nTrue: {true_label_clean}\nConf: {confidence:.1f}%",
                color=title_color,
                fontsize=10,
                pad=8,
            )

            plt.axis("off")

        plt.tight_layout()
        plt.subplots_adjust(top=0.94)
        plt.show()
        break  # Process one batch


visualize_predictions(seq_model, test_ds, class_names)


*From these visualizations, we can observe:*

1. *Green titles indicate correct predictions, red titles indicate misclassifications*
2. *The confidence level (percentage) indicates how certain the model is about its prediction*
3. *We can analyze which types of images are more challenging for the model*

*This helps us understand the model's strengths and limitations in a practical context.*


### Confusion Matrix and Classification Report

*A confusion matrix is a table that summarizes the prediction results, showing the counts of true positives, false positives, true negatives, and false negatives for each class. This helps us understand which classes the model might be confusing with each other.*


In [None]:
# Collect true labels and predictions across the entire test set
y_true = []
y_pred = []


print("Generating predictions on the test dataset...")
for i, (image_batch, label_batch) in enumerate(test_ds):
    # Add true labels to our list
    y_true.extend(label_batch.numpy())

    # Get predictions for this batch
    batch_predictions = seq_model.predict(image_batch, verbose=0)

    # Convert prediction probabilities to class indices
    batch_pred_classes = [np.argmax(pred) for pred in batch_predictions]
    y_pred.extend(batch_pred_classes)

    # Show progress
    print(f"Processed batch {i + 1}/{len(test_ds)}", end="\r")

print(f"\nCompleted predictions on {len(y_true)} test images")


In [None]:
def plot_confusion_matrix(y_true, y_pred, class_names):
    """Create a visually enhanced confusion matrix"""

    cm = confusion_matrix(y_true, y_pred)

    # Convert to a DataFrame for better visualization
    cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)

    readable_class_names = [
        name.replace("Potato___", "").replace("_", " ") for name in class_names
    ]
    cm_df.index = readable_class_names
    cm_df.columns = readable_class_names

    plt.figure(figsize=(10, 8))
    plt.style.use("dark_background")

    # Create an enhanced heatmap with both counts and percentages
    sns.heatmap(
        cm_df,
        annot=True,  # Show values in cells
        fmt="d",  # Use the integer format for counts
        cmap="Blues",  # Use a color palette that's easier to interpret
        linewidths=1,  # Add lines between cells
        cbar=True,  # Show color bar
        square=True,  # Make cells square-shaped
    )

    # Customize the plot
    plt.title("Confusion Matrix: Potato Disease Classification", fontsize=16, pad=20)
    plt.ylabel("Predicted Class", fontsize=14, labelpad=10)
    plt.xlabel("True Class", fontsize=14, labelpad=10)

    plt.xticks(rotation=45, ha="right", fontsize=12)
    plt.yticks(rotation=0, fontsize=12)

    plt.tight_layout()
    plt.show()

    return cm_df


cm_df = plot_confusion_matrix(y_true, y_pred, class_names)


In [None]:
# Display a detailed classification report
print("\nDetailed Classification Report:")
print("----------------------------------")

# Generate the classification report
report = classification_report(
    y_true,
    y_pred,
    target_names=[
        name.replace("Potato___", "").replace("_", " ") for name in class_names
    ],
    digits=4,
)

print(report)


*The confusion matrix and classification report provide important insights:*

- *Precision: The proportion of positive identifications that were actually correct*
- *Recall: The proportion of actual positives that were correctly identified*
- *F1-score: The harmonic mean of precision and recall*
- *Support: The number of samples in each class*

*These metrics help us understand how well our model performs for each disease class and identify any systematic misclassifications.*

### ROC Curve Analysis

*The Receiver Operating Characteristic (ROC) curve is a graphical plot that illustrates the diagnostic ability of a binary classifier system as its discrimination threshold is varied. For multi-class problems like ours, we plot ROC curves for each class.*

*This analysis helps us understand how well the model can distinguish between classes across different threshold settings.*


In [None]:
# Collect prediction probabilities and true labels from the test dataset
print("Generating prediction probabilities for ROC curve analysis...")
y_probs = []
y_true = []


for i, (image_batch, label_batch) in enumerate(test_ds):
    batch_probs = seq_model.predict(image_batch, verbose=0)
    y_probs.extend(batch_probs)
    y_true.extend(label_batch.numpy())
    print(f"Processed batch {i + 1}/{len(test_ds)}", end="\r")

print(f"\nCollected predictions for {len(y_true)} test images")


y_probs = np.array(y_probs)
y_true = np.array(y_true)


In [None]:
# Prepare data for ROC curve calculation

# Binarize the labels for multi-class ROC analysis
# This converts integer class labels to a binary matrix representation
y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
n_classes = y_true_bin.shape[1]

# Verify data integrity
if len(y_true) == 0 or len(y_probs) == 0 or len(y_true) != len(y_probs):
    raise ValueError(
        "Empty or mismatched data: y_true and y_probs must be non-empty and have the same length."
    )


In [None]:
# Calculate ROC curves and AUC for each class
fpr, tpr, roc_auc = {}, {}, {}

# Process each class separately
for i in range(n_classes):
    # Calculate false positive rate and true positive rate
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
    # Calculate the area under the curve
    roc_auc[i] = auc(fpr[i], tpr[i])

# Calculate the micro-average ROC curve and AUC (treats the problem as binary by flattening all classes)
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_probs.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Calculate macro-average ROC curve and AUC (average of per-class curves)
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])


In [None]:
plt.style.use("dark_background")
plt.figure(figsize=(10, 6))
plt.subplots_adjust(top=0.9, bottom=0.18)

# Define class-specific colors and readable class names
colors = ["#00B5F7", "#FF6B6B", "#4ECDC4", "#FFD93D", "#FF8066", "#95A5A6"]
readable_class_names = [
    name.replace("Potato___", "").replace("_", " ") for name in class_names
]

# Plot ROC curve for each class
for i in range(n_classes):
    plt.plot(
        fpr[i],
        tpr[i],
        color=colors[i % len(colors)],
        lw=2.5,
        label=f"{readable_class_names[i]} (AUC = {roc_auc[i]:.3f})",
    )

# Plot micro-average ROC curve
plt.plot(
    fpr["micro"],
    tpr["micro"],
    label=f"Micro-average (AUC = {roc_auc['micro']:.3f})",
    color="#FF69B4",
    linestyle=":",
    lw=2,
)

# Plot macro-average ROC curve
plt.plot(
    fpr["macro"],
    tpr["macro"],
    label=f"Macro-average (AUC = {roc_auc['macro']:.3f})",
    color="#87CEEB",
    linestyle=":",
    lw=2,
)

# Plot diagonal line representing random chance
plt.plot([0, 1], [0, 1], "w--", lw=2, label="Random Chance (AUC = 0.5)")

# Customize the plot appearance
plt.xlim([-0.01, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate", fontsize=14)
plt.ylabel("True Positive Rate", fontsize=14)
plt.title("ROC Curves for Potato Disease Classification", fontsize=16, pad=20)
plt.legend(loc="lower right", fontsize=10)
plt.grid(True, alpha=0.3)

# Add explanatory text with light background
plt.figtext(
    0.5,
    0.01,
    "A perfect classifier would have an AUC of 1.0 (top-left corner)\n"
    "Random guessing would give an AUC of 0.5 (diagonal line)",
    ha="center",
    fontsize=10,
    color="white",
    bbox={"facecolor": "#2C3E50", "alpha": 0.7, "pad": 5},
)


plt.show()


### Interpreting the ROC Curves

*From the ROC curves, we can observe:*

- **Area Under the Curve (AUC)**: A higher AUC indicates better model performance
  - AUC = 1.0: Perfect classification
  - AUC = 0.5: No better than random guessing (the diagonal line)

- **Micro-average**: Calculates metrics globally by considering each element of the label indicator matrix

- **Macro-average**: Calculates metrics for each class and takes the average (treats all classes equally)

*The high AUC values across all classes indicate that our model has strong discriminative power for potato disease classification.*

## Model Saving and Deployment

*After training and evaluating our model, we need to save it so it can be loaded and used for future predictions without retraining.*

### Saving the Trained Model

*We'll save our model in the TensorFlow Keras format (.keras), which preserves both the model architecture and trained weights.*


In [None]:
# Define save directory
save_dir = "SAVED_MODELS"

# Create a directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)


# Function to get the next available model number
def get_next_model_number():
    existing_models = [
        f
        for f in os.listdir(save_dir)
        if f.startswith("model_") and f.endswith(".keras")
    ]
    if not existing_models:
        return 0
    existing_numbers = [
        int(f.replace("model_", "").replace(".keras", "")) for f in existing_models
    ]
    return max(existing_numbers) + 1


# Get the next available model number
model_version = get_next_model_number()

# Define the complete save path
model_path = os.path.join(save_dir, f"model_{model_version}.keras")

# Save the model
print(f"Saving model to {model_path}...")
seq_model.save(model_path)
print("Model saved successfully!")


In [None]:
# Verify that the model can be loaded again
from tensorflow.keras.models import load_model

# Try loading the model to verify it saved correctly
print(f"Verifying model by loading it from {model_path}...")
try:
    loaded_model = load_model(model_path)
    print("Model loaded successfully! Architecture:")
    loaded_model.summary()
    print("\nModel verification complete. The model is ready for deployment.")
except Exception as e:
    print(f"Error loading model: {e}")


### Model Deployment Considerations

*Now that we have a trained and saved model, it can be deployed in various ways:*

1. **Web Application**: Using frameworks like Flask, Django, or Gradio to create a user-friendly interface
2. **Mobile Application**: Converting the model to TensorFlow Lite for mobile deployment
3. **Cloud API**: Hosting the model on cloud platforms like AWS, Azure, or Google Cloud
4. **Edge Devices**: Deploying to IoT devices for in-field detection

*The deployment approach depends on the specific use case and target users. Here we used Gradio for the Deployment*


In [None]:
# Load the trained model
MODEL_PATH = "/content/model_0.keras"  # Update this path to your saved model
model = tf.keras.models.load_model(MODEL_PATH)

# Define class names (should match your training data)
CLASS_NAMES = ["Potato___Early_blight", "Potato___Late_blight", "Potato___healthy"]

# Image preprocessing parameters
IMAGE_SIZE = 256


def preprocess_image(image):
    """
    Preprocess the input image for model prediction
    """
    # Convert the PIL image to a numpy array
    if isinstance(image, Image.Image):
        image = np.array(image)

    # Add batch dimension
    image = tf.expand_dims(image, axis=0)

    return image


def create_probability_plot(probabilities, class_names):
    """
    Create a bar plot of prediction probabilities
    """
    # Set the dark theme
    plt.style.use("dark_background")
    fig, ax = plt.subplots(figsize=(10, 6))
    fig.subplots_adjust(top=0.9, bottom=0.2)
    fig.patch.set_facecolor("#1e1e1e")
    ax.set_facecolor("#1e1e1e")

    # Create a color map - highlight the highest probability with modern colors
    colors = [
        "#ff6b6b" if i != np.argmax(probabilities) else "#4ecdc4"
        for i in range(len(probabilities))
    ]

    bars = ax.bar(
        class_names,
        probabilities * 100,
        color=colors,
        edgecolor="white",
        linewidth=1.5,
        alpha=0.9,
    )

    # Add percentage labels on bars with white text
    for bar, prob in zip(bars, probabilities):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + 1,
            f"{prob * 100:.1f}%",
            ha="center",
            va="bottom",
            fontweight="bold",
            color="white",
            fontsize=11,
        )

    # Style the plot with white text
    ax.set_title(
        "Potato Disease Classification Probabilities",
        fontsize=16,
        fontweight="bold",
        color="white",
        pad=20,
    )
    ax.set_xlabel("Disease Classes", fontsize=14, color="white")
    ax.set_ylabel("Probability (%)", fontsize=14, color="white")
    ax.set_ylim(0, 105)

    # Style tick labels
    ax.tick_params(axis="x", colors="white", labelsize=12)
    ax.tick_params(axis="y", colors="white", labelsize=12)

    # Add a subtle grid for better readability
    ax.grid(axis="y", alpha=0.3, color="gray", linestyle="--")

    # Remove top and right spines for a cleaner look
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_color("white")
    ax.spines["bottom"].set_color("white")

    plt.tight_layout()

    # Convert plot to image
    buf = io.BytesIO()
    plt.savefig(
        buf,
        format="png",
        dpi=150,
        bbox_inches="tight",
        facecolor="#1e1e1e",
        edgecolor="none",
    )
    buf.seek(0)
    plt.close()

    # Reset to the default style to avoid affecting other plots
    plt.style.use("default")

    return Image.open(buf)


def get_disease_info(predicted_class):
    """
    Return information about the predicted disease
    """
    disease_info = {
        "Potato___Early_blight": {
            "description": "Early blight is a common potato disease caused by Alternaria solani. It appears as dark spots with concentric rings on leaves.",
            "treatment": "Use fungicides, practice crop rotation, and ensure proper plant spacing for air circulation.",
            "severity": "Moderate",
        },
        "Potato___Late_blight": {
            "description": "Late blight is a serious potato disease caused by Phytophthora infestans. It can cause rapid destruction of leaves and tubers.",
            "treatment": "Apply fungicides preventively, remove infected plants, and avoid overhead watering.",
            "severity": "High",
        },
        "Potato___healthy": {
            "description": "The potato plant appears healthy with no visible signs of disease.",
            "treatment": "Continue regular monitoring and maintain good agricultural practices.",
            "severity": "None",
        },
    }

    return disease_info.get(
        predicted_class,
        {
            "description": "Unknown disease classification.",
            "treatment": "Consult with agricultural experts.",
            "severity": "Unknown",
        },
    )


def predict_disease(image):
    """
    Main prediction function for Gradio interface
    """
    if image is None:
        return "Please upload an image", None, ""

    try:
        # Preprocess the image
        processed_image = preprocess_image(image)

        # Make prediction
        predictions = model.predict(processed_image, verbose=0)
        probabilities = predictions[0]

        # Get predicted class
        predicted_class_idx = np.argmax(probabilities)
        predicted_class = CLASS_NAMES[predicted_class_idx]
        confidence = probabilities[predicted_class_idx] * 100

        # Create probability plot
        disease_names = [
            name.replace("Potato__", "").replace("_", " ").title()
            for name in CLASS_NAMES
        ]
        plot_image = create_probability_plot(probabilities, disease_names)

        # Get disease information
        disease_info = get_disease_info(predicted_class)

        # Format results
        result_text = f"""
                    ## 🔍 **Prediction Results**

                    **Predicted Disease:** {predicted_class.replace("Potato___", "").replace("_", " ").title()}
                    **Confidence:** {confidence:.2f}%

                    ## 📋 **Disease Information**

                    **Description:** {disease_info["description"]}

                    **Recommended Treatment:** {disease_info["treatment"]}

                    **Severity Level:** {disease_info["severity"]}

                    ---
                    *Note: This is an AI-based prediction. For critical decisions, please consult with agricultural experts.*
                    """

        return (
            result_text,
            plot_image,
            f"Prediction: {predicted_class.replace('Potato___', '').replace('_', ' ').title()} ({confidence:.1f}%)",
        )

    except Exception as e:
        return f"Error processing image: {str(e)}", None, "Error occurred"


# Create Gradio interface
def create_gradio_app():
    """
    Create and configure the Gradio interface
    """

    # Custom CSS for better styling
    css = """
    .gradio-container {
        font-family: 'Arial', sans-serif;
    }
    .output-markdown {
        font-size: 14px;
    }
    .image-upload {
        border: 2px dashed #4CAF50;
        border-radius: 10px;
    }
    """

    with gr.Blocks(css=css, title="Potato Disease Classifier") as app:
        gr.Markdown("""
        # 🥔 Potato Disease Classification System

        Upload an image of a potato leaf to detect potential diseases. The system can identify:
        - **Early Blight** - Caused by Alternaria solani
        - **Late Blight** - Caused by Phytophthora infestans
        - **Healthy** - No disease detected

        ---
        """)

        with gr.Row():
            with gr.Column(scale=1):
                # Input section
                gr.Markdown("### 📤 Upload Image")
                input_image = gr.Image(
                    type="pil",
                    label="Upload a potato leaf image",
                    elem_classes=["image-upload"],
                )

                predict_btn = gr.Button(
                    "🔍 Analyze Disease", variant="primary", size="lg"
                )

                # Quick status
                status_text = gr.Textbox(label="Status", interactive=False, max_lines=1)

            with gr.Column(scale=2):
                # Output section
                gr.Markdown("### 📊 Results")

                with gr.Row():
                    result_text = gr.Markdown(
                        value="Upload an image and click 'Analyze Disease' to see results.",
                        elem_classes=["output-markdown"],
                    )

                probability_plot = gr.Image(
                    label="Probability Distribution", type="pil"
                )

        # Examples section
        gr.Markdown("### 🖼️ Example Images")
        gr.Markdown("*Click on any example image below to test the classifier:*")

        # You can add example images here if you have them
        gr.Examples(
            examples=[
                ["/content/Early_Blight.JPG"],
                ["/content/Late_Blight.JPG"],
                ["/content/Healthy.JPG"],
            ],
            inputs=input_image,
        )

        # Connect the prediction function
        predict_btn.click(
            fn=predict_disease,
            inputs=[input_image],
            outputs=[result_text, probability_plot, status_text],
        )

        # Also allow prediction on image upload
        input_image.change(
            fn=predict_disease,
            inputs=[input_image],
            outputs=[result_text, probability_plot, status_text],
        )

        gr.Markdown("""
        ---
        ### ℹ️ **Important Notes:**
        - Ensure the image clearly shows potato leaves
        - Good lighting and focus improve accuracy
        - This tool is for educational/research purposes
        - Always verify results with agricultural experts for critical decisions

        **Model Accuracy:** ~98.5% on test dataset
        """)

    return app


# Launch the app
if __name__ == "__main__":
    try:
        # Create and launch the Gradio app
        app = create_gradio_app()

        # Launch with custom settings
        app.launch(
            share=True,  # Set to True to create a public link
            server_name="0.0.0.0",  # Allow access from any IP
            server_port=7860,  # Default Gradio port
            debug=True,  # Enable debug mode
            show_error=True,  # Show detailed error messages
        )

    except Exception as e:
        print(f"Error launching Gradio app: {e}")
        print("Make sure you have installed all required packages:s")
        print("pip install gradio tensorflow pillow matplotlib seaborn")

## Conclusion and Future Work

*This project successfully demonstrates how deep learning can be applied to agricultural challenges, specifically the detection of potato plant diseases. Our model achieved high accuracy in distinguishing between early blight, late blight, and healthy potato plants.*

### Key Achievements

1. **Effective Disease Classification**: The model achieved approximately 95% accuracy on the test dataset, demonstrating its effectiveness in identifying potato diseases.

2. **User-Friendly Interface**: We created an intuitive web application that allows farmers and agricultural experts to use the model without technical knowledge.

3. **Comprehensive Pipeline**: We implemented a complete machine learning pipeline from data preparation through model training to deployment.

### Limitations and Future Improvements

1. **Dataset Expansion**: 
   - Include more disease classes (viral diseases, nutrient deficiencies, etc.)
   - Add more images with varying lighting conditions, angles, and backgrounds
   - Incorporate time-series data to track disease progression

2. **Model Enhancements**:
   - Experiment with more advanced architectures like ResNet, EfficientNet, or Vision Transformers
   - Implement explainable AI techniques to highlight which parts of the leaf influenced the classification
   - Add disease severity estimation (percentage of leaf affected)

3. **Deployment Improvements**:
   - Develop a mobile application for field use without internet connectivity
   - Implement a recommendation system for treatment options based on disease severity
   - Add multi-language support for global use

### Potential Impact

This technology has the potential to significantly impact potato farming by:

- Enabling early disease detection before symptoms are visible to the human eye
- Reducing unnecessary pesticide use through targeted treatment
- Decreasing crop losses and improving food security
- Making expert knowledge more accessible to small-scale farmers

*With further development and refinement, this system could become an invaluable tool for sustainable agriculture and food security.*


## References and Resources

### Datasets

- PlantVillage Dataset: [Kaggle - Plant Village](https://www.kaggle.com/datasets/arjuntejaswi/plant-village)

### Technical Resources

- TensorFlow Documentation: [https://www.tensorflow.org/api_docs](https://www.tensorflow.org/api_docs)
- Keras Documentation: [https://keras.io/api/](https://keras.io/api/)
- Gradio Documentation: [https://gradio.app/docs/](https://gradio.app/docs/)

### Scientific References

1. Mohanty, S.P., Hughes, D.P. & Salathé, M. (2016). Using Deep Learning for Image-Based Plant Disease Detection. Frontiers in Plant Science, 7:1419.

2. Ferentinos, K.P. (2018). Deep learning models for plant disease detection and diagnosis. Computers and Electronics in Agriculture, 145, 311-318.

3. Singh, V., & Misra, A. K. (2017). Detection of plant leaf diseases using image segmentation and soft computing techniques. Information Processing in Agriculture, 4(1), 41-49.

### Additional Learning Resources

- [CNN Explainer](https://poloclub.github.io/cnn-explainer/): Interactive visualization for understanding convolutional neural networks
- [Deep Learning for Computer Vision](https://www.coursera.org/specializations/deep-learning): Coursera specialization by Andrew Ng
