In [1]:
import tensorflow as tf
from tensorflow.keras import models, layers
from tensorflow.keras.callbacks import EarlyStopping
import os

# --- Configuration Parameters ---
BATCH_SIZE = 32
IMAGE_SIZE = 256
CHANNELS = 3 # RGB images
EPOCHS = 50 # Keep a higher number, EarlyStopping will manage it

# --- 1. Load Dataset ---
# Ensure your 'PlantVillage' directory is correctly structured:
# PlantVillage/
# ├── Potato___Early_blight/
# │   ├── image1.jpg
# │   └── ...
# ├── Potato___Late_blight/
# │   ├── imageA.jpg
# │   └── ...
# └── Potato___healthy/
#     ├── imageX.jpg
#     └── ...

print("Loading dataset...")
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    "PlantVillage",
    seed=123,
    shuffle=True,
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE
)

class_names = dataset.class_names
print(f"Detected class names: {class_names}")
print(f"Number of batches in dataset: {len(dataset)}")

# --- 2. Split Dataset into Training, Validation, and Test Sets ---
# It's crucial to have a separate validation set for EarlyStopping
# and a test set for final, unbiased evaluation.

# Calculate the size of each split
dataset_size = tf.data.experimental.cardinality(dataset).numpy() * BATCH_SIZE
train_size = int(0.8 * dataset_size) # 80% for training
val_size = int(0.1 * dataset_size)   # 10% for validation
test_size = dataset_size - train_size - val_size # Remaining 10% for testing

print(f"Total dataset size: {dataset_size} images")
print(f"Training set size: {train_size} images")
print(f"Validation set size: {val_size} images")
print(f"Test set size: {test_size} images")

# Take batches for each split
train_ds = dataset.take(int(train_size / BATCH_SIZE))
val_ds = dataset.skip(int(train_size / BATCH_SIZE)).take(int(val_size / BATCH_SIZE))
test_ds = dataset.skip(int(train_size / BATCH_SIZE) + int(val_size / BATCH_SIZE))

# Optimize dataset loading for performance
# `cache()` keeps images in memory after first epoch
# `prefetch()` overlaps data preprocessing and model execution
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# --- 3. Data Augmentation and Preprocessing Layers ---
# These layers are added directly into the model for consistency
# and to ensure they are saved with the model for inference.

resize_and_rescale = tf.keras.Sequential([
  layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
  layers.Rescaling(1./255) # Normalize pixel values to [0, 1]
])

# Data augmentation layers (applied only during training)
data_augmentation = tf.keras.Sequential([
  layers.RandomFlip("horizontal_and_vertical"),
  layers.RandomRotation(0.2),
  layers.RandomZoom(0.2),
  layers.RandomContrast(0.2),
  layers.RandomBrightness(0.2),
])

# --- 4. Build the Model ---
# Using a more robust CNN architecture with Dropout for regularization.

num_classes = len(class_names)

model = models.Sequential([
    # Preprocessing and Augmentation as the first layers
    resize_and_rescale,
    data_augmentation, # Apply augmentation only during training

    # Convolutional Block 1
    layers.Conv2D(32, (3,3), activation='relu', input_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS)),
    layers.MaxPooling2D((2,2)),
    layers.Dropout(0.25), # Added Dropout

    # Convolutional Block 2
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Dropout(0.25), # Added Dropout

    # Convolutional Block 3
    layers.Conv2D(128, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Dropout(0.25), # Added Dropout

    # Flatten the output for the Dense layers
    layers.Flatten(),

    # Dense Layers
    layers.Dense(256, activation='relu'), # Increased neurons
    layers.Dropout(0.5), # Higher Dropout for the dense layer
    layers.Dense(num_classes, activation='softmax') # Output layer with softmax
])

# --- 5. Compile the Model ---
# Using Adam optimizer and SparseCategoricalCrossentropy for integer labels.

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), # from_logits=False because softmax outputs probabilities
    metrics=['accuracy']
)

model.build(input_shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
model.summary()

# --- 6. Train the Model ---
# EarlyStopping callback to prevent overfitting.
# It monitors 'val_accuracy' and stops if it doesn't improve for 'patience' epochs.

early_stopping_callback = EarlyStopping(
    monitor='val_accuracy',
    patience=10, # Number of epochs with no improvement after which training will be stopped.
    restore_best_weights=True # Restores model weights from the epoch with the best value of the monitored quantity.
)

print("\nStarting model training...")
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=val_ds,
    callbacks=[early_stopping_callback]
)

# --- 7. Evaluate the Model on the Test Set ---
print("\nEvaluating model on test set...")
loss, accuracy = model.evaluate(test_ds)
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")

# --- 8. Export the Model for Serving ---
# Auto-increment versioning for serving
os.makedirs("exported_models", exist_ok=True)
existing_versions = [int(d) for d in os.listdir("exported_models") if d.isdigit()]
model_version = max(existing_versions + [0]) + 1
export_path = f"exported_models/{model_version}"

# Save the model in SavedModel format
# The preprocessing and augmentation layers are now part of the saved model.
print(f"\nExporting model to: {export_path}")
model.export(export_path)
print("Model exported successfully!")



Loading dataset...
Found 2152 files belonging to 3 classes.
Detected class names: ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy']
Number of batches in dataset: 68
Total dataset size: 2176 images
Training set size: 1740 images
Validation set size: 217 images
Test set size: 219 images


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)



Starting model training...
Epoch 1/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 2s/step - accuracy: 0.4546 - loss: 88.9989 - val_accuracy: 0.4896 - val_loss: 1.0122
Epoch 2/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 1s/step - accuracy: 0.4633 - loss: 0.9621 - val_accuracy: 0.4479 - val_loss: 0.9693
Epoch 3/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 1s/step - accuracy: 0.4437 - loss: 0.9658 - val_accuracy: 0.4479 - val_loss: 0.8924
Epoch 4/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 1s/step - accuracy: 0.4300 - loss: 0.9000 - val_accuracy: 0.4479 - val_loss: 0.8701
Epoch 5/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 1s/step - accuracy: 0.4625 - loss: 0.8999 - val_accuracy: 0.4479 - val_loss: 0.8670
Epoch 6/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 1s/step - accuracy: 0.4544 - loss: 0.8972 - val_accuracy: 0.4479 - val_loss: 0.8651
Epoch 7/50


INFO:tensorflow:Assets written to: exported_models/4\assets


Saved artifact at 'exported_models/4'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 3), dtype=tf.float32, name=None)
Captures:
  2146632591920: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146632594032: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146632595616: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146632597200: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146632730048: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146633051552: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146633052432: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146633051904: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146633048736: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146633054192: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2146633052080: TensorSpec(shape=(