# ðŸŒ¿ FloraGuard AI - Professional Model Training

This notebook trains a high-performance Plant Disease Detection model using **Transfer Learning** with **MobileNetV2**.

### Why this is better:
- **MobileNetV2**: A state-of-the-art architecture optimized for mobile/web speed and accuracy.
- **Transfer Learning**: Uses knowledge from millions of images (ImageNet) to understand features better.
- **Data Augmentation**: Simulates real-world conditions (rotation, lighting, zoom) to prevent overfitting.

### Instructions:
1.  Run all cells in order.
2.  The notebook will download the dataset automatically.
3.  It will train the model and save it as `model.tflite`.
4.  Download `model.tflite` and `labels.txt` and use them in your Web App.

In [None]:
# @title 1. Setup & Install Dependencies
!pip install -q tensorflow tensorflow-datasets matplotlib numpy

In [None]:
# @title 2. Download Dataset (PlantVillage)
import tensorflow_datasets as tfds
import tensorflow as tf
import os

print("Downloading PlantVillage dataset... This may take a few minutes.")

# Load PlantVillage dataset from TensorFlow Datasets
# We use 'plant_village' which contains 38 classes
(train_ds, val_ds, test_ds), metadata = tfds.load(
    'plant_village',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

num_classes = metadata.features['label'].num_classes
class_names = metadata.features['label'].names

print(f"Dataset loaded successfully!")
print(f"Total Classes: {num_classes}")
print(f"Class Names: {class_names}")

# Save labels to file
with open('labels.txt', 'w') as f:
    for name in class_names:
        f.write(name + '\n')
print("labels.txt saved.")

In [None]:
# @title 3. Preprocessing & Data Augmentation
IMG_SIZE = 224 # MobileNetV2 expects 224x224
BATCH_SIZE = 32

resize_and_rescale = tf.keras.Sequential([
  tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),
  tf.keras.layers.Rescaling(1./255)
])

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip("horizontal_and_vertical"),
  tf.keras.layers.RandomRotation(0.2),
  tf.keras.layers.RandomZoom(0.2),
  tf.keras.layers.RandomContrast(0.2),
])

def prepare(ds, shuffle=False, augment=False):
  # Resize and Rescale all datasets
  ds = ds.map(lambda x, y: (resize_and_rescale(x), y), 
              num_parallel_calls=tf.data.AUTOTUNE)

  if shuffle:
    ds = ds.shuffle(1000)

  # Batch all datasets
  ds = ds.batch(BATCH_SIZE)

  # Use data augmentation only on the training set
  if augment:
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                num_parallel_calls=tf.data.AUTOTUNE)

  # Use buffered prefetching on all datasets
  return ds.prefetch(buffer_size=tf.data.AUTOTUNE)

train_ds = prepare(train_ds, shuffle=True, augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)

In [None]:
# @title 4. Build Model (MobileNetV2)
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,
    weights='imagenet'
)

base_model.trainable = False # Freeze base model

model = tf.keras.Sequential([
  base_model,
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(
  optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy']
)

model.summary()

In [None]:
# @title 5. Train Model
epochs = 10 # Start with 10, increase if needed

history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

In [None]:
# @title 6. Fine-Tuning (Optional but Recommended)
print("Unfreezing base model layers for fine-tuning...")
base_model.trainable = True

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable = False

model.compile(
  optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.00001),  # Lower learning rate
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy']
)

fine_tune_epochs = 5
total_epochs =  epochs + fine_tune_epochs

history_fine = model.fit(
  train_ds,
  validation_data=val_ds,
  initial_epoch=history.epoch[-1],
  epochs=total_epochs
)

In [None]:
# @title 7. Convert to TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

print("Model converted and saved as 'model.tflite'")

In [None]:
# @title 8. Download Files
from google.colab import files
files.download('model.tflite')
files.download('labels.txt')