In [None]:
# ================================
# Plant Disease Classification Project
# Single-file Python implementation
# ================================

# -------- Imports --------
import os, json, random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
import kagglehub

# -------- Reproducibility --------
random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)

# -------- Download Dataset --------
print("Downloading PlantVillage dataset...")
path = kagglehub.dataset_download("abdallahalidev/plantvillage-dataset")

# Locate dataset root
dataset_root = None
for root, dirs, files in os.walk(path):
    if "plantvillage dataset" in root:
        dataset_root = root
        break

if dataset_root is None:
    raise FileNotFoundError("PlantVillage dataset folder not found")

color_dir = os.path.join(dataset_root, "color")
print("Dataset loaded from:", color_dir)

# -------- Image Parameters --------
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 5

# -------- Data Generators --------
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

train_gen = datagen.flow_from_directory(
    color_dir,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    subset="training",
    class_mode="categorical",
    shuffle=True
)

val_gen = datagen.flow_from_directory(
    color_dir,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    subset="validation",
    class_mode="categorical",
    shuffle=False
)

NUM_CLASSES = train_gen.num_classes
print("Number of classes:", NUM_CLASSES)

# -------- CNN Model --------
model = models.Sequential([
    layers.Conv2D(32, (3,3), activation="relu", input_shape=(IMG_SIZE, IMG_SIZE, 3)),
    layers.MaxPooling2D(2,2),

    layers.Conv2D(64, (3,3), activation="relu"),
    layers.MaxPooling2D(2,2),

    layers.Flatten(),
    layers.Dense(256, activation="relu"),
    layers.Dense(NUM_CLASSES, activation="softmax")
])

model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

model.summary()

# -------- Train Model --------
history = model.fit(
    train_gen,
    epochs=EPOCHS,
    validation_data=val_gen
)

# -------- Accuracy & Loss Plots --------
plt.figure()
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.title("Training vs Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend(["Train", "Validation"])
plt.show()

plt.figure()
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Training vs Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Train", "Validation"])
plt.show()

# -------- Evaluation --------
val_gen.reset()
pred_probs = model.predict(val_gen)
y_pred = np.argmax(pred_probs, axis=1)
y_true = val_gen.classes
class_names = list(val_gen.class_indices.keys())

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(10,10))
disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
disp.plot(include_values=False, cmap="Blues", xticks_rotation=90)
plt.title("Confusion Matrix")
plt.show()

print("\nClassification Report:\n")
print(classification_report(y_true, y_pred, target_names=class_names))

# -------- Sample Predictions --------
images, labels = next(val_gen)
preds = model.predict(images)

plt.figure(figsize=(12,7))
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.imshow(images[i])
    pred_class = class_names[np.argmax(preds[i])]
    true_class = class_names[np.argmax(labels[i])]
    plt.title(f"P: {pred_class}\nT: {true_class}", fontsize=9)
    plt.axis("off")
plt.tight_layout()
plt.show()

# -------- Save Model & Classes --------
model.save("plant_disease_cnn.h5")

class_indices = {v:k for k,v in train_gen.class_indices.items()}
json.dump(class_indices, open("class_indices.json","w"), indent=2)

print("Model saved as plant_disease_cnn.h5")
print("Class indices saved as class_indices.json")

# -------- Single Image Prediction --------
def predict_image(image_path):
    img = Image.open(image_path).convert("RGB")
    img = img.resize((IMG_SIZE, IMG_SIZE))
    img = np.array(img).astype("float32") / 255.0
    img = np.expand_dims(img, axis=0)

    pred = model.predict(img)
    return class_indices[np.argmax(pred)]

# Example (change image path)
# test_image = "test_leaf.jpg"
# print("Predicted Disease:", predict_image(test_image))
