# Plant Disease Detection using Transfer Learning (MobileNetV2)

## 1. Introduction
In this notebook, we develop a Convolutional Neural Network (CNN) based on the **MobileNetV2** architecture to classify plant diseases. The objective is to achieve high accuracy while maintaining a low parameter count, making the model suitable for mobile edge computing.

**Methodology:**
* **Data Source:** PlantVillage Dataset (38 Classes).
* **Technique:** Transfer Learning (Feature Extraction & Fine-Tuning).
* **Framework:** TensorFlow/Keras.

In [None]:

import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import splitfolders
from sklearn.metrics import classification_report, confusion_matrix

# Configuration
# Setting seeds for reproducibility in academic research
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)

print(f"TensorFlow Version: {tf.__version__}")
print(f"GPU Available: {len(tf.config.list_physical_devices('GPU')) > 0}")

## 2. Data Preparation and Engineering
We utilize the `split-folders` library to physically partition the dataset into Training, Validation, and Test sets. This prevents **data leakage**, a common issue in deep learning projects where validation data inadvertently bleeds into the training process.

In [None]:
# Path Configuration
# NOTE: Ensure you have downloaded the dataset to 'data/raw/PlantVillage'
INPUT_FOLDER = '../data/raw/PlantVillage'
OUTPUT_FOLDER = '../data/processed/PlantVillage_Split'

def split_dataset(input_path, output_path):
    """
    Splits the dataset into training, validation, and testing sets.

    Args:
        input_path (str): Path to the raw dataset.
        output_path (str): Path where the split dataset will be saved.
    """
    if not os.path.exists(input_path):
        print(f"Dataset not found at {input_path}. Please download it first.")
        return

    print("Splitting dataset... This may take a while.")
    # Split ratio: 80% Train, 10% Validation, 10% Test
    splitfolders.ratio(input_path, output=output_path,
                       seed=SEED, ratio=(.8, .1, .1),
                       group_prefix=None, move=False)
    print("Data splitting completed successfully.")

# Uncomment the line below to run splitting (Run once)
# split_dataset(INPUT_FOLDER, OUTPUT_FOLDER)

In [None]:
# Image Parameters
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

DATA_DIR = '../data/processed/PlantVillage_Split'
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
VAL_DIR = os.path.join(DATA_DIR, 'val')
TEST_DIR = os.path.join(DATA_DIR, 'test')

# Using tf.keras.utils.image_dataset_from_directory for efficient data pipeline
train_ds = tf.keras.utils.image_dataset_from_directory(
    TRAIN_DIR,
    labels='categorical',
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=SEED
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    VAL_DIR,
    labels='categorical',
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    TEST_DIR,
    labels='categorical',
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False
)

# Retrieve class names
class_names = train_ds.class_names
print(f"Number of Classes: {len(class_names)}")
print(f"Example Classes: {class_names[:5]}")

# Performance Optimization: Prefetching and Caching
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

## 3. Model Architecture: MobileNetV2
We employ **MobileNetV2** as the backbone. The weights are pre-trained on ImageNet.
1.  **Preprocessing:** Inputs are scaled to `[-1, 1]`.
2.  **Base:** MobileNetV2 (frozen).
3.  **Head:** Global Average Pooling -> Dropout -> Dense (Softmax).

In [None]:
def build_model(num_classes):
    """
    Constructs the Transfer Learning model based on MobileNetV2.

    Args:
        num_classes (int): Number of target classes (38 for PlantVillage).

    Returns:
        tf.keras.Model: Compiled model ready for training.
    """
    # 1. Base Model
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=IMG_SIZE + (3,),
        include_top=False,
        weights='imagenet'
    )

    # Freeze the base model to prevent destroying learned features
    base_model.trainable = False

    # 2. Input and Preprocessing
    inputs = tf.keras.Input(shape=IMG_SIZE + (3,))
    x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)

    # 3. Feature Extraction
    x = base_model(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)  # Regularization to prevent overfitting

    # 4. Classification Head
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, outputs)
    return model

model = build_model(len(class_names))
model.summary()

# Compilation
model.compile(
    optimizer=optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

## 4. Training Process
We implement callbacks such as `EarlyStopping` and `ModelCheckpoint` to ensure the training process is efficient and saves the best performing weights.

In [None]:
EPOCHS = 20

# Callbacks
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "plant_village_mobilenetv2.h5",
    save_best_only=True,
    monitor='val_loss',
    mode='min'
)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=5,
    restore_best_weights=True,
    monitor='val_loss'
)

# Train the model
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[checkpoint_cb, early_stopping_cb]
)

In [None]:
def plot_history(history):
    """Plots training and validation metrics."""
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()

plot_history(history)

## 5. Model Evaluation
The final evaluation is conducted on the **Test Set**, which the model has never seen during training. This provides an unbiased estimate of the model's performance in real-world scenarios.

In [None]:
# Evaluate on Test Data
loss, accuracy = model.evaluate(test_ds)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Confusion Matrix & Classification Report
# Note: This requires extracting labels from the dataset
y_true = []
y_pred = []

for images, labels in test_ds:
    preds = model.predict(images, verbose=0)
    y_true.extend(np.argmax(labels.numpy(), axis=1))
    y_pred.extend(np.argmax(preds, axis=1))

print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:

# [CELL 13] - Explainable AI (XAI): Grad-CAM Implementation
# ---------------------------------------------------------
# This module visualizes the "Region of Interest" (ROI) that the model focuses on.
# It uses Gradient-weighted Class Activation Mapping (Grad-CAM).

import cv2
import os
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

def get_img_array(img_path, size):
    """
    Loads and preprocesses an image for the model.
    Note: Since our model includes a preprocessing layer, we return the raw image array
    but expanded to match batch dimensions.
    """
    # Load image with target size
    img = tf.keras.preprocessing.image.load_img(img_path, target_size=size)
    # Convert to array
    array = tf.keras.preprocessing.image.img_to_array(img)
    # Add batch dimension (1, 224, 224, 3)
    array = np.expand_dims(array, axis=0)
    return array

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    """
    Generates a Grad-CAM heatmap by accessing the internal layers of MobileNetV2.
    """
    # 1. Access the internal MobileNetV2 base layer dynamically
    base_model_layer = None
    for layer in model.layers:
        # Check if the layer is a Functional model (MobileNetV2)
        if isinstance(layer, tf.keras.Model):
            base_model_layer = layer
            break

    if base_model_layer is None:
        # Fallback: If model is not nested, assume 'model' is the base
        base_model_layer = model

    # 2. Access the last convolutional layer ('out_relu' is standard for MobileNetV2)
    try:
        last_conv_layer = base_model_layer.get_layer(last_conv_layer_name)
    except ValueError:
        raise ValueError(f"Layer '{last_conv_layer_name}' not found in the model.")

    # 3. Create a sub-model that outputs the last conv layer
    # We map the inputs of the base model to the output of the last conv layer
    last_conv_layer_model = tf.keras.Model(base_model_layer.inputs, last_conv_layer.output)

    # 4. Create a classifier model
    # It takes the output of the last conv layer and passes it through the rest of the main model
    # (GAP + Dropout + Dense)
    classifier_input = tf.keras.Input(shape=last_conv_layer.output.shape[1:])
    x = classifier_input

    # We need to find where the base model ends in the main model's layer list
    # and pass the input through the remaining layers (GAP, Dropout, Dense)
    # Typically, these are the last 3 layers in our architecture.
    for layer in model.layers[-3:]:
        x = layer(x)

    classifier_model = tf.keras.Model(classifier_input, x)

    # 5. Compute Gradients
    with tf.GradientTape() as tape:
        # Preprocess input manually since we are bypassing the main model's preprocessing layer
        # MobileNetV2 expects [-1, 1] scaling.
        inputs = tf.keras.applications.mobilenet_v2.preprocess_input(tf.cast(img_array, tf.float32))

        # Get conv output
        last_conv_layer_output = last_conv_layer_model(inputs)
        tape.watch(last_conv_layer_output)

        # Get predictions
        preds = classifier_model(last_conv_layer_output)

        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # 6. Gradient calculation and pooling
    grads = tape.gradient(class_channel, last_conv_layer_output)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # 7. Generate Heatmap
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # 8. Normalize
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

def save_and_display_gradcam(img_path, heatmap, alpha=0.4):
    """
    Superimposes the heatmap on the original image and displays it using Matplotlib.
    """
    # Load original image
    img = cv2.imread(img_path)
    img = cv2.resize(img, (224, 224))

    # Rescale heatmap to 0-255
    heatmap = np.uint8(255 * heatmap)

    # Colorize heatmap
    jet = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    # Superimpose
    superimposed_img = jet * alpha + img
    superimposed_img = np.clip(superimposed_img, 0, 255).astype('uint8')

    # Display
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    superimposed_rgb = cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(img_rgb)
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(superimposed_rgb)
    plt.title("Grad-CAM Attention Map")
    plt.axis('off')
    plt.show()

# --- EXECUTION BLOCK ---
try:
    # 1. Define the Test Directory explicitly to avoid 'test_ds' dependency issues
    # Ensure this path matches the OUTPUT_FOLDER defined in Cell 4
    TEST_DIR_PATH = '../data/processed/PlantVillage_Split/test'

    if not os.path.exists(TEST_DIR_PATH):
        print(f"Error: Test directory not found at {TEST_DIR_PATH}")
    else:
        # 2. Collect all image paths from the test directory
        all_test_images = []
        for root, dirs, files in os.walk(TEST_DIR_PATH):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    all_test_images.append(os.path.join(root, file))

        if len(all_test_images) > 0:
            # 3. Select a random image
            random_img_path = random.choice(all_test_images)

            # Extract class name from path for display
            actual_class = os.path.basename(os.path.dirname(random_img_path))
            print(f"Analyzing Image: {random_img_path}")
            print(f"Actual Class: {actual_class}")

            # 4. Prepare image
            img_array = get_img_array(random_img_path, size=(224, 224))

            # 5. Generate Heatmap
            # 'out_relu' is the standard last conv layer for MobileNetV2
            heatmap = make_gradcam_heatmap(img_array, model, 'out_relu')

            # 6. Display Results
            save_and_display_gradcam(random_img_path, heatmap)
        else:
            print("No images found in the test directory.")

except Exception as e:
    print(f"XAI Error: {e}")
    print("Ensure that the model is compiled and the 'out_relu' layer exists.")