In [3]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import numpy as np
import os


# 1. SETUP & CONSTANTS

In [None]:
TRAIN_DIR = "/Users/felix/Documents/Data Science/06_Offical_project_DS/may25_bds_plants/05_data/original_data/2.1.1 New Plant Diseases/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/train"
VALID_DIR = "/Users/felix/Documents/Data Science/06_Offical_project_DS/may25_bds_plants/05_data/original_data/2.1.1 New Plant Diseases/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/valid"

IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32

# Training parameters
INITIAL_EPOCHS = 10
FINE_TUNE_EPOCHS = 10
LEARNING_RATE = 0.001
FINE_TUNE_LEARNING_RATE = 0.00001 # 1e-5

# 2. LOAD & PREPARE DATASET

In [None]:
# Load training data from the 'train' directory
train_dataset = tf.keras.utils.image_dataset_from_directory(
    TRAIN_DIR,
    shuffle=True, 
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE
)

# Load validation data from the 'valid' directory
validation_dataset = tf.keras.utils.image_dataset_from_directory(
    VALID_DIR,
    shuffle=False, 
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE
)

# Get class names from the dataset object
class_names = train_dataset.class_names
num_classes = len(class_names)
print("Found classes:", class_names)

# Create a pre-processing layer that will be part of the model
data_augmentation_and_preprocessing = tf.keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
    layers.Lambda(preprocess_input) # Use the official MobileNetV2 preprocessing function
])

# Optimize performance by prefetching data
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)

# 3. BUILD THE MODEL 

In [None]:
def build_model(num_classes):
    base_model = MobileNetV2(input_shape=(224, 224, 3),
                             include_top=False,
                             weights='imagenet')
    base_model.trainable = False

    inputs = tf.keras.Input(shape=(224, 224, 3))
    x = data_augmentation_and_preprocessing(inputs)
    x = base_model(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = tf.keras.Model(inputs, outputs)
    return model, base_model

model, base_model = build_model(num_classes)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

print("\n--- MODEL SUMMARY (before fine-tuning) ---")
model.summary()



# 4. INITIAL TRAINING (Train only the new head)

In [None]:

print("\n--- STARTING INITIAL TRAINING (HEAD ONLY) ---")
history = model.fit(train_dataset,
                    epochs=INITIAL_EPOCHS,
                    validation_data=validation_dataset)

# 5. FINE-TUNING (Train the whole model with a low learning rate)

In [None]:

print("\n--- STARTING FINE-TUNING (UNFREEZING TOP LAYERS) ---")

base_model.trainable = True
print("Number of layers in the base model: ", len(base_model.layers))

fine_tune_at = 100
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=FINE_TUNE_LEARNING_RATE),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

print("\n--- MODEL SUMMARY (after unfreezing for fine-tuning) ---")
model.summary()

total_epochs = INITIAL_EPOCHS + FINE_TUNE_EPOCHS
history_fine_tune = model.fit(train_dataset,
                              epochs=total_epochs,
                              initial_epoch=history.epoch[-1],
                              validation_data=validation_dataset)

print("\n--- TRAINING COMPLETE ---")

Found 70295 files belonging to 38 classes.
Found 17572 files belonging to 38 classes.
Found classes: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold'


--- STARTING INITIAL TRAINING (HEAD ONLY) ---
Epoch 1/10
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m491s[0m 223ms/step - accuracy: 0.7671 - loss: 0.8640 - val_accuracy: 0.9074 - val_loss: 0.2921
Epoch 2/10
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m483s[0m 220ms/step - accuracy: 0.9189 - loss: 0.2582 - val_accuracy: 0.9342 - val_loss: 0.2052
Epoch 3/10
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m481s[0m 219ms/step - accuracy: 0.9256 - loss: 0.2249 - val_accuracy: 0.9345 - val_loss: 0.2048
Epoch 4/10
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m476s[0m 216ms/step - accuracy: 0.9313 - loss: 0.2100 - val_accuracy: 0.9347 - val_loss: 0.1977
Epoch 5/10
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m463s[0m 211ms/step - accuracy: 0.9316 - loss: 0.2036 - val_accuracy: 0.9391 - val_loss: 0.1815
Epoch 6/10
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m471s[0m 215ms/step - accuracy: 0.9334 - 

Epoch 10/20
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7626s[0m 3s/step - accuracy: 0.7913 - loss: 0.9953 - val_accuracy: 0.9381 - val_loss: 0.1926
Epoch 11/20
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1065s[0m 485ms/step - accuracy: 0.9275 - loss: 0.2177 - val_accuracy: 0.9527 - val_loss: 0.1390
Epoch 12/20
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m710s[0m 323ms/step - accuracy: 0.9495 - loss: 0.1538 - val_accuracy: 0.9615 - val_loss: 0.1088
Epoch 13/20
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m700s[0m 319ms/step - accuracy: 0.9607 - loss: 0.1156 - val_accuracy: 0.9710 - val_loss: 0.0861
Epoch 14/20
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28060s[0m 13s/step - accuracy: 0.9690 - loss: 0.0913 - val_accuracy: 0.9706 - val_loss: 0.0826
Epoch 15/20
[1m2197/2197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3840s[0m 2s/step - accuracy: 0.9744 - loss: 0.0758 - val_accuracy: 0.9751 - val_lo