# Plant Disease Detection Model Training
## EfficientNetB0 Transfer Learning for PlantVillage Dataset

This notebook trains a plant disease detection model using transfer learning with EfficientNetB0.

In [None]:
# Import necessary libraries
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
from pathlib import Path

# Import custom modules
import sys
sys.path.append('../backend')
from disease_model import PlantDiseaseModel
from utils.preprocessing import create_data_generators, split_dataset
from utils.evaluation import ModelEvaluator

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

## 1. Dataset Preparation

In [None]:
# Dataset configuration
DATASET_PATH = 'data/PlantVillage'  # Update this path
BATCH_SIZE = 32
IMAGE_SIZE = (224, 224)
EPOCHS_INITIAL = 10
EPOCHS_FINE_TUNE = 15

# Create train/val/test splits if they don't exist
if not os.path.exists(f"{DATASET_PATH}_train"):
    print("Splitting dataset into train/val/test...")
    split_dataset(DATASET_PATH, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)
    print("Dataset split completed!")
else:
    print("Dataset splits already exist.")

In [None]:
# Create data generators
train_generator, val_generator, test_generator = create_data_generators(
    train_dir=f"{DATASET_PATH}_train",
    val_dir=f"{DATASET_PATH}_val",
    test_dir=f"{DATASET_PATH}_test",
    batch_size=BATCH_SIZE,
    target_size=IMAGE_SIZE
)

# Get class information
class_names = list(train_generator.class_indices.keys())
num_classes = len(class_names)

print(f"Number of classes: {num_classes}")
print(f"Training samples: {train_generator.samples}")
print(f"Validation samples: {val_generator.samples}")
print(f"Test samples: {test_generator.samples}")
print(f"Class names: {class_names[:5]}...")  # Show first 5 classes

## 2. Model Architecture

In [None]:
# Initialize the model
model = PlantDiseaseModel(num_classes=num_classes, input_shape=(*IMAGE_SIZE, 3))

# Build and compile the model
model.build_model()
model.compile_model(learning_rate=0.001)

# Display model summary
model.model.summary()

## 3. Initial Training (Frozen Base)

In [None]:
# Train initial model with frozen base
print("Starting initial training with frozen EfficientNetB0 base...")

initial_history = model.train_initial(
    train_data=train_generator,
    val_data=val_generator,
    epochs=EPOCHS_INITIAL
)

print("Initial training completed!")

In [None]:
# Plot initial training history
model.plot_training_history()
plt.suptitle('Initial Training (Frozen Base)')
plt.show()

## 4. Fine-Tuning

In [None]:
# Fine-tune the model
print("Starting fine-tuning...")

fine_tune_history = model.fine_tune(
    train_data=train_generator,
    val_data=val_generator,
    epochs=EPOCHS_FINE_TUNE,
    unfreeze_layers=50  # Unfreeze top 50 layers
)

print("Fine-tuning completed!")

In [None]:
# Plot combined training history
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Initial training
axes[0, 0].plot(initial_history.history['accuracy'], label='Training')
axes[0, 0].plot(initial_history.history['val_accuracy'], label='Validation')
axes[0, 0].set_title('Initial Training - Accuracy')
axes[0, 0].legend()

axes[0, 1].plot(initial_history.history['loss'], label='Training')
axes[0, 1].plot(initial_history.history['val_loss'], label='Validation')
axes[0, 1].set_title('Initial Training - Loss')
axes[0, 1].legend()

# Fine-tuning
axes[1, 0].plot(fine_tune_history.history['accuracy'], label='Training')
axes[1, 0].plot(fine_tune_history.history['val_accuracy'], label='Validation')
axes[1, 0].set_title('Fine-Tuning - Accuracy')
axes[1, 0].legend()

axes[1, 1].plot(fine_tune_history.history['loss'], label='Training')
axes[1, 1].plot(fine_tune_history.history['val_loss'], label='Validation')
axes[1, 1].set_title('Fine-Tuning - Loss')
axes[1, 1].legend()

plt.tight_layout()
plt.show()

## 5. Model Evaluation

In [None]:
# Evaluate on test set
print("Evaluating model on test set...")

evaluator = ModelEvaluator(class_names)
report, cm = model.evaluate_model(test_generator, class_names)

print(f"Test Accuracy: {report['accuracy']:.4f}")
print(f"Test Precision: {report['weighted avg']['precision']:.4f}")
print(f"Test Recall: {report['weighted avg']['recall']:.4f}")
print(f"Test F1-Score: {report['weighted avg']['f1-score']:.4f}")

In [None]:
# Plot confusion matrix (for top 20 classes)
top_classes = 20
if num_classes > top_classes:
    cm_subset = cm[:top_classes, :top_classes]
    class_names_subset = class_names[:top_classes]
else:
    cm_subset = cm
    class_names_subset = class_names

evaluator_subset = ModelEvaluator(class_names_subset)
plt.figure(figsize=(15, 12))
evaluator_subset.plot_confusion_matrix(cm_subset, 'Confusion Matrix (Top 20 Classes)')
plt.show()

In [None]:
# Plot classification report
evaluator.plot_classification_report(report, 'Classification Report')
plt.show()

## 6. Model Saving and Conversion

In [None]:
# Save the trained model
model.save_model('plant_disease_model.h5')

# Convert to TensorFlow Lite
tflite_path = model.convert_to_tflite()

# Save class names
with open('class_names.json', 'w') as f:
    json.dump(class_names, f, indent=2)

print("Model saved successfully!")
print(f"Keras model: plant_disease_model.h5")
print(f"TFLite model: {tflite_path}")
print(f"Class names: class_names.json")

## 7. Model Testing with Sample Images

In [None]:
# Test with some sample images
import cv2
from utils.preprocessing import ImagePreprocessor

preprocessor = ImagePreprocessor()

# Get a batch of test images
test_batch = test_generator.next()
test_images, test_labels = test_batch

# Make predictions
predictions = model.model.predict(test_images)

# Display results
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

for i in range(8):
    # Display image
    axes[i].imshow(test_images[i])
    
    # Get predictions
    true_class_idx = np.argmax(test_labels[i])
    pred_class_idx = np.argmax(predictions[i])
    confidence = predictions[i][pred_class_idx]
    
    true_class = class_names[true_class_idx]
    pred_class = class_names[pred_class_idx]
    
    # Set title
    color = 'green' if true_class_idx == pred_class_idx else 'red'
    axes[i].set_title(f'True: {true_class[:15]}...\nPred: {pred_class[:15]}... ({confidence:.2f})', 
                     color=color, fontsize=8)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 8. Performance Analysis

In [None]:
# Analyze per-class performance
per_class_metrics = evaluator.calculate_per_class_metrics(
    np.concatenate([test_generator.next()[1] for _ in range(len(test_generator))]),
    predictions
)

# Create performance DataFrame
performance_df = pd.DataFrame(per_class_metrics).T
performance_df = performance_df.sort_values('f1_score', ascending=False)

print("Top 10 Best Performing Classes:")
print(performance_df.head(10)[['precision', 'recall', 'f1_score']])

print("\nTop 10 Worst Performing Classes:")
print(performance_df.tail(10)[['precision', 'recall', 'f1_score']])

In [None]:
# Plot performance distribution
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

metrics = ['precision', 'recall', 'f1_score']
for i, metric in enumerate(metrics):
    axes[i].hist(performance_df[metric], bins=20, alpha=0.7, edgecolor='black')
    axes[i].set_title(f'{metric.title()} Distribution')
    axes[i].set_xlabel(metric.title())
    axes[i].set_ylabel('Number of Classes')
    axes[i].axvline(performance_df[metric].mean(), color='red', linestyle='--', 
                   label=f'Mean: {performance_df[metric].mean():.3f}')
    axes[i].legend()

plt.tight_layout()
plt.show()

## 9. Model Size and Inference Speed

In [None]:
import time
import os

# Model size comparison
keras_size = os.path.getsize('plant_disease_model.h5') / (1024 * 1024)  # MB
tflite_size = os.path.getsize(tflite_path) / (1024 * 1024)  # MB

print(f"Keras model size: {keras_size:.2f} MB")
print(f"TFLite model size: {tflite_size:.2f} MB")
print(f"Size reduction: {((keras_size - tflite_size) / keras_size * 100):.1f}%")

# Inference speed test
test_image = test_images[:1]  # Single image

# Keras model speed
start_time = time.time()
for _ in range(100):
    _ = model.model.predict(test_image, verbose=0)
keras_time = (time.time() - start_time) / 100

# TFLite model speed
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

start_time = time.time()
for _ in range(100):
    interpreter.set_tensor(input_details[0]['index'], test_image)
    interpreter.invoke()
    _ = interpreter.get_tensor(output_details[0]['index'])
tflite_time = (time.time() - start_time) / 100

print(f"\nInference Speed (average of 100 runs):")
print(f"Keras model: {keras_time*1000:.2f} ms")
print(f"TFLite model: {tflite_time*1000:.2f} ms")
print(f"Speed improvement: {(keras_time/tflite_time):.1f}x")

## 10. Final Summary

In [None]:
# Generate final report
final_report = f"""
# Plant Disease Detection Model Training Summary

## Dataset Information
- Total Classes: {num_classes}
- Training Samples: {train_generator.samples}
- Validation Samples: {val_generator.samples}
- Test Samples: {test_generator.samples}

## Model Architecture
- Base Model: EfficientNetB0 (ImageNet pretrained)
- Input Size: {IMAGE_SIZE}
- Total Parameters: {model.model.count_params():,}

## Training Configuration
- Initial Training Epochs: {EPOCHS_INITIAL}
- Fine-tuning Epochs: {EPOCHS_FINE_TUNE}
- Batch Size: {BATCH_SIZE}
- Total Training Epochs: {EPOCHS_INITIAL + EPOCHS_FINE_TUNE}

## Final Performance
- Test Accuracy: {report['accuracy']:.4f}
- Test Precision: {report['weighted avg']['precision']:.4f}
- Test Recall: {report['weighted avg']['recall']:.4f}
- Test F1-Score: {report['weighted avg']['f1-score']:.4f}

## Model Optimization
- Keras Model Size: {keras_size:.2f} MB
- TFLite Model Size: {tflite_size:.2f} MB
- Size Reduction: {((keras_size - tflite_size) / keras_size * 100):.1f}%
- Inference Speed (TFLite): {tflite_time*1000:.2f} ms

## Model Files
- Keras Model: plant_disease_model.h5
- TFLite Model: {tflite_path}
- Class Names: class_names.json
"""

print(final_report)

# Save report
with open('training_report.md', 'w') as f:
    f.write(final_report)

print("\n✅ Training completed successfully!")
print("📊 All results and models have been saved.")
print("🚀 Ready for deployment!")