# Coconut Mite Detection - 3-Class Model v9


## Key Improvements from v8:
- **Clean Dataset**: v4_clean with no data leaks or corrupted files
- **Class Weights**: Computed from training data to handle imbalance
- **Focal Loss**: Better handling of hard examples
- **Honest Evaluation**: No tricks, real performance metrics

| Parameter | Value |
|-----------|-------|
| Model | EfficientNetB0 (Transfer Learning) |
| Classes | coconut_mite, healthy, not_coconut |
| Dataset | v4_clean (13,781 images, no leaks) |
| Loss | Focal Loss with class weights |

---
**Author:** Research Team  
**Date:** 2025-12-25  
**Version:** v9 (Robust)

In [1]:
# ============================================================
# 1. IMPORTS & SETUP
# ============================================================
import os
import json
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from datetime import datetime
import random

import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from sklearn.utils.class_weight import compute_class_weight

# Reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
random.seed(SEED)

print("=" * 60)
print("  ENVIRONMENT")
print("=" * 60)
print(f"  TensorFlow: {tf.__version__}")
print(f"  GPU: {len(tf.config.list_physical_devices('GPU')) > 0}")
print("=" * 60)

  ENVIRONMENT
  TensorFlow: 2.20.0
  GPU: False


In [2]:
# ============================================================
# 2. CONFIGURATION
# ============================================================
BASE_DIR = os.path.dirname(os.getcwd())
DATA_DIR = os.path.join(BASE_DIR, 'data', 'raw', 'pest_mite', 'dataset_v4_clean')  # CLEAN DATASET!
MODEL_DIR = os.path.join(BASE_DIR, 'models', 'coconut_mite_v9')

TRAIN_DIR = os.path.join(DATA_DIR, 'train')
VAL_DIR = os.path.join(DATA_DIR, 'validation')
TEST_DIR = os.path.join(DATA_DIR, 'test')

# Hyperparameters - Conservative for honest evaluation
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 100  # More epochs, rely on early stopping
LEARNING_RATE = 0.0001
DROPOUT_RATE = 0.5
L2_REG = 0.01
PATIENCE = 10  # More patience for better convergence

CLASS_NAMES = ['coconut_mite', 'healthy', 'not_coconut']
NUM_CLASSES = 3

os.makedirs(MODEL_DIR, exist_ok=True)

print("=" * 60)
print("  CONFIGURATION - V9 ROBUST")
print("=" * 60)
print(f"\n  Dataset: {DATA_DIR}")
print(f"  Model:   {MODEL_DIR}")
print(f"\n  Hyperparameters:")
print(f"    Image Size:    {IMG_SIZE}x{IMG_SIZE}")
print(f"    Batch Size:    {BATCH_SIZE}")
print(f"    Max Epochs:    {EPOCHS}")
print(f"    Learning Rate: {LEARNING_RATE}")
print(f"    Dropout:       {DROPOUT_RATE}")
print(f"    L2 Reg:        {L2_REG}")
print(f"    Early Stop:    {PATIENCE} epochs")
print("=" * 60)

  CONFIGURATION - V9 ROBUST

  Dataset: d:\SLIIT\Reaserch Project\CoconutHealthMonitor\Research\ml\data\raw\pest_mite\dataset_v4_clean
  Model:   d:\SLIIT\Reaserch Project\CoconutHealthMonitor\Research\ml\models\coconut_mite_v9

  Hyperparameters:
    Image Size:    224x224
    Batch Size:    32
    Max Epochs:    100
    Learning Rate: 0.0001
    Dropout:       0.5
    L2 Reg:        0.01
    Early Stop:    10 epochs


In [3]:
# ============================================================
# 3. DATASET VERIFICATION
# ============================================================
def count_images(directory):
    counts = {}
    for cls in CLASS_NAMES:
        cls_dir = os.path.join(directory, cls)
        if os.path.exists(cls_dir):
            counts[cls] = len([f for f in os.listdir(cls_dir) 
                              if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        else:
            counts[cls] = 0
    return counts

train_counts = count_images(TRAIN_DIR)
val_counts = count_images(VAL_DIR)
test_counts = count_images(TEST_DIR)

print("=" * 60)
print("  DATASET v4_clean (NO DATA LEAKS!)")
print("=" * 60)
print(f"\n  {'Split':<12} {'mite':>8} {'healthy':>10} {'not_coco':>12} {'Total':>8}")
print("  " + "-" * 50)

for split, counts in [('train', train_counts), ('validation', val_counts), ('test', test_counts)]:
    total = sum(counts.values())
    print(f"  {split:<12} {counts['coconut_mite']:>8} {counts['healthy']:>10} {counts['not_coconut']:>12} {total:>8}")

total_all = sum(train_counts.values()) + sum(val_counts.values()) + sum(test_counts.values())
print("  " + "-" * 50)
print(f"  {'TOTAL':<12} {train_counts['coconut_mite']+val_counts['coconut_mite']+test_counts['coconut_mite']:>8} "
      f"{train_counts['healthy']+val_counts['healthy']+test_counts['healthy']:>10} "
      f"{train_counts['not_coconut']+val_counts['not_coconut']+test_counts['not_coconut']:>12} {total_all:>8}")
print("=" * 60)

  DATASET v4_clean (NO DATA LEAKS!)

  Split            mite    healthy     not_coco    Total
  --------------------------------------------------
  train            4739       4204         3985    12928
  validation         99         89          291      479
  test              100         89          185      374
  --------------------------------------------------
  TOTAL            4938       4382         4461    13781


In [4]:
# ============================================================
# 4. COMPUTE CLASS WEIGHTS
# ============================================================
# This helps the model pay equal attention to all classes

train_labels = []
for i, cls in enumerate(CLASS_NAMES):
    train_labels.extend([i] * train_counts[cls])

class_weights_array = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_labels),
    y=train_labels
)

class_weights = {i: w for i, w in enumerate(class_weights_array)}

print("=" * 60)
print("  CLASS WEIGHTS (for balanced learning)")
print("=" * 60)
for i, cls in enumerate(CLASS_NAMES):
    print(f"  {cls}: {class_weights[i]:.4f}")
print("=" * 60)

  CLASS WEIGHTS (for balanced learning)
  coconut_mite: 0.9093
  healthy: 1.0251
  not_coconut: 1.0814


In [5]:
# ============================================================
# 5. DATA GENERATORS
# ============================================================

# Moderate augmentation - not too aggressive
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.15,
    height_shift_range=0.15,
    shear_range=0.15,
    zoom_range=0.15,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=SEED
)

val_generator = val_test_datagen.flow_from_directory(
    VAL_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

test_generator = val_test_datagen.flow_from_directory(
    TEST_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

print("=" * 60)
print("  DATA GENERATORS")
print("=" * 60)
print(f"  Train: {train_generator.samples} samples, {len(train_generator)} batches")
print(f"  Val:   {val_generator.samples} samples")
print(f"  Test:  {test_generator.samples} samples")
print(f"\n  Class indices: {train_generator.class_indices}")
print("=" * 60)

Found 12928 images belonging to 3 classes.
Found 479 images belonging to 3 classes.
Found 374 images belonging to 3 classes.
  DATA GENERATORS
  Train: 12928 samples, 404 batches
  Val:   479 samples
  Test:  374 samples

  Class indices: {'coconut_mite': 0, 'healthy': 1, 'not_coconut': 2}


In [6]:
# ============================================================
# 6. FOCAL LOSS (Better for imbalanced/hard examples)
# ============================================================

class FocalLoss(tf.keras.losses.Loss):
    """Focal Loss for handling class imbalance and hard examples."""
    
    def __init__(self, gamma=2.0, alpha=None, **kwargs):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha
    
    def call(self, y_true, y_pred):
        # Clip predictions to prevent log(0)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        
        # Cross entropy
        ce = -y_true * tf.math.log(y_pred)
        
        # Focal weight
        weight = tf.pow(1 - y_pred, self.gamma) * y_true
        
        # Apply alpha if provided
        if self.alpha is not None:
            weight = weight * self.alpha
        
        focal_loss = weight * ce
        return tf.reduce_mean(tf.reduce_sum(focal_loss, axis=-1))

print("Focal Loss defined with gamma=2.0")

Focal Loss defined with gamma=2.0


In [7]:
# ============================================================
# 7. BUILD MODEL
# ============================================================

def build_model():
    """Build EfficientNetB0 for 3-class classification."""
    
    # Base model
    base_model = EfficientNetB0(
        weights='imagenet',
        include_top=False,
        input_shape=(IMG_SIZE, IMG_SIZE, 3)
    )
    
    # Freeze base model initially
    base_model.trainable = False
    
    # Custom head
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dense(128, activation='relu', kernel_regularizer=l2(L2_REG))(x)
    x = Dropout(DROPOUT_RATE)(x)
    x = Dense(64, activation='relu', kernel_regularizer=l2(L2_REG))(x)
    x = Dropout(DROPOUT_RATE)(x)
    outputs = Dense(NUM_CLASSES, activation='softmax')(x)
    
    model = Model(inputs=base_model.input, outputs=outputs)
    
    # Compile with focal loss
    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE),
        loss=FocalLoss(gamma=2.0),
        metrics=['accuracy']
    )
    
    return model, base_model

model, base_model = build_model()

print("=" * 60)
print("  MODEL ARCHITECTURE")
print("=" * 60)
print(f"  Total params: {model.count_params():,}")
trainable = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
print(f"  Trainable: {trainable:,}")
print(f"  Non-trainable: {model.count_params() - trainable:,}")
print("=" * 60)

  MODEL ARCHITECTURE
  Total params: 4,227,110
  Trainable: 174,979
  Non-trainable: 4,052,131


In [8]:
# ============================================================
# 8. CALLBACKS
# ============================================================

checkpoint = ModelCheckpoint(
    os.path.join(MODEL_DIR, 'best_model.keras'),
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stop = EarlyStopping(
    monitor='val_accuracy',
    patience=PATIENCE,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-7,
    verbose=1
)

callbacks = [checkpoint, early_stop, reduce_lr]

print("Callbacks configured: ModelCheckpoint, EarlyStopping, ReduceLROnPlateau")

Callbacks configured: ModelCheckpoint, EarlyStopping, ReduceLROnPlateau


In [9]:
# ============================================================
# 9. PHASE 1: TRAIN WITH FROZEN BASE
# ============================================================

print("\n" + "="*60)
print("  PHASE 1: Training with frozen EfficientNetB0")
print("="*60)

start_time = datetime.now()

history1 = model.fit(
    train_generator,
    epochs=30,  # Initial training
    validation_data=val_generator,
    callbacks=callbacks,
    class_weight=class_weights,
    verbose=1
)

phase1_time = (datetime.now() - start_time).total_seconds() / 60
print(f"\nPhase 1 completed in {phase1_time:.1f} minutes")
print(f"Best val_accuracy: {max(history1.history['val_accuracy'])*100:.2f}%")


  PHASE 1: Training with frozen EfficientNetB0
Epoch 1/30
[1m404/404[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.3456 - loss: 3.7272
Epoch 1: val_accuracy improved from None to 0.55115, saving model to d:\SLIIT\Reaserch Project\CoconutHealthMonitor\Research\ml\models\coconut_mite_v9\best_model.keras
[1m404/404[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m582s[0m 1s/step - accuracy: 0.3424 - loss: 3.5097 - val_accuracy: 0.5511 - val_loss: 3.1150 - learning_rate: 1.0000e-04
Epoch 2/30
[1m260/404[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m2:40[0m 1s/step - accuracy: 0.3555 - loss: 3.0684

KeyboardInterrupt: 

In [None]:
# ============================================================
# 10. PHASE 2: FINE-TUNE TOP LAYERS
# ============================================================

print("\n" + "="*60)
print("  PHASE 2: Fine-tuning top layers of EfficientNetB0")
print("="*60)

# Unfreeze top 20 layers
base_model.trainable = True
for layer in base_model.layers[:-20]:
    layer.trainable = False

# Recompile with lower learning rate
model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE / 10),
    loss=FocalLoss(gamma=2.0),
    metrics=['accuracy']
)

trainable = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
print(f"Trainable params after unfreezing: {trainable:,}")

# Continue training
start_time = datetime.now()

history2 = model.fit(
    train_generator,
    epochs=50,
    validation_data=val_generator,
    callbacks=callbacks,
    class_weight=class_weights,
    verbose=1
)

phase2_time = (datetime.now() - start_time).total_seconds() / 60
print(f"\nPhase 2 completed in {phase2_time:.1f} minutes")
print(f"Best val_accuracy: {max(history2.history['val_accuracy'])*100:.2f}%")

In [None]:
# ============================================================
# 11. COMBINE TRAINING HISTORY
# ============================================================

history = {
    'accuracy': history1.history['accuracy'] + history2.history['accuracy'],
    'val_accuracy': history1.history['val_accuracy'] + history2.history['val_accuracy'],
    'loss': history1.history['loss'] + history2.history['loss'],
    'val_loss': history1.history['val_loss'] + history2.history['val_loss']
}

total_training_time = phase1_time + phase2_time

# Save history
with open(os.path.join(MODEL_DIR, 'training_history.json'), 'w') as f:
    json.dump({k: [float(v) for v in vals] for k, vals in history.items()}, f, indent=2)

print(f"Total training time: {total_training_time:.1f} minutes")
print(f"Total epochs: {len(history['accuracy'])}")

In [None]:
# ============================================================
# 12. PLOT TRAINING HISTORY
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

epochs_range = range(1, len(history['accuracy']) + 1)

# Accuracy
axes[0].plot(epochs_range, [x*100 for x in history['accuracy']], 'b-', label='Train', linewidth=2)
axes[0].plot(epochs_range, [x*100 for x in history['val_accuracy']], 'r-', label='Validation', linewidth=2)
axes[0].axvline(x=len(history1.history['accuracy']), color='green', linestyle='--', alpha=0.5, label='Fine-tune start')

best_epoch = np.argmax(history['val_accuracy']) + 1
best_val_acc = max(history['val_accuracy']) * 100
axes[0].scatter([best_epoch], [best_val_acc], color='green', s=100, zorder=5)
axes[0].annotate(f'Best: {best_val_acc:.1f}%', xy=(best_epoch, best_val_acc), 
                 xytext=(best_epoch+2, best_val_acc-3), fontsize=10)

axes[0].set_title('Model Accuracy', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy (%)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(epochs_range, history['loss'], 'b-', label='Train', linewidth=2)
axes[1].plot(epochs_range, history['val_loss'], 'r-', label='Validation', linewidth=2)
axes[1].axvline(x=len(history1.history['loss']), color='green', linestyle='--', alpha=0.5, label='Fine-tune start')
axes[1].set_title('Model Loss', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'training_history.png'), dpi=150, bbox_inches='tight')
plt.show()

# Calculate gap
final_train = history['accuracy'][-1] * 100
final_val = history['val_accuracy'][-1] * 100
gap = abs(final_train - final_val)
print(f"\nFinal Train: {final_train:.2f}%, Val: {final_val:.2f}%, Gap: {gap:.2f}%")

In [None]:
# ============================================================
# 13. LOAD BEST MODEL & EVALUATE ON TEST SET
# ============================================================

print("\n" + "="*60)
print("  TEST SET EVALUATION (HONEST METRICS)")
print("="*60)

# Load best model
best_model = tf.keras.models.load_model(
    os.path.join(MODEL_DIR, 'best_model.keras'),
    custom_objects={'FocalLoss': FocalLoss}
)

# Get predictions
test_generator.reset()
y_probs = best_model.predict(test_generator, verbose=1)
y_pred = np.argmax(y_probs, axis=1)
y_true = test_generator.classes

# Class names in correct order
class_names_ordered = list(test_generator.class_indices.keys())

print(f"\nTest samples: {len(y_true)}")
print(f"Classes: {class_names_ordered}")

In [None]:
# ============================================================
# 14. CLASSIFICATION REPORT
# ============================================================

print("\n" + "="*60)
print("  CLASSIFICATION REPORT")
print("="*60)
print(classification_report(y_true, y_pred, target_names=class_names_ordered, digits=4))

In [None]:
# ============================================================
# 15. DETAILED METRICS
# ============================================================

from sklearn.metrics import precision_recall_fscore_support

accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)
macro_f1 = f1_score(y_true, y_pred, average='macro')

print("\n" + "="*60)
print("  DETAILED METRICS")
print("="*60)
print(f"\n  Overall Accuracy: {accuracy*100:.2f}%")
print(f"  Macro F1-Score:   {macro_f1*100:.2f}%")

print("\n  Per-Class:")
print("  " + "-"*50)
for i, cls in enumerate(class_names_ordered):
    print(f"  {cls}:")
    print(f"    Precision: {precision[i]*100:.2f}%")
    print(f"    Recall:    {recall[i]*100:.2f}%")
    print(f"    F1-Score:  {f1[i]*100:.2f}%")
    print(f"    Support:   {support[i]}")

In [None]:
# ============================================================
# 16. CONFUSION MATRIX
# ============================================================

cm = confusion_matrix(y_true, y_pred)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=class_names_ordered, yticklabels=class_names_ordered)
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')

# Normalized
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues', ax=axes[1],
            xticklabels=class_names_ordered, yticklabels=class_names_ordered)
axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'confusion_matrix.png'), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# 17. REQUIREMENTS CHECK
# ============================================================

print("\n" + "="*60)
print("  UTHPALA MISS REQUIREMENTS CHECK")
print("="*60)

all_pass = True

# Req 1: P/R/F1 balanced per class
print("\n  [1] P/R/F1 Balanced per Class (gap < 0.10)")
for i, cls in enumerate(class_names_ordered):
    gap = max(precision[i], recall[i], f1[i]) - min(precision[i], recall[i], f1[i])
    status = "PASS" if gap < 0.10 else "FAIL"
    if gap >= 0.10:
        all_pass = False
    print(f"      {cls}: P={precision[i]:.2f}, R={recall[i]:.2f}, F1={f1[i]:.2f} | Gap={gap:.4f} [{status}]")

# Req 2: Accuracy ~ F1
print("\n  [2] Accuracy ~ F1-Score (diff < 0.05)")
diff = abs(accuracy - macro_f1)
status2 = "PASS" if diff < 0.05 else "FAIL"
if diff >= 0.05:
    all_pass = False
print(f"      Acc={accuracy:.4f}, F1={macro_f1:.4f}, Diff={diff:.4f} [{status2}]")

# Req 3: Class F1 similar
print("\n  [3] Class F1-Scores Similar (max diff < 0.15)")
f1_diff = max(f1) - min(f1)
status3 = "PASS" if f1_diff < 0.15 else "FAIL"
if f1_diff >= 0.15:
    all_pass = False
print(f"      F1s: {[f'{x:.2f}' for x in f1]}, Max diff={f1_diff:.4f} [{status3}]")

# Req 4: Train-Val gap
print("\n  [4] Train-Val Gap < 15%")
train_val_gap = gap  # from earlier
status4 = "PASS" if train_val_gap < 15 else "FAIL"
if train_val_gap >= 15:
    all_pass = False
print(f"      Gap={train_val_gap:.2f}% [{status4}]")

print("\n" + "="*60)
print(f"  FINAL: {'ALL REQUIREMENTS PASSED!' if all_pass else 'SOME REQUIREMENTS NEED WORK'}")
print("="*60)

In [None]:
# ============================================================
# 18. SAVE MODEL INFO
# ============================================================

model_info = {
    'model_name': 'coconut_mite_3class_v9_robust',
    'version': 'v9_robust',
    'architecture': 'EfficientNetB0 + Fine-tuning',
    'num_classes': NUM_CLASSES,
    'classes': class_names_ordered,
    'input_size': [IMG_SIZE, IMG_SIZE, 3],
    'dataset': {
        'version': 'v4_clean',
        'train': train_generator.samples,
        'validation': val_generator.samples,
        'test': test_generator.samples,
        'data_leaks': 0,
        'corrupted_files': 0
    },
    'performance': {
        'test_accuracy': float(accuracy),
        'macro_f1': float(macro_f1),
        'per_class': [
            {
                'class': class_names_ordered[i],
                'precision': float(precision[i]),
                'recall': float(recall[i]),
                'f1': float(f1[i]),
                'support': int(support[i])
            }
            for i in range(NUM_CLASSES)
        ],
        'confusion_matrix': cm.tolist()
    },
    'training': {
        'total_epochs': len(history['accuracy']),
        'best_epoch': int(best_epoch),
        'training_time_minutes': float(total_training_time),
        'train_val_gap': float(train_val_gap),
        'loss_function': 'FocalLoss(gamma=2.0)',
        'class_weights_used': True
    },
    'requirements_check': {
        'pr_balanced': all(max(precision[i], recall[i], f1[i]) - min(precision[i], recall[i], f1[i]) < 0.10 for i in range(NUM_CLASSES)),
        'acc_equals_f1': diff < 0.05,
        'f1_similar': f1_diff < 0.15,
        'gap_ok': train_val_gap < 15,
        'all_pass': all_pass
    },
    'timestamp': datetime.now().isoformat()
}

with open(os.path.join(MODEL_DIR, 'model_info.json'), 'w') as f:
    json.dump(model_info, f, indent=2)

print("Model info saved!")

In [None]:
# ============================================================
# 19. FINAL SUMMARY
# ============================================================

print("\n")
print("=" * 60)
print("  V9 ROBUST MODEL - TRAINING COMPLETE")
print("=" * 60)
print(f"\n  Dataset:    v4_clean ({train_generator.samples + val_generator.samples + test_generator.samples} images)")
print(f"  Training:   {total_training_time:.1f} minutes")
print(f"  Best Epoch: {best_epoch}")
print(f"\n  TEST RESULTS:")
print(f"    Accuracy:   {accuracy*100:.2f}%")
print(f"    Macro F1:   {macro_f1*100:.2f}%")
print(f"\n  Per-Class F1:")
for i, cls in enumerate(class_names_ordered):
    print(f"    {cls}: {f1[i]*100:.2f}%")
print(f"\n  Train-Val Gap: {train_val_gap:.2f}%")
print(f"\n  Requirements: {'ALL PASS' if all_pass else 'NEEDS WORK'}")
print("\n  Model saved to:", MODEL_DIR)
print("=" * 60)