In [None]:
# @title 1. Setup Environment and Imports
# Make sure you are running this in a Colab environment with GPU enabled (Runtime -> Change runtime type)

# Clone the repository (if not already mounted via Google Drive)
# This ensures you have access to your src/ and scripts/ directories
# !git clone https://github.com/UOM-AgroAI-Project/agro-ai-copilot.git
# %cd agro-ai-copilot/module1-edge-ai

# For now, let's assume we navigate to the right directory if running from drive
# or manually clone/cd.

# Mount Google Drive (Recommended for persistent data storage)
from google.colab import drive
drive.mount('/content/drive')

# Define project paths - ADJUST THESE BASED ON YOUR GOOGLE DRIVE SETUP!
# Assuming 'agro-ai-copilot' is cloned/placed under MyDrive
project_root = '/content/drive/MyDrive/agro-ai-copilot/module1-edge-ai'
# Fallback for temporary Colab runtime if repo is not on Drive
if not os.path.exists(project_root):
    project_root = '/content/module1-edge-ai' # Local Colab path if not mounted/cloned to drive
    # If running this path, you might need to manually copy src/ and scripts/
    # or clone the repo into /content/agro-ai-copilot and adjust accordingly.

# Add project_root to Python path to import from src/
import sys
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Install dependencies
%cd {project_root}
!pip install -r requirements.txt
!pip install split-folders tqdm # For data splitting notebook, if not already
!pip install opencv-python-headless # Useful for some image processing

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt

# Import your custom modules
from src.data_utils import create_tf_dataset, get_class_names, prepare_dataset
from src.models import build_fp32_efficientnet_model, IMG_SIZE
from src.loss_functions import WeightedFocalLoss

print("Environment setup and imports complete.")
print(f"TensorFlow version: {tf.__version__}")
print(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")

In [None]:
# @title 2. Data Loading and Preparation

# Define paths to your PlantVillage subset generated by mvp_data_prep.ipynb
# Ensure mvp_data_prep.ipynb has been run and created this structure!
data_base_dir = os.path.join(project_root, 'data', 'PlantVillage_Subset')

train_data_dir = os.path.join(data_base_dir, 'train')
val_data_dir = os.path.join(data_base_dir, 'val')
test_data_dir = os.path.join(data_base_dir, 'test')

if not os.path.exists(train_data_dir):
    print(f"Error: Training data directory not found at {train_data_dir}")
    print("Please ensure 'mvp_data_prep.ipynb' has been run successfully and paths are correct.")
else:
    print(f"Loading data from: {data_base_dir}")

BATCH_SIZE = 32 # Adjust based on GPU memory
IMG_HEIGHT, IMG_WIDTH = IMG_SIZE, IMG_SIZE # From models.py

# Create datasets
train_ds_raw = create_tf_dataset(train_data_dir, (IMG_HEIGHT, IMG_WIDTH), BATCH_SIZE, shuffle=True)
val_ds_raw = create_tf_dataset(val_data_dir, (IMG_HEIGHT, IMG_WIDTH), BATCH_SIZE, shuffle=False)
test_ds_raw = create_tf_dataset(test_data_dir, (IMG_HEIGHT, IMG_WIDTH), BATCH_SIZE, shuffle=False) # For final eval

# Get class names
class_names = get_class_names(train_data_dir)
num_classes = len(class_names)
print(f"Found {num_classes} classes: {class_names}")

# Apply preprocessing and augmentation
train_ds = prepare_dataset(train_ds_raw, IMG_HEIGHT, IMG_WIDTH, augment=True)
val_ds = prepare_dataset(val_ds_raw, IMG_HEIGHT, IMG_WIDTH, augment=False)
test_ds = prepare_dataset(test_ds_raw, IMG_HEIGHT, IMG_WIDTH, augment=False)

print("Datasets created and prepared.")

# Optional: Visualize a batch
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8")) # Need uint8 for imshow
        plt.title(class_names[labels[i]])
        plt.axis("off")
plt.show()

In [None]:
# @title 3. Model Definition and Compilation

model = build_fp32_efficientnet_model(num_classes)
model.summary()

# Define alpha weights for Weighted Focal Loss (MVP: just a placeholder example)
# In a real scenario, you'd calculate these based on your actual class frequencies.
# For MVP, let's just make a dummy array.
dummy_alpha = np.ones(num_classes) * 0.5 # Example: all classes get 0.5 weight initially
# For actual weighted focal loss, you'd calculate:
# class_counts = np.bincount(tf.concat([labels for _, labels in train_ds_raw.unbatch()], axis=0).numpy())
# total_samples = np.sum(class_counts)
# class_frequencies = class_counts / total_samples
# inverse_frequencies = 1.0 / class_frequencies
# alpha_weights = inverse_frequencies / np.sum(inverse_frequencies) # Normalize to sum to 1
# You might want to smooth these or cap extreme values.

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=WeightedFocalLoss(gamma=2.0, alpha=dummy_alpha), # Use your custom loss
    metrics=['accuracy']
)

print("Model defined and compiled.")

In [None]:
# @title 4. Model Training

EPOCHS = 10 # Start with a small number for MVP to quickly see results

# Define callbacks
# ModelCheckpoint to save the best model weights
checkpoint_filepath = os.path.join(project_root, 'trained_models', 'fp32_mvp_best_model.h5')
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=False, # Save the entire model
    monitor='val_accuracy',
    mode='max',
    save_best_only=True,
    verbose=1
)

# EarlyStopping to prevent overfitting
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3, # Number of epochs with no improvement after which training will be stopped.
    restore_best_weights=True,
    verbose=1
)

print(f"Starting training for {EPOCHS} epochs...")
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=[model_checkpoint_callback, early_stopping_callback]
)

print("Training complete. Model saved to:", checkpoint_filepath)

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# @title 5. Basic Model Evaluation on Test Set

print("\n--- Evaluating MVP Model on Test Set ---")
# Load the best saved model for evaluation
best_model = tf.keras.models.load_model(
    checkpoint_filepath,
    custom_objects={'WeightedFocalLoss': WeightedFocalLoss} # Needed if your custom loss is not saved by name
)
loss, accuracy = best_model.evaluate(test_ds)
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")

# Save class names for later use (e.g., in inference)
class_names_path = os.path.join(project_root, 'data', 'class_names.txt')
with open(class_names_path, 'w') as f:
    for name in class_names:
        f.write(f"{name}\n")
print(f"Class names saved to: {class_names_path}")

In [None]:
# @title 6. Mock Inference Example with Saved Model

# Get one image from the test set for demonstration
for images, labels in test_ds.take(1):
    sample_image = images[0]
    sample_label = labels[0]
    break

# Add a batch dimension to the single image
sample_image_batch = tf.expand_dims(sample_image, axis=0)

# Make prediction
predictions = best_model.predict(sample_image_batch)
predicted_class_idx = np.argmax(predictions[0])
predicted_confidence = np.max(predictions[0])

print(f"\n--- Mock Inference ---")
print(f"True Class: {class_names[sample_label.numpy()]}")
print(f"Predicted Class: {class_names[predicted_class_idx]}")
print(f"Predicted Confidence: {predicted_confidence:.4f}")

plt.imshow(sample_image.numpy().astype("uint8"))
plt.title(f"Predicted: {class_names[predicted_class_idx]} ({predicted_confidence:.2f})\nTrue: {class_names[sample_label.numpy()]}")
plt.axis("off")
plt.show()

print("\nModule 1 FP32 MVP training and basic evaluation complete!")