# ü•• Coconut Mite Detection - High Accuracy Model Training

**Optimized for Google Colab with GPU**

## Features:
- Multiple model architectures comparison
- Advanced data augmentation
- Learning rate scheduling
- K-Fold Cross Validation
- Ensemble predictions
- TensorFlow Lite export for mobile

---

## 1. Setup & GPU Check

In [None]:
# Mount Google Drive (for Colab)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Check GPU availability
import tensorflow as tf

print("TensorFlow Version:", tf.__version__)
print("\n" + "="*50)
print("GPU STATUS")
print("="*50)

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"‚úÖ GPU Available: {gpus}")
    # Enable memory growth
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
else:
    print("‚ùå No GPU found! Go to Runtime > Change runtime type > GPU")

In [None]:
# Install additional libraries
!pip install -q albumentations
!pip install -q scikit-learn
!pip install -q seaborn

In [None]:
# Imports
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import datetime
import json
import cv2
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

# TensorFlow & Keras
from tensorflow import keras
from tensorflow.keras import layers, Model, Sequential
from tensorflow.keras.applications import (
    MobileNetV2,
    EfficientNetB0,
    EfficientNetB3,
    ResNet50V2,
    InceptionV3,
    Xception
)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import (
    EarlyStopping, 
    ModelCheckpoint, 
    ReduceLROnPlateau, 
    TensorBoard,
    LearningRateScheduler
)
from tensorflow.keras.optimizers import Adam, AdamW
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy, AUC

# Sklearn
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    roc_curve, 
    auc,
    precision_recall_curve
)
from sklearn.utils.class_weight import compute_class_weight

# Albumentations for advanced augmentation
import albumentations as A

print("\n‚úÖ All libraries imported successfully!")

## 2. Configuration

In [None]:
# ============================================
# CONFIGURATION - MODIFY THESE PATHS
# ============================================

# Option 1: Google Drive path
# DATA_DIR = Path('/content/drive/MyDrive/CoconutML/data/raw/pest')
# MODEL_DIR = Path('/content/drive/MyDrive/CoconutML/models/coconut_mite')

# Option 2: Local path (if uploaded to Colab)
DATA_DIR = Path('../data/raw/pest')
MODEL_DIR = Path('../models/coconut_mite')

# ============================================
# MODEL CONFIGURATION
# ============================================
CONFIG = {
    # Image settings
    'img_size': 224,           # Image size (224 for MobileNet/EfficientNet, 299 for Inception/Xception)
    'channels': 3,
    
    # Training settings
    'batch_size': 32,          # Reduce if OOM error
    'epochs': 100,             # Will early stop anyway
    'learning_rate': 1e-4,
    'min_lr': 1e-7,
    
    # Data split
    'test_split': 0.15,        # 15% for final testing
    'val_split': 0.15,         # 15% for validation
    
    # Cross validation
    'use_kfold': False,        # Set True for K-Fold CV
    'n_folds': 5,
    
    # Classes
    'classes': ['coconut_mite', 'healthy'],
    'num_classes': 2,
    
    # Seed for reproducibility
    'seed': 42
}

# Set seeds
np.random.seed(CONFIG['seed'])
tf.random.set_seed(CONFIG['seed'])
random.seed(CONFIG['seed'])

# Create directories
MODEL_DIR.mkdir(parents=True, exist_ok=True)

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Data Loading & Analysis

In [None]:
def load_dataset(data_dir, classes):
    """Load image paths and labels"""
    images = []
    labels = []
    
    for class_idx, class_name in enumerate(classes):
        class_dir = data_dir / class_name
        if not class_dir.exists():
            print(f"‚ö†Ô∏è Warning: {class_dir} does not exist!")
            continue
            
        for img_path in class_dir.glob('*'):
            if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                images.append(str(img_path))
                labels.append(class_idx)
    
    return np.array(images), np.array(labels)

# Load dataset
X, y = load_dataset(DATA_DIR, CONFIG['classes'])

print("="*50)
print("DATASET SUMMARY")
print("="*50)
print(f"\nTotal images: {len(X):,}")
for idx, class_name in enumerate(CONFIG['classes']):
    count = np.sum(y == idx)
    print(f"  {class_name}: {count:,} ({count/len(X)*100:.1f}%)")

In [None]:
# Train/Val/Test Split
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, 
    test_size=CONFIG['test_split'], 
    stratify=y, 
    random_state=CONFIG['seed']
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, 
    test_size=CONFIG['val_split']/(1-CONFIG['test_split']), 
    stratify=y_temp, 
    random_state=CONFIG['seed']
)

print(f"\nData Split:")
print(f"  Training:   {len(X_train):,} images")
print(f"  Validation: {len(X_val):,} images")
print(f"  Test:       {len(X_test):,} images")

In [None]:
# Calculate class weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights_dict = dict(enumerate(class_weights))
print(f"\nClass weights: {class_weights_dict}")

## 4. Advanced Data Augmentation

In [None]:
# Albumentations augmentation pipeline
def get_augmentation(mode='train'):
    """Get augmentation pipeline"""
    
    if mode == 'train':
        return A.Compose([
            A.Resize(CONFIG['img_size'], CONFIG['img_size']),
            
            # Geometric transforms
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.3),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.1, 
                scale_limit=0.2, 
                rotate_limit=30, 
                p=0.5
            ),
            
            # Color transforms
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
                A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1),
                A.ColorJitter(p=1),
            ], p=0.5),
            
            # Blur & Noise
            A.OneOf([
                A.GaussianBlur(blur_limit=(3, 5), p=1),
                A.GaussNoise(var_limit=(10, 50), p=1),
                A.ISONoise(p=1),
            ], p=0.3),
            
            # Advanced
            A.CoarseDropout(
                max_holes=8, 
                max_height=CONFIG['img_size']//8,
                max_width=CONFIG['img_size']//8,
                fill_value=0, 
                p=0.3
            ),
            
            # Normalize
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])
    else:
        return A.Compose([
            A.Resize(CONFIG['img_size'], CONFIG['img_size']),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])

print("‚úÖ Augmentation pipelines created!")

In [None]:
# Data Generator class
class DataGenerator(keras.utils.Sequence):
    def __init__(self, image_paths, labels, batch_size, augmentation, num_classes, shuffle=True):
        self.image_paths = image_paths
        self.labels = labels
        self.batch_size = batch_size
        self.augmentation = augmentation
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.indexes = np.arange(len(self.image_paths))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))
    
    def __getitem__(self, index):
        batch_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_paths = self.image_paths[batch_indexes]
        batch_labels = self.labels[batch_indexes]
        
        X = np.zeros((len(batch_paths), CONFIG['img_size'], CONFIG['img_size'], 3), dtype=np.float32)
        y = keras.utils.to_categorical(batch_labels, num_classes=self.num_classes)
        
        for i, path in enumerate(batch_paths):
            img = cv2.imread(path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            augmented = self.augmentation(image=img)
            X[i] = augmented['image']
        
        return X, y
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)

# Create generators
train_gen = DataGenerator(
    X_train, y_train, 
    CONFIG['batch_size'], 
    get_augmentation('train'),
    CONFIG['num_classes'],
    shuffle=True
)

val_gen = DataGenerator(
    X_val, y_val,
    CONFIG['batch_size'],
    get_augmentation('val'),
    CONFIG['num_classes'],
    shuffle=False
)

test_gen = DataGenerator(
    X_test, y_test,
    CONFIG['batch_size'],
    get_augmentation('val'),
    CONFIG['num_classes'],
    shuffle=False
)

print(f"\n‚úÖ Data generators created!")
print(f"  Training batches: {len(train_gen)}")
print(f"  Validation batches: {len(val_gen)}")
print(f"  Test batches: {len(test_gen)}")

In [None]:
# Visualize augmented samples
sample_img_path = X_train[0]
sample_img = cv2.imread(sample_img_path)
sample_img = cv2.cvtColor(sample_img, cv2.COLOR_BGR2RGB)

aug = get_augmentation('train')

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

axes[0].imshow(cv2.resize(sample_img, (CONFIG['img_size'], CONFIG['img_size'])))
axes[0].set_title('Original', fontsize=12, fontweight='bold')
axes[0].axis('off')

for i in range(1, 8):
    augmented = aug(image=sample_img)['image']
    # Denormalize for visualization
    img_show = augmented * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_show = np.clip(img_show, 0, 1)
    axes[i].imshow(img_show)
    axes[i].set_title(f'Augmented {i}', fontsize=11)
    axes[i].axis('off')

plt.suptitle('Data Augmentation Examples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Model Architecture Comparison

In [None]:
def create_model(backbone_name, input_shape, num_classes, dropout_rate=0.5):
    """
    Create a transfer learning model with specified backbone
    """
    
    # Backbone selection
    backbones = {
        'MobileNetV2': MobileNetV2,
        'EfficientNetB0': EfficientNetB0,
        'EfficientNetB3': EfficientNetB3,
        'ResNet50V2': ResNet50V2,
        'InceptionV3': InceptionV3,
        'Xception': Xception
    }
    
    if backbone_name not in backbones:
        raise ValueError(f"Unknown backbone: {backbone_name}")
    
    # Load base model
    base_model = backbones[backbone_name](
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model
    base_model.trainable = False
    
    # Build model
    inputs = keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    
    # Global pooling + classification head
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(0.01))(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(256, activation='relu', kernel_regularizer=keras.regularizers.l2(0.01))(x)
    x = layers.Dropout(dropout_rate * 0.6)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs, name=f'{backbone_name}_classifier')
    
    return model, base_model

print("‚úÖ Model creation function ready!")

In [None]:
# Learning rate scheduler
def cosine_annealing_schedule(epoch, lr, total_epochs=100, min_lr=1e-7):
    """Cosine annealing with warm restarts"""
    return min_lr + (lr - min_lr) * (1 + np.cos(np.pi * epoch / total_epochs)) / 2

def get_callbacks(model_name, patience=15):
    """Get training callbacks"""
    return [
        EarlyStopping(
            monitor='val_loss',
            patience=patience,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=str(MODEL_DIR / f'{model_name}_best.keras'),
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=CONFIG['min_lr'],
            verbose=1
        ),
        LearningRateScheduler(
            lambda epoch, lr: cosine_annealing_schedule(epoch, CONFIG['learning_rate'], CONFIG['epochs']),
            verbose=0
        )
    ]

print("‚úÖ Callbacks configured!")

## 6. Train Multiple Models & Compare

In [None]:
# Select models to compare
# Comment out models you don't want to train
MODELS_TO_TRAIN = [
    'EfficientNetB0',     # Best balance of accuracy and size
    # 'EfficientNetB3',   # Higher accuracy, larger model
    # 'MobileNetV2',      # Fastest, good for mobile
    # 'ResNet50V2',       # Classic, very accurate
    # 'Xception',         # Very accurate, larger model
]

print(f"Models to train: {MODELS_TO_TRAIN}")

In [None]:
def train_model(backbone_name, train_gen, val_gen, class_weights, epochs=100):
    """Train a model with 2-phase training"""
    
    print("\n" + "="*60)
    print(f"TRAINING: {backbone_name}")
    print("="*60)
    
    input_shape = (CONFIG['img_size'], CONFIG['img_size'], 3)
    model, base_model = create_model(backbone_name, input_shape, CONFIG['num_classes'])
    
    # Phase 1: Train classification head only
    print("\nüìå Phase 1: Training classification head...")
    model.compile(
        optimizer=Adam(learning_rate=CONFIG['learning_rate']),
        loss=CategoricalCrossentropy(label_smoothing=0.1),
        metrics=['accuracy', AUC(name='auc')]
    )
    
    history1 = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=20,
        class_weight=class_weights,
        callbacks=get_callbacks(f'{backbone_name}_phase1', patience=10),
        verbose=1
    )
    
    # Phase 2: Fine-tune top layers
    print("\nüìå Phase 2: Fine-tuning top layers...")
    base_model.trainable = True
    
    # Freeze early layers
    fine_tune_from = int(len(base_model.layers) * 0.7)  # Unfreeze top 30%
    for layer in base_model.layers[:fine_tune_from]:
        layer.trainable = False
    
    model.compile(
        optimizer=Adam(learning_rate=CONFIG['learning_rate'] / 10),
        loss=CategoricalCrossentropy(label_smoothing=0.1),
        metrics=['accuracy', AUC(name='auc')]
    )
    
    history2 = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=epochs,
        initial_epoch=len(history1.history['loss']),
        class_weight=class_weights,
        callbacks=get_callbacks(f'{backbone_name}_phase2', patience=15),
        verbose=1
    )
    
    # Combine histories
    history = {}
    for key in history1.history.keys():
        history[key] = history1.history[key] + history2.history[key]
    
    return model, history

print("‚úÖ Training function ready!")

In [None]:
# Train all selected models
results = {}

for backbone in MODELS_TO_TRAIN:
    model, history = train_model(
        backbone, 
        train_gen, 
        val_gen, 
        class_weights_dict,
        epochs=CONFIG['epochs']
    )
    
    # Evaluate on test set
    test_loss, test_acc, test_auc = model.evaluate(test_gen, verbose=0)
    
    results[backbone] = {
        'model': model,
        'history': history,
        'test_loss': test_loss,
        'test_accuracy': test_acc,
        'test_auc': test_auc
    }
    
    print(f"\n‚úÖ {backbone} - Test Accuracy: {test_acc*100:.2f}%, Test AUC: {test_auc:.4f}")

In [None]:
# Compare results
print("\n" + "="*60)
print("MODEL COMPARISON RESULTS")
print("="*60)

comparison_df = pd.DataFrame({
    'Model': list(results.keys()),
    'Test Accuracy': [r['test_accuracy']*100 for r in results.values()],
    'Test AUC': [r['test_auc'] for r in results.values()],
    'Test Loss': [r['test_loss'] for r in results.values()]
}).sort_values('Test Accuracy', ascending=False)

print(comparison_df.to_string(index=False))

# Best model
best_model_name = comparison_df.iloc[0]['Model']
best_accuracy = comparison_df.iloc[0]['Test Accuracy']
print(f"\nüèÜ Best Model: {best_model_name} with {best_accuracy:.2f}% accuracy")

## 7. Detailed Evaluation of Best Model

In [None]:
# Get best model
best_model = results[best_model_name]['model']
best_history = results[best_model_name]['history']

# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Accuracy
axes[0].plot(best_history['accuracy'], label='Train', linewidth=2)
axes[0].plot(best_history['val_accuracy'], label='Validation', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].set_title(f'{best_model_name} - Accuracy', fontsize=12, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(best_history['loss'], label='Train', linewidth=2)
axes[1].plot(best_history['val_loss'], label='Validation', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title(f'{best_model_name} - Loss', fontsize=12, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# AUC
axes[2].plot(best_history['auc'], label='Train', linewidth=2)
axes[2].plot(best_history['val_auc'], label='Validation', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('AUC')
axes[2].set_title(f'{best_model_name} - AUC', fontsize=12, fontweight='bold')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(MODEL_DIR / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Generate predictions on test set
y_pred_proba = best_model.predict(test_gen, verbose=1)
y_pred = np.argmax(y_pred_proba, axis=1)
y_true = y_test

# Classification Report
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(y_true, y_pred, target_names=CONFIG['classes']))

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=CONFIG['classes'], yticklabels=CONFIG['classes'])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')
axes[0].set_title('Confusion Matrix (Counts)', fontsize=12, fontweight='bold')

# Normalized
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues', ax=axes[1],
            xticklabels=CONFIG['classes'], yticklabels=CONFIG['classes'])
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Actual')
axes[1].set_title('Confusion Matrix (Normalized)', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig(MODEL_DIR / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ROC Curve & Precision-Recall Curve
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# ROC Curve
fpr, tpr, _ = roc_curve(y_true, y_pred_proba[:, 1])
roc_auc = auc(fpr, tpr)

axes[0].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
axes[0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
axes[0].set_xlim([0.0, 1.0])
axes[0].set_ylim([0.0, 1.05])
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('ROC Curve', fontsize=12, fontweight='bold')
axes[0].legend(loc='lower right')
axes[0].grid(True, alpha=0.3)

# Precision-Recall Curve
precision, recall, _ = precision_recall_curve(y_true, y_pred_proba[:, 1])
pr_auc = auc(recall, precision)

axes[1].plot(recall, precision, color='green', lw=2, label=f'PR curve (AUC = {pr_auc:.4f})')
axes[1].set_xlim([0.0, 1.0])
axes[1].set_ylim([0.0, 1.05])
axes[1].set_xlabel('Recall')
axes[1].set_ylabel('Precision')
axes[1].set_title('Precision-Recall Curve', fontsize=12, fontweight='bold')
axes[1].legend(loc='lower left')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(MODEL_DIR / 'roc_pr_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Save Best Model

In [None]:
# Save model in multiple formats
print("="*60)
print("SAVING BEST MODEL")
print("="*60)

# Keras format
best_model.save(MODEL_DIR / 'coconut_mite_best.keras')
print(f"‚úÖ Keras model saved: {MODEL_DIR / 'coconut_mite_best.keras'}")

# H5 format
best_model.save(MODEL_DIR / 'coconut_mite_best.h5')
print(f"‚úÖ H5 model saved: {MODEL_DIR / 'coconut_mite_best.h5'}")

In [None]:
# Convert to TensorFlow Lite (optimized for mobile)
print("\nConverting to TensorFlow Lite...")

# Standard TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(best_model)
tflite_model = converter.convert()

tflite_path = MODEL_DIR / 'coconut_mite_model.tflite'
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)
print(f"‚úÖ TFLite model saved: {tflite_path}")
print(f"   Size: {os.path.getsize(tflite_path) / (1024*1024):.2f} MB")

# Quantized TFLite (smaller, faster)
converter_quant = tf.lite.TFLiteConverter.from_keras_model(best_model)
converter_quant.optimizations = [tf.lite.Optimize.DEFAULT]
converter_quant.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter_quant.convert()

tflite_quant_path = MODEL_DIR / 'coconut_mite_model_quantized.tflite'
with open(tflite_quant_path, 'wb') as f:
    f.write(tflite_quant_model)
print(f"‚úÖ Quantized TFLite saved: {tflite_quant_path}")
print(f"   Size: {os.path.getsize(tflite_quant_path) / (1024*1024):.2f} MB")

In [None]:
# Save model metadata
model_info = {
    'model_name': best_model_name,
    'input_shape': [CONFIG['img_size'], CONFIG['img_size'], 3],
    'classes': CONFIG['classes'],
    'class_indices': {c: i for i, c in enumerate(CONFIG['classes'])},
    'test_accuracy': float(results[best_model_name]['test_accuracy']),
    'test_auc': float(results[best_model_name]['test_auc']),
    'normalization': {
        'mean': [0.485, 0.456, 0.406],
        'std': [0.229, 0.224, 0.225]
    },
    'training_date': datetime.datetime.now().isoformat(),
    'version': '2.0.0'
}

with open(MODEL_DIR / 'model_info.json', 'w') as f:
    json.dump(model_info, f, indent=2)

print(f"\n‚úÖ Model info saved: {MODEL_DIR / 'model_info.json'}")

## 9. Test Predictions

In [None]:
def predict_single_image(model, image_path, class_names):
    """Predict a single image"""
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Preprocess
    aug = get_augmentation('val')
    augmented = aug(image=img)['image']
    img_batch = np.expand_dims(augmented, axis=0)
    
    # Predict
    predictions = model.predict(img_batch, verbose=0)
    pred_class = np.argmax(predictions[0])
    confidence = predictions[0][pred_class]
    
    return class_names[pred_class], confidence, img

# Test with random samples
n_samples = 6
test_samples = np.random.choice(X_test, n_samples, replace=False)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, img_path in enumerate(test_samples):
    pred_class, confidence, img = predict_single_image(best_model, img_path, CONFIG['classes'])
    true_class = CONFIG['classes'][y_test[np.where(X_test == img_path)[0][0]]]
    
    img_resized = cv2.resize(img, (CONFIG['img_size'], CONFIG['img_size']))
    axes[idx].imshow(img_resized)
    
    color = 'green' if pred_class == true_class else 'red'
    axes[idx].set_title(
        f'Pred: {pred_class} ({confidence:.1%})\nTrue: {true_class}',
        fontsize=11, color=color, fontweight='bold'
    )
    axes[idx].axis('off')

plt.suptitle('Sample Test Predictions', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(MODEL_DIR / 'sample_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

## 10. Summary

In [None]:
print("\n" + "="*60)
print("üéâ TRAINING COMPLETE!")
print("="*60)

print(f"""
üìä RESULTS SUMMARY
{'‚îÄ'*40}
Best Model:      {best_model_name}
Test Accuracy:   {results[best_model_name]['test_accuracy']*100:.2f}%
Test AUC:        {results[best_model_name]['test_auc']:.4f}
ROC AUC:         {roc_auc:.4f}
PR AUC:          {pr_auc:.4f}

üìÅ SAVED FILES
{'‚îÄ'*40}
‚Ä¢ coconut_mite_best.keras       - Full Keras model
‚Ä¢ coconut_mite_best.h5          - H5 format
‚Ä¢ coconut_mite_model.tflite     - TensorFlow Lite
‚Ä¢ coconut_mite_model_quantized.tflite - Quantized (smaller)
‚Ä¢ model_info.json               - Model metadata
‚Ä¢ training_history.png          - Training curves
‚Ä¢ confusion_matrix.png          - Evaluation results
‚Ä¢ roc_pr_curves.png             - ROC & PR curves

üöÄ NEXT STEPS
{'‚îÄ'*40}
1. Deploy model via Flask API
2. Integrate TFLite with React Native app
3. Train models for other pest types
""")