# Plant Disease Recognition Using MobileNet

**MSc Data Science - Deep Learning Applications (CMP-L016)**

**Project 17: Plant Disease Recognition Using MobileNet Variants**

---

**Author:** [Your Name]

**Date:** December 2025

## 1. Setup and GPU Configuration

In [None]:
!pip install -q kagglehub

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPU devices found: {len(gpus)}")
        for gpu in gpus:
            print(f"  - {gpu.name}")
    except RuntimeError as e:
        print(e)
else:
    print("No GPU found. Please enable GPU in Runtime > Change runtime type > GPU")

print(f"\nTensorFlow Version: {tf.__version__}")
print(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing import image
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

## 2. Download Dataset

In [None]:
import kagglehub

path = kagglehub.dataset_download("karagwaanntreasure/plant-disease-detection")
print("Path to dataset files:", path)

In [None]:
dataset_path = os.path.join(path, "Dataset")

if not os.path.exists(dataset_path):
    for root, dirs, files in os.walk(path):
        for d in dirs:
            if d == "Dataset":
                dataset_path = os.path.join(root, d)
                break
        if os.path.exists(dataset_path):
            break
    else:
        dataset_path = path

print(f"Dataset path: {dataset_path}")

## 3. Data Exploration

In [None]:
class_names = sorted([d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))])
num_classes = len(class_names)

print(f"Number of classes: {num_classes}")
print(f"\nClass names:")
for i, name in enumerate(class_names):
    print(f"  {i+1}. {name}")

In [None]:
class_counts = {}
for class_name in class_names:
    class_path = os.path.join(dataset_path, class_name)
    count = len([f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    class_counts[class_name] = count

df_counts = pd.DataFrame(list(class_counts.items()), columns=['Class', 'Count'])
df_counts = df_counts.sort_values('Count', ascending=False)

print(f"Total images: {df_counts['Count'].sum()}")
print(f"\nClass distribution:")
print(df_counts.to_string(index=False))

In [None]:
plt.figure(figsize=(14, 8))
colors = plt.cm.viridis(np.linspace(0, 1, len(df_counts)))
bars = plt.barh(df_counts['Class'], df_counts['Count'], color=colors)
plt.xlabel('Number of Images', fontsize=12)
plt.ylabel('Disease Class', fontsize=12)
plt.title('Distribution of Images Across Disease Classes', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
for bar, count in zip(bars, df_counts['Count']):
    plt.text(bar.get_width() + 10, bar.get_y() + bar.get_height()/2, str(count), va='center', fontsize=9)
plt.tight_layout()
plt.savefig('class_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
n_cols = 5
n_rows = (len(class_names) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3*n_rows))
axes = axes.flatten()

for idx, class_name in enumerate(class_names):
    class_path = os.path.join(dataset_path, class_name)
    images = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    if images:
        img_path = os.path.join(class_path, images[0])
        img = plt.imread(img_path)
        axes[idx].imshow(img)
        short_name = class_name.replace('_', '\n')[:30]
        axes[idx].set_title(short_name, fontsize=8)
    axes[idx].axis('off')

for idx in range(len(class_names), len(axes)):
    axes[idx].axis('off')

plt.suptitle('Sample Images from Each Disease Class', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('sample_images.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Data Preprocessing

In [None]:
img_size = (224, 224)
batch_size = 32

datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

train_data = datagen.flow_from_directory(
    dataset_path,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    subset='training'
)

val_data = datagen.flow_from_directory(
    dataset_path,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'
)

num_classes = train_data.num_classes
print(f"\nNumber of classes: {num_classes}")
print(f"Training samples: {train_data.samples}")
print(f"Validation samples: {val_data.samples}")

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(14, 7))
axes = axes.flatten()
batch = next(train_data)
images, labels = batch
class_indices = {v: k for k, v in train_data.class_indices.items()}

for i in range(8):
    axes[i].imshow(images[i])
    label_idx = np.argmax(labels[i])
    class_name = class_indices[label_idx]
    short_name = class_name.replace('_', ' ')[:25]
    axes[i].set_title(f'{short_name}', fontsize=9)
    axes[i].axis('off')

plt.suptitle('Augmented Training Images', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('augmented_samples.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. MobileNetV2 Transfer Learning Model

In [None]:
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False

model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(128, activation='relu'),
    Dropout(0.3),
    Dense(num_classes, activation='softmax')
])

model.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

In [None]:
history = model.fit(
    train_data,
    validation_data=val_data,
    epochs=15
)

In [None]:
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('mobilenet_accuracy.png', dpi=150, bbox_inches='tight')
plt.show()

plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('mobilenet_loss.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Model Evaluation

In [None]:
val_loss, val_acc = model.evaluate(val_data, verbose=0)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")

In [None]:
val_data.reset()
predictions = model.predict(val_data, verbose=1)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = val_data.classes
class_labels = list(val_data.class_indices.keys())

In [None]:
print("\nClassification Report:")
print("="*80)
report = classification_report(true_classes, predicted_classes, target_names=class_labels, digits=4)
print(report)

with open('classification_report.txt', 'w') as f:
    f.write("Classification Report - MobileNetV2\n")
    f.write("="*80 + "\n")
    f.write(report)

In [None]:
cm = confusion_matrix(true_classes, predicted_classes)

plt.figure(figsize=(16, 14))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.title('Confusion Matrix - MobileNetV2', fontsize=14, fontweight='bold')
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=90, fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
accuracy_df = pd.DataFrame({
    'Class': class_labels,
    'Accuracy': per_class_accuracy,
    'Samples': cm.sum(axis=1)
}).sort_values('Accuracy')

print("Per-Class Accuracy (sorted):")
print(accuracy_df.to_string(index=False))

plt.figure(figsize=(12, 8))
colors = ['red' if acc < 0.8 else 'orange' if acc < 0.9 else 'green' for acc in accuracy_df['Accuracy']]
plt.barh(accuracy_df['Class'], accuracy_df['Accuracy'], color=colors)
plt.axvline(x=0.9, color='green', linestyle='--', label='90% threshold')
plt.axvline(x=0.8, color='orange', linestyle='--', label='80% threshold')
plt.xlabel('Accuracy', fontsize=12)
plt.ylabel('Disease Class', fontsize=12)
plt.title('Per-Class Accuracy', fontsize=14, fontweight='bold')
plt.legend()
plt.tight_layout()
plt.savefig('per_class_accuracy.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Sample Predictions

In [None]:
val_data.reset()
batch = next(val_data)
images, labels = batch
preds = model.predict(images, verbose=0)
class_indices = {v: k for k, v in val_data.class_indices.items()}

fig, axes = plt.subplots(3, 4, figsize=(16, 12))
axes = axes.flatten()

for i in range(12):
    axes[i].imshow(images[i])
    true_idx = np.argmax(labels[i])
    pred_idx = np.argmax(preds[i])
    confidence = preds[i][pred_idx] * 100
    true_label = class_indices[true_idx].replace('_', ' ')[:25]
    pred_label = class_indices[pred_idx].replace('_', ' ')[:25]
    color = 'green' if true_idx == pred_idx else 'red'
    axes[i].set_title(f'True: {true_label}\nPred: {pred_label}\nConf: {confidence:.1f}%', fontsize=9, color=color)
    axes[i].axis('off')

plt.suptitle('Sample Predictions (Green=Correct, Red=Incorrect)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('sample_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Save Model

In [None]:
model.save('plant_disease_mobilenetv2.h5')
print("Model saved as 'plant_disease_mobilenetv2.h5'")

import json
with open('class_labels.json', 'w') as f:
    json.dump(class_labels, f, indent=2)
print("Class labels saved as 'class_labels.json'")

## 9. Test Prediction Function

In [None]:
def predict_image(img_path):
    img = image.load_img(img_path, target_size=img_size)
    img_array = image.img_to_array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    prediction = model.predict(img_array)
    predicted_class = class_labels[np.argmax(prediction)]
    confidence = np.max(prediction) * 100
    print(f"Predicted: {predicted_class} ({confidence:.2f}% confidence)")
    return predicted_class, confidence

In [None]:
test_class = class_names[0]
test_class_path = os.path.join(dataset_path, test_class)
test_images = [f for f in os.listdir(test_class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
if test_images:
    test_img_path = os.path.join(test_class_path, test_images[0])
    print(f"Testing on: {test_img_path}")
    predict_image(test_img_path)

## 10. Summary

In [None]:
print("\n" + "="*80)
print("PROJECT COMPLETE")
print("="*80)
print(f"\nDataset: {train_data.samples + val_data.samples} images across {num_classes} classes")
print(f"Training samples: {train_data.samples}")
print(f"Validation samples: {val_data.samples}")
print(f"\nFinal MobileNetV2 Results:")
print(f"  Validation Accuracy: {val_acc*100:.2f}%")
print(f"  Validation Loss: {val_loss:.4f}")
print(f"\nFiles Generated:")
print("  - plant_disease_mobilenetv2.h5")
print("  - class_labels.json")
print("  - class_distribution.png")
print("  - sample_images.png")
print("  - augmented_samples.png")
print("  - mobilenet_accuracy.png")
print("  - mobilenet_loss.png")
print("  - confusion_matrix.png")
print("  - per_class_accuracy.png")
print("  - sample_predictions.png")
print("  - classification_report.txt")