# Implementation of Mobile-based CNN (TensorFlow/Keras) for Detecting Rice Leaf Diseases

This notebook is a **full TensorFlow/Keras rewrite** of the original PyTorch notebook.
It covers: data loading with `tf.data`, transfer learning (MobileNetV2/EfficientNet), training,
evaluation with confusion matrix, and **TFLite conversion** for mobile deployment.


In [None]:
# Core Python
import sys, platform, itertools
from pathlib import Path

# Numerical & Data Handling
import numpy as np

# Visualization
import matplotlib.pyplot as plt

# TensorFlow / Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# Metrics / Evaluation
from sklearn.metrics import confusion_matrix, classification_report

# Image Handling
from PIL import Image

# Albumentations
import albumentations as A


## 2) Environment Check

In [None]:
print("Python:", sys.version)
print("TensorFlow:", tf.__version__)
print("Platform:", platform.platform())
print("GPU(s):", tf.config.list_physical_devices('GPU'))

## 3) Configuration

In [None]:
# Paths
DATA_DIR = Path("data")   # Change this to your dataset root folder
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparameters
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 20
VAL_SPLIT = 0.2
SEED = 42



## 4) Data Loading

In [None]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_DIR,
    validation_split=VAL_SPLIT,
    subset="training",
    seed=SEED,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_DIR,
    validation_split=VAL_SPLIT,
    subset="validation",
    seed=SEED,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

class_names = train_ds.class_names
num_classes = len(class_names)
print("Classes:", class_names)


## 5) Data Augmentation

In [None]:
# Define Albumentations pipeline
albumentations_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.RandomResizedCrop(height=IMG_SIZE[0], width=IMG_SIZE[1], scale=(0.8, 1.0), p=0.5),
])

## 6) Apply Albumentations in tf.data

In [None]:
def albumentations_augment(image, label):
    image = image.numpy()
    augmented = albumentations_transform(image=image)
    image = augmented["image"]
    image = tf.convert_to_tensor(image, dtype=tf.float32)
    return image, label

def augment_with_albumentations(image, label):
    aug_img, aug_label = tf.py_function(
        albumentations_augment, [image, label], [tf.float32, label.dtype]
    )
    aug_img.set_shape((IMG_SIZE[0], IMG_SIZE[1], 3))
    return aug_img, aug_label

# Apply only to training dataset
train_ds = train_ds.map(augment_with_albumentations, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(1000).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.AUTOTUNE)

## 7) Visualize Augmented Sample

In [None]:
for images, labels in train_ds.take(1):
    plt.figure(figsize=(8, 8))
    for i in range(min(9, images.shape[0])):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[int(labels[i])])
        plt.axis("off")
    plt.show()


## 8) Model (MobileNetV2 Backbone)

In [None]:
INPUT_SHAPE = IMG_SIZE + (3,)

base = keras.applications.MobileNetV2(
    input_shape=INPUT_SHAPE,
    include_top=False,
    weights="imagenet"
)
base.trainable = False

inputs = keras.Input(shape=INPUT_SHAPE)
x = keras.applications.mobilenet_v2.preprocess_input(inputs)
x = base(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)

model = keras.Model(inputs, outputs, name="riceleaf_mobilenetv2")
model.summary()

## 9) Compile & Callbacks

In [None]:
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

checkpoint_path = str((OUTPUT_DIR / "best_model.keras").resolve())

callbacks = [
    EarlyStopping(monitor="val_accuracy", patience=8, restore_best_weights=True),
    ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=4, min_lr=1e-6, verbose=1),
    ModelCheckpoint(checkpoint_path, monitor="val_accuracy", save_best_only=True, verbose=1),
]


## 8) Callbacks

## 10) Train

In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks
)

## 11) Plot Curves

In [None]:
acc = history.history.get("accuracy", [])
val_acc = history.history.get("val_accuracy", [])
loss = history.history.get("loss", [])
val_loss = history.history.get("val_loss", [])

plt.figure()
plt.plot(acc, label="train_acc")
plt.plot(val_acc, label="val_acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Accuracy")

plt.figure()
plt.plot(loss, label="train_loss")
plt.plot(val_loss, label="val_loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss")
plt.show()


## 13) Evaluation 

In [None]:
# Collect predictions on val_ds
y_true = []
y_pred = []
for batch_imgs, batch_labels in val_ds:
    preds = model.predict(batch_imgs, verbose=0)
    y_true.extend(batch_labels.numpy().tolist())
    y_pred.extend(np.argmax(preds, axis=1).tolist())

y_true = np.array(y_true)
y_pred = np.array(y_pred)

cm = confusion_matrix(y_true, y_pred)
print(classification_report(y_true, y_pred, target_names=class_names))

# Plot CM
import itertools
plt.figure(figsize=(6,6))
plt.imshow(cm, interpolation='nearest')
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45, ha="right")
plt.yticks(tick_marks, class_names)

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.show()


## 13) Save Keras Model

In [None]:
saved_model_dir = OUTPUT_DIR / "saved_model"
saved_model_dir.mkdir(exist_ok=True, parents=True)
model.save(saved_model_dir, include_optimizer=False)
print("SavedModel ->", saved_model_dir.resolve())


## 14) TFLite Conversion (Float16 & INT8 options)

In [None]:
tflite_float16 = OUTPUT_DIR / "model_float16.tflite"
tflite_int8 = OUTPUT_DIR / "model_int8.tflite"

# Float16 quantization (good trade-off for mobile GPUs/NPUs)
converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
tflite_float16.write_bytes(tflite_model)
print("Wrote:", tflite_float16.resolve())

# Full INT8 quantization (requires a calibration dataset)
def representative_data_gen():
    for images, _ in train_ds.take(50):  # 50 batches for calibration
        yield [tf.cast(images, tf.float32)]

converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_int8_model = converter.convert()
tflite_int8.write_bytes(tflite_int8_model)
print("Wrote:", tflite_int8.resolve())


## 15) Single Image Inference Demo

In [None]:
def load_image(path, target_size=IMG_SIZE):
    img = Image.open(path).convert("RGB").resize(target_size, Image.BILINEAR)
    arr = np.array(img, dtype=np.float32)
    arr = tf.keras.applications.mobilenet_v2.preprocess_input(arr)
    return np.expand_dims(arr, 0)

# Example:
# test_path = DATA_DIR / class_names[0] / "some_image.jpg"
# x = load_image(test_path)
# preds = model.predict(x)
# print("Pred:", class_names[int(np.argmax(preds))], preds.max())


## 16) Notes on Porting from PyTorch → TensorFlow

- `torch.utils.data.Dataset/DataLoader` → `tf.data` via `image_dataset_from_directory`.
- Manual training loops (`optimizer.zero_grad()`, `loss.backward()`, `optimizer.step()`) → `model.fit()` with callbacks.
- `torchvision.models` (e.g., ResNet, MobileNet) → `tf.keras.applications` equivalents.
- Saving: `torch.save(state_dict)` → `model.save(SavedModel)` and TFLite export.
- Device: `with torch.cuda.amp.autocast()` → mixed precision via `tf.keras.mixed_precision` (optional).
