# Pest Detection Model - Training Notebook

This notebook provides an interactive environment for training the pest detection CNN model.

## Steps:
1. Setup and imports
2. Load and explore dataset
3. Build model
4. Train model
5. Evaluate results
6. Make predictions

## 1. Setup and Imports

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Add src to path
sys.path.append('../src')

from data_loader import PestDataLoader
from model import PestDetectionModel
from predict import PestPredictor

print("✅ Imports successful!")

## 2. Load and Explore Dataset

In [None]:
# Configuration
DATA_DIR = '../data/raw'
BATCH_SIZE = 32
IMG_SIZE = (224, 224)

# Create sample dataset if needed
if not os.path.exists(DATA_DIR):
    print("Creating sample dataset...")
    loader = PestDataLoader(DATA_DIR, img_size=IMG_SIZE, batch_size=BATCH_SIZE)
    classes = loader.create_sample_dataset(DATA_DIR)
    print(f"Sample dataset created with {len(classes)} classes")

# Load dataset
loader = PestDataLoader(DATA_DIR, img_size=IMG_SIZE, batch_size=BATCH_SIZE)
train_gen, val_gen, class_names = loader.load_dataset()

print(f"\nDataset loaded:")
print(f"  Classes: {len(class_names)}")
print(f"  Class names: {class_names}")
print(f"  Training samples: {train_gen.samples}")
print(f"  Validation samples: {val_gen.samples}")

In [None]:
# Visualize sample images
images, labels = next(train_gen)

plt.figure(figsize=(15, 10))
for i in range(min(16, len(images))):
    plt.subplot(4, 4, i+1)
    plt.imshow(images[i])
    class_idx = np.argmax(labels[i])
    plt.title(class_names[class_idx])
    plt.axis('off')
plt.tight_layout()
plt.show()

## 3. Build Model

In [None]:
# Build model
model_builder = PestDetectionModel(num_classes=len(class_names), img_size=IMG_SIZE)
model = model_builder.build_model()

# Display model summary
model_builder.summary()

## 4. Train Model

In [None]:
# Training configuration
EPOCHS = 20
FINE_TUNE_EPOCHS = 10

# Get callbacks
callbacks = model_builder.get_callbacks(checkpoint_path='../models/pest_detector_best.h5')

# Train model
print("Starting initial training...")
history1 = model.fit(
    train_gen,
    epochs=EPOCHS,
    validation_data=val_gen,
    callbacks=callbacks,
    verbose=1
)

print("\n✅ Initial training complete!")

In [None]:
# Fine-tune model
print("Starting fine-tuning...")
model_builder.fine_tune(trainable_layers=20)

history2 = model.fit(
    train_gen,
    epochs=FINE_TUNE_EPOCHS,
    validation_data=val_gen,
    callbacks=callbacks,
    verbose=1
)

print("\n✅ Fine-tuning complete!")

## 5. Evaluate Results

In [None]:
# Combine histories
combined_history = {
    'accuracy': history1.history['accuracy'] + history2.history['accuracy'],
    'val_accuracy': history1.history['val_accuracy'] + history2.history['val_accuracy'],
    'loss': history1.history['loss'] + history2.history['loss'],
    'val_loss': history1.history['val_loss'] + history2.history['val_loss']
}

# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
ax1.plot(combined_history['accuracy'], label='Training')
ax1.plot(combined_history['val_accuracy'], label='Validation')
ax1.axvline(x=EPOCHS, color='r', linestyle='--', label='Fine-tuning starts')
ax1.set_title('Model Accuracy')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(True)

# Loss
ax2.plot(combined_history['loss'], label='Training')
ax2.plot(combined_history['val_loss'], label='Validation')
ax2.axvline(x=EPOCHS, color='r', linestyle='--', label='Fine-tuning starts')
ax2.set_title('Model Loss')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig('../models/training_plot.png', dpi=300)
plt.show()

print(f"\nFinal Results:")
print(f"  Validation Accuracy: {combined_history['val_accuracy'][-1]:.4f}")
print(f"  Validation Loss: {combined_history['val_loss'][-1]:.4f}")

In [None]:
# Save models
print("Saving models...")

# Save Keras model
model_builder.save_model('../models/pest_detector_final.h5')

# Convert to TFLite
model_builder.convert_to_tflite('../models/pest_detector_final.tflite')

# Save class names
import json
with open('../models/class_names.json', 'w') as f:
    json.dump(class_names, f, indent=2)

print("\n✅ Models saved successfully!")

## 6. Make Predictions

In [None]:
# Initialize predictor
predictor = PestPredictor('../models/pest_detector_best.h5')

# Get sample validation images
val_images, val_labels = next(val_gen)

# Make predictions on first 8 images
plt.figure(figsize=(15, 10))
for i in range(min(8, len(val_images))):
    # Preprocess and predict
    img_array = np.expand_dims(val_images[i], axis=0)
    predictions = model.predict(img_array, verbose=0)[0]
    
    # Get true and predicted labels
    true_idx = np.argmax(val_labels[i])
    pred_idx = np.argmax(predictions)
    confidence = predictions[pred_idx]
    
    # Plot
    plt.subplot(2, 4, i+1)
    plt.imshow(val_images[i])
    color = 'green' if true_idx == pred_idx else 'red'
    plt.title(f"True: {class_names[true_idx]}\nPred: {class_names[pred_idx]} ({confidence:.2%})", 
              color=color, fontsize=9)
    plt.axis('off')

plt.tight_layout()
plt.show()

## Next Steps

1. **Collect More Data:** Add more images to improve accuracy
2. **Add More Classes:** Include additional pests and diseases
3. **Integrate with Backend:** Connect to vision service API
4. **Deploy to Mobile:** Use TFLite model in mobile app
5. **Monitor Performance:** Track real-world accuracy