### Imports & Configuration

In [4]:
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import json

### Global constants

In [5]:
IMAGE_SIZE = 256
BATCH_SIZE = 32
CHANNELS = 3
EPOCHS = 1   # initial run
DATA_DIR = "PlantVillage"
MODEL_DIR = "../saved_models"
os.makedirs(MODEL_DIR, exist_ok=True)

### Load dataset

In [6]:
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_DIR,
    shuffle=True,
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE
)

class_names = dataset.class_names
n_classes = len(class_names)
print("Classes:", class_names)

# Save class names to JSON for FastAPI
with open(os.path.join(MODEL_DIR, "class_names.json"), "w") as f:
    json.dump(class_names, f)

Found 6627 files belonging to 8 classes.
Classes: ['Pepper Bell - Bacterial Spot', 'Pepper Bell - Healthy', 'Potato - Blight', 'Potato - Healthy', 'Rice - Downy Mildew', 'Rice - Healthy', 'Tomato - Healthy', 'Tomato - Leaf Mold']


### Partition dataset

In [7]:
def get_dataset_partitions_tf(ds, train_split=0.8, val_split=0.1, shuffle=True, shuffle_size=10000):
    ds_size = len(ds)
    if shuffle:
        ds = ds.shuffle(shuffle_size, seed=12)
    train_size = int(train_split * ds_size)
    val_size = int(val_split * ds_size)
    train_dataset = ds.take(train_size)
    val_dataset = ds.skip(train_size).take(val_size)
    test_dataset = ds.skip(train_size).skip(val_size)
    return train_dataset, val_dataset, test_dataset

train_ds, val_ds, test_ds = get_dataset_partitions_tf(dataset)

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

### Data Augmentation and Rescaling

In [8]:
resize_and_rescale = tf.keras.Sequential([
    layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
    layers.Rescaling(1.0/255)
])

data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
])

### Functional API Model Definition

In [9]:
# Define input explicitly
inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))

# Preprocessing + augmentation
x = layers.Resizing(IMAGE_SIZE, IMAGE_SIZE)(inputs)
x = layers.Rescaling(1.0/255)(x)
x = layers.RandomFlip("horizontal_and_vertical")(x)
x = layers.RandomRotation(0.2)(x)

# Convolutional feature extractor
x = layers.Conv2D(32, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(64, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(128, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D()(x)

# Dense classifier
x = layers.Flatten()(x)
x = layers.Dense(128, activation='relu')(x)
outputs = layers.Dense(n_classes, activation='softmax')(x)

# Build model
model = Model(inputs=inputs, outputs=outputs)

### Model Compile

In [10]:
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=['accuracy']
)

model.summary()

### Callbacks (Checkpoints + EarlyStopping)

In [11]:
checkpoint_cb = ModelCheckpoint(
    filepath=os.path.join(MODEL_DIR, "model_epoch_{epoch:02d}.keras"),
    save_freq="epoch",
    save_best_only=False
)

early_stop_cb = EarlyStopping(
    monitor="val_accuracy",
    patience=5,
    restore_best_weights=True
)

### Model Train

In [11]:
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=[checkpoint_cb, early_stop_cb],
    verbose=1
)

[1m62/62[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m152s[0m 2s/step - accuracy: 0.7906 - loss: 0.4907 - val_accuracy: 0.9330 - val_loss: 0.8202


### Evaluate

In [12]:
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc:.2f}, Test Loss: {test_loss:.2f}")

[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 268ms/step - accuracy: 0.8889 - loss: 0.2866
Test Accuracy: 0.89, Test Loss: 0.29


### Plot curves

In [None]:
acc, val_acc = history.history.get('accuracy', []), history.history.get('val_accuracy', [])
loss, val_loss = history.history.get('loss', []), history.history.get('val_loss', [])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.plot(acc, label='Train Acc'); ax1.plot(val_acc, label='Val Acc')
ax1.legend(); ax1.set_title("Accuracy")
ax2.plot(loss, label='Train Loss'); ax2.plot(val_loss, label='Val Loss')
ax2.legend(); ax2.set_title("Loss")
plt.show()

---

### Export to `.h5` file so that, we can upload to GCP conveniently

In [13]:
# Ensure directory exists
os.makedirs(MODEL_DIR, exist_ok=True)

# Always overwrite a single "latest.h5" for deployment
h5_save_path = os.path.join(MODEL_DIR, "latest_model.h5")

model.save(h5_save_path)
print(f"[INFO] Final model exported as {h5_save_path}")



[INFO] Final model exported as ../saved_models\latest_model.h5


### Retraining Helper Function

In [12]:
import os, glob, json
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import tensorflow as tf

MODEL_DIR = "../saved_models"

def resume_training(train_ds, val_ds, extra_epochs=20, class_names=None):
    """
    Resume training from the latest checkpoint for a given number of extra epochs.
    If the dataset has new classes, rebuild the final Dense layer automatically.
    Performs two-phase training: (1) train new head, (2) fine-tune full network.
    Always exports a final .h5 file and updated class_names.json for deployment.
    Also saves new .keras checkpoints during Phase 2.
    """
    # --- Find latest checkpoint ---
    checkpoints = glob.glob(os.path.join(MODEL_DIR, "model_epoch_*.keras"))
    if not checkpoints:
        raise FileNotFoundError("No checkpoints found in saved_models/")
    latest_checkpoint = max(checkpoints, key=os.path.getctime)
    last_epoch = int(os.path.basename(latest_checkpoint).split("_")[-1].split(".")[0])

    print(f"[INFO] Resuming from {latest_checkpoint} (epoch {last_epoch})")
    old_model = load_model(latest_checkpoint)

    # --- Detect current number of classes ---
    old_num_classes = old_model.layers[-1].units
    n_classes = len(class_names) if class_names else old_num_classes

    # --- Rebuild final layer if class count changed ---
    if old_num_classes != n_classes:
        print(f"[INFO] Updating final layer for {n_classes} classes (was {old_num_classes})")
        x = old_model.layers[-2].output
        new_output = Dense(n_classes, activation="softmax")(x)
        model = Model(inputs=old_model.input, outputs=new_output)

        # Phase 1: Freeze earlier layers
        for layer in old_model.layers[:-2]:
            layer.trainable = False

        model.compile(
            optimizer="adam",
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
            metrics=["accuracy"]
        )

        # Dynamic Phase 1: stop early if val_accuracy plateaus
        early_stop_phase1 = EarlyStopping(
            monitor="val_accuracy",
            patience=2,              # wait 2 epochs for improvement
            restore_best_weights=True
        )

        print("[INFO] Phase 1: Training new head only...")
        model.fit(
            train_ds,
            epochs=10,               # upper bound, but will stop earlier
            validation_data=val_ds,
            callbacks=[early_stop_phase1],
            verbose=1
        )

        # Phase 2: Unfreeze all layers for fine-tuning
        for layer in model.layers:
            layer.trainable = True

        model.compile(
            optimizer=tf.keras.optimizers.Adam(1e-5),  # lower LR for fine-tuning
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
            metrics=["accuracy"]
        )

        checkpoint_cb = ModelCheckpoint(
            filepath=os.path.join(MODEL_DIR, "model_epoch_{epoch:02d}.keras"),
            save_freq="epoch",
            save_best_only=False
        )
        early_stop_cb = EarlyStopping(
            monitor="val_accuracy",
            patience=5,
            restore_best_weights=True
        )

        print("[INFO] Phase 2: Fine-tuning full network...")
        history = model.fit(
            train_ds,
            epochs=last_epoch + extra_epochs,
            initial_epoch=last_epoch,
            validation_data=val_ds,
            callbacks=[checkpoint_cb, early_stop_cb],
            verbose=1
        )
    else:
        # --- Normal resume if class count unchanged ---
        model = old_model
        model.compile(
            optimizer="adam",
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
            metrics=["accuracy"]
        )

        checkpoint_cb = ModelCheckpoint(
            filepath=os.path.join(MODEL_DIR, "model_epoch_{epoch:02d}.keras"),
            save_freq="epoch",
            save_best_only=False
        )
        early_stop_cb = EarlyStopping(
            monitor="val_accuracy",
            patience=5,
            restore_best_weights=True
        )

        history = model.fit(
            train_ds,
            epochs=last_epoch + extra_epochs,
            initial_epoch=last_epoch,
            validation_data=val_ds,
            callbacks=[checkpoint_cb, early_stop_cb],
            verbose=1
        )

    # --- Export final .h5 file for GCP ---
    h5_export_path = os.path.join(MODEL_DIR, "latest_model.h5")
    model.save(h5_export_path)
    print(f"[INFO] Final model exported as {h5_export_path}")

    # --- Save updated class_names.json ---
    if class_names:
        class_names_path = os.path.join(MODEL_DIR, "class_names.json")
        with open(class_names_path, "w") as f:
            json.dump(class_names, f)
        print(f"[INFO] Updated class_names.json saved with {len(class_names)} classes")

    print(f"[INFO] Extended training by {extra_epochs} epochs (from {last_epoch} → {last_epoch+extra_epochs})")
    return model, history

### Model Retraining

In [20]:
# After loading dataset
class_names = dataset.class_names

# Resume training for 5 more epochs
model, history = resume_training(train_ds, val_ds, extra_epochs=10, class_names=class_names)

# Evaluate again
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Accuracy after resume: {test_acc:.2f}, Test Loss: {test_loss:.2f}")

[INFO] Resuming from ../saved_models\model_epoch_50.keras (epoch 50)
Epoch 51/60
[1m166/166[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m222s[0m 1s/step - accuracy: 0.9769 - loss: 0.0707 - val_accuracy: 0.9547 - val_loss: 0.1650
Epoch 52/60
[1m166/166[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m308s[0m 2s/step - accuracy: 0.9699 - loss: 0.0896 - val_accuracy: 0.9156 - val_loss: 0.3484
Epoch 53/60
[1m166/166[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m212s[0m 1s/step - accuracy: 0.9701 - loss: 0.0830 - val_accuracy: 0.9500 - val_loss: 0.2992
Epoch 54/60
[1m166/166[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m213s[0m 1s/step - accuracy: 0.9854 - loss: 0.0472 - val_accuracy: 0.9844 - val_loss: 0.0520
Epoch 55/60
[1m166/166[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m214s[0m 1s/step - accuracy: 0.9807 - loss: 0.0557 - val_accuracy: 0.9672 - val_loss: 0.1401
Epoch 56/60
[1m166/166[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m218s[0m 1s/step - accuracy: 0.9837 - lo



[INFO] Final model exported as ../saved_models\latest_model.h5
[INFO] Updated class_names.json saved with 8 classes
[INFO] Extended training by 10 epochs (from 50 → 60)
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 313ms/step - accuracy: 0.9886 - loss: 0.0449
Test Accuracy after resume: 0.99, Test Loss: 0.04


---

### Next...