1. Preprocessing & Data Augmentation


In [None]:
# 1. Preprocessing & Data Augmentation
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

# Define image dimensions (MobileNet standard)
IMG_SIZE = 224

# Path to your extracted dataset directory
dataset_path = "plant_data"  # Adjust based on your extraction path

# Verify dataset structure
print("Dataset folders:", os.listdir(dataset_path))

# Create a training data generator with MobileNetV2 preprocessing
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,  # Optimized for MobileNet
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2  # Reserve 20% of data for validation
)

# Generator for training data (80% of the dataset)
train_generator = train_datagen.flow_from_directory(
    dataset_path,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=16,  # Reduced batch size for smoother training
    class_mode="categorical",
    subset="training",
    shuffle=True  # Enable random ordering of images
)

# Create a validation data generator
val_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,  # Same preprocessing for consistency
    validation_split=0.2
)

# Generator for validation data (20% of the dataset)
val_generator = val_datagen.flow_from_directory(
    dataset_path,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=16,  # Keeping batch size consistent
    class_mode="categorical",
    subset="validation"
)

# Print discovered class names and number of categories
print("Class Indices:", train_generator.class_indices)
print("Total classes:", len(train_generator.class_indices))



2. Building a MobileNet-based Model


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

# Number of classes from dataset
num_classes = len(train_generator.class_indices)

# Load MobileNetV2 with ImageNet weights
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,  # Removing classification layers
    weights='imagenet'
)

# Freeze all layers initially
base_model.trainable = False

# Build the model architecture
inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)  # Normalization for stable training
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.3)(x)  # Dropout reduces overfitting
outputs = layers.Dense(num_classes, activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)

# Use learning rate scheduling
initial_lr = 0.001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_lr, decay_steps=10000, decay_rate=0.9
)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()


3. Training Your Model

In [None]:
from tensorflow.keras.callbacks import EarlyStopping

# Set early stopping to prevent unnecessary training if validation loss stops improving
early_stop = EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True)

# Train the model using the generators
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=10,  # Increase epochs for deeper learning
    callbacks=[early_stop],  # Enable early stopping
    verbose=1  # Show training progress
)



4.Plot Accuracy & Loss Properly

In [None]:
import matplotlib.pyplot as plt

# Plot training & validation accuracy
plt.figure(figsize=(10, 5))
plt.plot(history.history['accuracy'], label='Train Accuracy', marker='o')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy', marker='o')
plt.legend()
plt.title("Model Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.grid(True)
plt.show()

# Plot training & validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Train Loss', marker='o')
plt.plot(history.history['val_loss'], label='Validation Loss', marker='o')
plt.legend()
plt.title("Model Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.grid(True)
plt.show()
