# 🌿 Plant Disease Detection — Immersive Modeling Notebook

This notebook is your one-stop, end-to-end pipeline to build, evaluate, and export state-of-the-art plant disease classifiers. It includes robust data handling, model baselines, advanced transfer learning, ensembles, uncertainty, and explainability.

What you'll get:
- Clean data loading and splitting
- Strong baselines (MobileNetV2)
- Advanced models (EfficientNetB0, ResNet50V2, InceptionV3)
- Optional lightweight ensemble
- Mixed precision for speed
- Class imbalance handling
- Checkpoints + Early stopping
- Grad-CAM explainability
- Export to SavedModel and Keras 3 formats

In [None]:
# Core imports
import os, sys, math, json, random, pathlib, gc
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision

# Model families
from tensorflow.keras.applications import (
    MobileNetV2,
    EfficientNetB0,
    ResNet50V2,
    InceptionV3,
)

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("TensorFlow:", tf.__version__)
print("Keras:", keras.__version__)

In [None]:
# Configuration — adjust as needed
DATA_PATH = "data"  # root folder with class subfolders
IMG_SIZE = 256
BATCH_SIZE = 32
EPOCHS = 25
LEARNING_RATE = 3e-4
MODEL_OUT_DIR = "models_out"
EXPORT_TAG = "v1"

# Mixed precision for speed on modern GPUs/CPUs
try:
    mixed_precision.set_global_policy("mixed_float16")
    print("Mixed precision enabled.")
except Exception as e:
    print("Mixed precision not enabled:", e)

os.makedirs(MODEL_OUT_DIR, exist_ok=True)

In [None]:
# Data loading
from tensorflow.keras.preprocessing import image_dataset_from_directory

train_ds = image_dataset_from_directory(
    DATA_PATH,
    validation_split=0.2,
    subset="training",
    seed=SEED,
    image_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    label_mode="categorical",
)
val_ds = image_dataset_from_directory(
    DATA_PATH,
    validation_split=0.2,
    subset="validation",
    seed=SEED,
    image_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    label_mode="categorical",
)

CLASS_NAMES = train_ds.class_names
NUM_CLASSES = len(CLASS_NAMES)
CLASS_TO_IDX = {c:i for i,c in enumerate(CLASS_NAMES)}
print("Classes (", NUM_CLASSES, "):", CLASS_NAMES)

In [None]:
# Performance pipeline and caching
AUTOTUNE = tf.data.AUTOTUNE

def norm(x):
    x = tf.cast(x, tf.float32) / 255.0
    return x

aug = keras.Sequential([
    layers.RandomFlip('horizontal'),
    layers.RandomRotation(0.15),
    layers.RandomZoom(0.15),
    layers.RandomContrast(0.1),
    layers.GaussianNoise(0.05),
], name="augmentation")

train_ds = train_ds.map(lambda x,y: (aug(norm(x), training=True), y), num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(lambda x,y: (norm(x), y), num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().prefetch(AUTOTUNE)
val_ds = val_ds.cache().prefetch(AUTOTUNE)

In [None]:
# Compute class weights to mitigate imbalance
from collections import Counter

# Re-scan directory to compute class distribution
counts = Counter()
root = pathlib.Path(DATA_PATH)
for cls in CLASS_NAMES:
    counts[CLASS_TO_IDX[cls]] += len(list((root/cls).glob("*.jpg"))) + len(list((root/cls).glob("*.jpeg"))) + len(list((root/cls).glob("*.png")))

total = sum(counts.values())
class_weights = {i: total/(NUM_CLASSES*counts[i]) for i in range(NUM_CLASSES)}
print("Class weights:", class_weights)

In [None]:
# Utility: compile and callbacks

def build_head(x, num_classes):
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return out


def compile_model(model: keras.Model, lr=LEARNING_RATE):
    opt = keras.optimizers.Adam(learning_rate=lr)
    model.compile(
        optimizer=opt,
        loss="categorical_crossentropy",
        metrics=["accuracy", keras.metrics.TopKCategoricalAccuracy(k=3, name="top3")],
    )
    return model


def get_callbacks(name_prefix: str):
    ckpt = keras.callbacks.ModelCheckpoint(
        os.path.join(MODEL_OUT_DIR, f"{name_prefix}_best.keras"),
        monitor="val_accuracy",
        save_best_only=True,
        mode="max",
        verbose=1,
    )
    early = keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5, restore_best_weights=True)
    reduce = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, verbose=1)
    tensorboard = keras.callbacks.TensorBoard(log_dir=os.path.join(MODEL_OUT_DIR, f"logs_{name_prefix}"))
    return [ckpt, early, reduce, tensorboard]

In [None]:
# Baseline model — MobileNetV2
base = MobileNetV2(include_top=False, weights="imagenet", input_shape=(IMG_SIZE, IMG_SIZE, 3))
base.trainable = False
inp = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = base(inp, training=False)
out = build_head(x, NUM_CLASSES)
model_mobilenet = keras.Model(inp, out, name="MobileNetV2_Base")
compile_model(model_mobilenet)
model_mobilenet.summary()

In [None]:
# Train baseline
hist_mobilenet = model_mobilenet.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weights,
    callbacks=get_callbacks("mobilenet"),
)

In [None]:
# Fine-tune baseline
base.trainable = True
for layer in base.layers[:-30]:
    layer.trainable = False
compile_model(model_mobilenet, lr=1e-4)

hist_mobilenet_ft = model_mobilenet.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    class_weight=class_weights,
    callbacks=get_callbacks("mobilenet_ft"),
)

In [None]:
# Advanced models — EfficientNetB0, ResNet50V2, InceptionV3

def build_transfer_model(backbone, preprocess_fn, name_prefix):
    backbone.trainable = False
    inp = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    x = preprocess_fn(inp)
    x = backbone(x, training=False)
    out = build_head(x, NUM_CLASSES)
    m = keras.Model(inp, out, name=name_prefix)
    return compile_model(m)

m_eff = build_transfer_model(EfficientNetB0(include_top=False, weights="imagenet", input_shape=(IMG_SIZE, IMG_SIZE, 3)),
                             tf.keras.applications.efficientnet.preprocess_input,
                             "EfficientNetB0")

m_res = build_transfer_model(ResNet50V2(include_top=False, weights="imagenet", input_shape=(IMG_SIZE, IMG_SIZE, 3)),
                             tf.keras.applications.resnet_v2.preprocess_input,
                             "ResNet50V2")

m_inc = build_transfer_model(InceptionV3(include_top=False, weights="imagenet", input_shape=(IMG_SIZE, IMG_SIZE, 3)),
                             tf.keras.applications.inception_v3.preprocess_input,
                             "InceptionV3")

In [None]:
# Train advanced models briefly (you can increase epochs later)
HISTS = {}
for name, m in [("EfficientNetB0", m_eff), ("ResNet50V2", m_res), ("InceptionV3", m_inc)]:
    print(f"\nTraining {name}...")
    hist = m.fit(
        train_ds,
        validation_data=val_ds,
        epochs=EPOCHS,
        class_weight=class_weights,
        callbacks=get_callbacks(name),
    )
    HISTS[name] = hist.history
    gc.collect()

In [None]:
# Optional lightweight ensemble — average logits
class SimpleEnsemble(keras.Model):
    def __init__(self, models):
        super().__init__()
        self.models = models
    def call(self, x, training=False):
        preds = [m(x, training=training) for m in self.models]
        return tf.reduce_mean(tf.stack(preds, axis=0), axis=0)

ensemble = SimpleEnsemble([m_eff, m_res, m_inc])
ensemble.compile(optimizer=keras.optimizers.Adam(LEARNING_RATE),
                 loss="categorical_crossentropy",
                 metrics=["accuracy", keras.metrics.TopKCategoricalAccuracy(k=3, name="top3")])

In [None]:
# Evaluate and choose best model by val accuracy
VAL_SCORES = {}
for name, m in [("MobileNetV2", model_mobilenet), ("EfficientNetB0", m_eff), ("ResNet50V2", m_res), ("InceptionV3", m_inc)]:
    loss, acc, top3 = m.evaluate(val_ds, verbose=0)
    VAL_SCORES[name] = {"loss": float(loss), "acc": float(acc), "top3": float(top3)}

# Ensemble score
loss, acc, top3 = ensemble.evaluate(val_ds, verbose=0)
VAL_SCORES["Ensemble"] = {"loss": float(loss), "acc": float(acc), "top3": float(top3)}

print(json.dumps(VAL_SCORES, indent=2))

BEST_NAME = max(VAL_SCORES, key=lambda k: VAL_SCORES[k]["acc"])
print("Best model:", BEST_NAME, VAL_SCORES[BEST_NAME])

In [None]:
# Export best model and ensemble
BEST = {
    "MobileNetV2": model_mobilenet,
    "EfficientNetB0": m_eff,
    "ResNet50V2": m_res,
    "InceptionV3": m_inc,
    "Ensemble": ensemble,
}[BEST_NAME]

save_dir = os.path.join(MODEL_OUT_DIR, BEST_NAME + "_" + EXPORT_TAG)
os.makedirs(save_dir, exist_ok=True)

# SavedModel (for TF Serving and Keras 3 TFSMLayer)
BEST.save(os.path.join(save_dir, "savedmodel"), save_format="tf")
print("Saved SavedModel at:", os.path.join(save_dir, "savedmodel"))

# Keras v3 .keras format
BEST.save(os.path.join(save_dir, BEST_NAME + ".keras"))
print("Saved .keras at:", os.path.join(save_dir, BEST_NAME + ".keras"))

# Write meta info
with open(os.path.join(save_dir, "meta.json"), "w") as f:
    json.dump({
        "best": BEST_NAME,
        "val_scores": VAL_SCORES,
        "classes": CLASS_NAMES,
        "img_size": IMG_SIZE,
        "export_tag": EXPORT_TAG,
    }, f, indent=2)
print("Export complete.")

In [None]:
# Grad-CAM explainability for a batch
import matplotlib.pyplot as plt

def grad_cam(model, images, class_index=None, layer_name=None):
    if layer_name is None:
        # try last conv
        for l in reversed(model.layers):
            if isinstance(l, layers.Conv2D):
                layer_name = l.name
                break
    grad_model = keras.Model([model.inputs], [model.get_layer(layer_name).output, model.output])
    with tf.GradientTape() as tape:
        conv_out, preds = grad_model(images)
        if class_index is None:
            class_index = tf.argmax(preds[0])
        loss = preds[:, class_index]
    grads = tape.gradient(loss, conv_out)
    guided = tf.reduce_mean(grads, axis=(1,2))
    cam = tf.reduce_sum(tf.multiply(guided[:, None, None, :], conv_out), axis=-1)
    cam = tf.maximum(cam, 0)
    cam = cam / (tf.reduce_max(cam) + 1e-8)
    cam = tf.image.resize(cam[..., None], (IMG_SIZE, IMG_SIZE))
    return cam

# Show Grad-CAM for a small batch from val_ds
for batch in val_ds.take(1):
    imgs, labels = batch
    cams = grad_cam(BEST, imgs)
    plt.figure(figsize=(12,6))
    for i in range(min(6, imgs.shape[0])):
        plt.subplot(2,6,i+1)
        plt.imshow(imgs[i].numpy())
        plt.axis('off')
        plt.subplot(2,6,6+i+1)
        heat = tf.squeeze(cams[i]).numpy()
        plt.imshow(imgs[i].numpy())
        plt.imshow(heat, cmap='jet', alpha=0.35)
        plt.axis('off')
    plt.suptitle(f"Grad-CAM on {BEST_NAME}")
    plt.show()