# Coconut Mite Detection - Model Training

Train a CNN model to detect coconut mite infestation using transfer learning.

## Model Architecture
- **Base Model:** MobileNetV2 (pretrained on ImageNet)
- **Task:** Binary Classification (Mite Infected vs Healthy)
- **Target:** Mobile deployment via TensorFlow Lite

## 1. Setup & Configuration

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import datetime

# TensorFlow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from tensorflow.keras.optimizers import Adam

# Sklearn for metrics
from sklearn.metrics import classification_report, confusion_matrix

print(f"TensorFlow Version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

In [None]:
# Configuration
CONFIG = {
    # Data paths
    'data_dir': Path('../data/raw/pest'),
    'model_dir': Path('../models/coconut_mite'),
    
    # Image settings
    'img_height': 224,
    'img_width': 224,
    'channels': 3,
    
    # Training settings
    'batch_size': 32,
    'epochs': 50,
    'learning_rate': 0.0001,
    'validation_split': 0.2,
    
    # Classes
    'classes': ['coconut_mite', 'healthy'],
    'num_classes': 2
}

# Create model directory if not exists
CONFIG['model_dir'].mkdir(parents=True, exist_ok=True)

print("Configuration loaded!")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 2. Data Loading & Preprocessing

In [None]:
# Check data availability
mite_path = CONFIG['data_dir'] / 'coconut_mite'
healthy_path = CONFIG['data_dir'] / 'healthy'

mite_count = len(list(mite_path.glob('*.jpg'))) + len(list(mite_path.glob('*.png')))
healthy_count = len(list(healthy_path.glob('*.jpg'))) + len(list(healthy_path.glob('*.png')))

print(f"Coconut Mite images: {mite_count:,}")
print(f"Healthy images: {healthy_count:,}")

if healthy_count == 0:
    print("\n‚ö†Ô∏è WARNING: No healthy images found!")
    print("Please upload healthy images before training.")
else:
    print(f"\n‚úÖ Total images: {mite_count + healthy_count:,}")

In [None]:
# Data generators with augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=CONFIG['validation_split'],
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=CONFIG['validation_split']
)

In [None]:
# Load training data
train_generator = train_datagen.flow_from_directory(
    CONFIG['data_dir'],
    target_size=(CONFIG['img_height'], CONFIG['img_width']),
    batch_size=CONFIG['batch_size'],
    class_mode='categorical',
    subset='training',
    shuffle=True,
    classes=CONFIG['classes']
)

# Load validation data
validation_generator = val_datagen.flow_from_directory(
    CONFIG['data_dir'],
    target_size=(CONFIG['img_height'], CONFIG['img_width']),
    batch_size=CONFIG['batch_size'],
    class_mode='categorical',
    subset='validation',
    shuffle=False,
    classes=CONFIG['classes']
)

print(f"\nTraining samples: {train_generator.samples}")
print(f"Validation samples: {validation_generator.samples}")
print(f"\nClass indices: {train_generator.class_indices}")

In [None]:
# Calculate class weights for imbalanced data
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_generator.classes),
    y=train_generator.classes
)
class_weights_dict = dict(enumerate(class_weights))

print(f"Class weights: {class_weights_dict}")

In [None]:
# Visualize sample batch
sample_batch, sample_labels = next(train_generator)

fig, axes = plt.subplots(2, 4, figsize=(15, 8))
axes = axes.flatten()

class_names = list(train_generator.class_indices.keys())

for idx in range(8):
    axes[idx].imshow(sample_batch[idx])
    label_idx = np.argmax(sample_labels[idx])
    axes[idx].set_title(f'Class: {class_names[label_idx]}', fontsize=10)
    axes[idx].axis('off')

plt.suptitle('Sample Training Batch (After Augmentation)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 3. Model Architecture

In [None]:
def create_model(input_shape, num_classes):
    """
    Create a transfer learning model using MobileNetV2
    """
    # Load pretrained MobileNetV2
    base_model = MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model layers
    base_model.trainable = False
    
    # Build model
    inputs = keras.Input(shape=input_shape)
    
    # Preprocessing (MobileNetV2 expects [-1, 1])
    x = keras.applications.mobilenet_v2.preprocess_input(inputs * 255)
    
    # Base model
    x = base_model(x, training=False)
    
    # Custom classification head
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    
    # Output layer
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    
    return model, base_model

# Create model
input_shape = (CONFIG['img_height'], CONFIG['img_width'], CONFIG['channels'])
model, base_model = create_model(input_shape, CONFIG['num_classes'])

model.summary()

In [None]:
# Compile model
model.compile(
    optimizer=Adam(learning_rate=CONFIG['learning_rate']),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print("Model compiled successfully!")

## 4. Training Callbacks

In [None]:
# Define callbacks
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

callbacks = [
    # Early stopping
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    
    # Model checkpoint
    ModelCheckpoint(
        filepath=str(CONFIG['model_dir'] / 'best_model.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    
    # Learning rate reduction
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    
    # TensorBoard logging
    TensorBoard(
        log_dir=str(CONFIG['model_dir'] / f'logs/{timestamp}'),
        histogram_freq=1
    )
]

print("Callbacks configured!")

## 5. Model Training - Phase 1 (Feature Extraction)

In [None]:
# Phase 1: Train only the classification head
print("=" * 60)
print("PHASE 1: Training Classification Head (Base Model Frozen)")
print("=" * 60)

history_phase1 = model.fit(
    train_generator,
    epochs=20,
    validation_data=validation_generator,
    class_weight=class_weights_dict,
    callbacks=callbacks,
    verbose=1
)

## 6. Model Training - Phase 2 (Fine-tuning)

In [None]:
# Phase 2: Fine-tune top layers of base model
print("\n" + "=" * 60)
print("PHASE 2: Fine-tuning Top Layers")
print("=" * 60)

# Unfreeze top layers of base model
base_model.trainable = True

# Freeze all layers except the last 30
for layer in base_model.layers[:-30]:
    layer.trainable = False

# Recompile with lower learning rate
model.compile(
    optimizer=Adam(learning_rate=CONFIG['learning_rate'] / 10),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Continue training
history_phase2 = model.fit(
    train_generator,
    epochs=CONFIG['epochs'],
    initial_epoch=len(history_phase1.history['loss']),
    validation_data=validation_generator,
    class_weight=class_weights_dict,
    callbacks=callbacks,
    verbose=1
)

## 7. Training History Visualization

In [None]:
# Combine histories
def combine_histories(h1, h2):
    combined = {}
    for key in h1.history.keys():
        combined[key] = h1.history[key] + h2.history[key]
    return combined

history = combine_histories(history_phase1, history_phase2)

# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
axes[0].plot(history['accuracy'], label='Training Accuracy', linewidth=2)
axes[0].plot(history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[0].axvline(x=len(history_phase1.history['accuracy'])-1, color='r', linestyle='--', label='Fine-tuning Start')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Model Accuracy', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(history['loss'], label='Training Loss', linewidth=2)
axes[1].plot(history['val_loss'], label='Validation Loss', linewidth=2)
axes[1].axvline(x=len(history_phase1.history['loss'])-1, color='r', linestyle='--', label='Fine-tuning Start')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Model Loss', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(CONFIG['model_dir'] / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Model Evaluation

In [None]:
# Evaluate on validation set
print("=" * 50)
print("MODEL EVALUATION")
print("=" * 50)

val_loss, val_accuracy = model.evaluate(validation_generator, verbose=0)
print(f"\nValidation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")

In [None]:
# Generate predictions
validation_generator.reset()
predictions = model.predict(validation_generator, verbose=1)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = validation_generator.classes
class_names = list(validation_generator.class_indices.keys())

# Classification report
print("\n" + "=" * 50)
print("CLASSIFICATION REPORT")
print("=" * 50)
print(classification_report(true_classes, predicted_classes, target_names=class_names))

In [None]:
# Confusion Matrix
cm = confusion_matrix(true_classes, predicted_classes)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('Actual', fontsize=12)
plt.title('Confusion Matrix - Coconut Mite Detection', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(CONFIG['model_dir'] / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Save Model

In [None]:
# Save model in different formats
print("=" * 50)
print("SAVING MODEL")
print("=" * 50)

# Save Keras model
model.save(CONFIG['model_dir'] / 'coconut_mite_model.keras')
print(f"‚úÖ Keras model saved: {CONFIG['model_dir'] / 'coconut_mite_model.keras'}")

# Save as H5 format
model.save(CONFIG['model_dir'] / 'coconut_mite_model.h5')
print(f"‚úÖ H5 model saved: {CONFIG['model_dir'] / 'coconut_mite_model.h5'}")

In [None]:
# Convert to TensorFlow Lite (for mobile deployment)
print("\nConverting to TensorFlow Lite...")

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save TFLite model
tflite_path = CONFIG['model_dir'] / 'coconut_mite_model.tflite'
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)

print(f"‚úÖ TFLite model saved: {tflite_path}")
print(f"   TFLite model size: {os.path.getsize(tflite_path) / (1024*1024):.2f} MB")

In [None]:
# Save class labels
import json

labels_info = {
    'class_indices': train_generator.class_indices,
    'class_names': class_names,
    'input_shape': [CONFIG['img_height'], CONFIG['img_width'], CONFIG['channels']],
    'model_version': '1.0.0',
    'training_date': datetime.datetime.now().isoformat()
}

with open(CONFIG['model_dir'] / 'model_info.json', 'w') as f:
    json.dump(labels_info, f, indent=2)

print(f"‚úÖ Model info saved: {CONFIG['model_dir'] / 'model_info.json'}")

## 10. Test Prediction

In [None]:
def predict_image(model, image_path, target_size=(224, 224)):
    """Make prediction on a single image"""
    from tensorflow.keras.preprocessing import image
    
    img = image.load_img(image_path, target_size=target_size)
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array / 255.0
    
    predictions = model.predict(img_array, verbose=0)
    predicted_class = np.argmax(predictions[0])
    confidence = predictions[0][predicted_class]
    
    return class_names[predicted_class], confidence, img

# Test with a sample image
sample_images = list(mite_path.glob('*.jpg'))[:3]

if sample_images:
    fig, axes = plt.subplots(1, len(sample_images), figsize=(15, 5))
    
    for idx, img_path in enumerate(sample_images):
        pred_class, confidence, img = predict_image(model, img_path)
        
        axes[idx].imshow(img)
        color = 'green' if pred_class == 'healthy' else 'red'
        axes[idx].set_title(f'Prediction: {pred_class}\nConfidence: {confidence:.2%}', 
                           fontsize=11, color=color, fontweight='bold')
        axes[idx].axis('off')
    
    plt.suptitle('Sample Predictions', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

## 11. Summary

### Training Complete!

**Saved Files:**
- `coconut_mite_model.keras` - Full Keras model
- `coconut_mite_model.h5` - H5 format model
- `coconut_mite_model.tflite` - TensorFlow Lite (for mobile)
- `model_info.json` - Model metadata and class labels
- `training_history.png` - Training curves
- `confusion_matrix.png` - Evaluation results

**Next Steps:**
1. Deploy model via Flask API
2. Integrate TFLite model with React Native app
3. Train models for other pest types

In [None]:
print("\n" + "=" * 60)
print("üéâ COCONUT MITE DETECTION MODEL TRAINING COMPLETE!")
print("=" * 60)
print(f"\nüìä Final Validation Accuracy: {val_accuracy*100:.2f}%")
print(f"üìÅ Models saved to: {CONFIG['model_dir'].absolute()}")
print("\nüëâ Next: Run Flask API to serve the model")