# Coconut Disease Detection Model v1

## 4-Class Classification: Leaf Rot, Leaf Spot, Healthy, Not Coconut

**Author:** Coconut Health Monitor Team  
**Date:** January 2026  
**Model:** EfficientNetB0 (Transfer Learning)  

### Objectives:
1. Train a disease detection model for coconut leaves
2. Prevent data leaking and overfitting
3. Achieve balanced Precision, Recall, F1-score across all classes
4. Ensure accuracy is close to F1-score

### Classes:
- `healthy` - Healthy coconut leaves
- `Leaf Rot` - Leaves affected by rot disease
- `Leaf_Spot` - Leaves with spot disease
- `not_cocount` - Non-coconut images (rejection class)

---
## 1. Import Libraries

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')

# TensorFlow and Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers

# Sklearn for metrics
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score
from sklearn.utils.class_weight import compute_class_weight

# Check GPU
print(f"TensorFlow Version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

---
## 2. Configuration

In [None]:
# Paths
BASE_PATH = r'D:\SLIIT\Reaserch Project\CoconutHealthMonitor\Research\ml\data\raw\stage_2_split'
TRAIN_PATH = os.path.join(BASE_PATH, 'train')
VAL_PATH = os.path.join(BASE_PATH, 'val')
TEST_PATH = os.path.join(BASE_PATH, 'test')

# Model save path
MODEL_SAVE_PATH = r'D:\SLIIT\Reaserch Project\CoconutHealthMonitor\Research\ml\models\disease_detection_v1'
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# Hyperparameters
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 50  # Will use early stopping
LEARNING_RATE = 0.0001
DROPOUT_RATE = 0.5

# Class names (alphabetical order - how ImageDataGenerator loads them)
CLASS_NAMES = ['Leaf Rot', 'Leaf_Spot', 'healthy', 'not_cocount']
NUM_CLASSES = len(CLASS_NAMES)

print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Number of Classes: {NUM_CLASSES}")
print(f"Classes: {CLASS_NAMES}")

---
## 3. Explore Dataset

In [None]:
def count_images(path):
    """Count images in each class folder"""
    counts = {}
    for class_name in os.listdir(path):
        class_path = os.path.join(path, class_name)
        if os.path.isdir(class_path):
            counts[class_name] = len([f for f in os.listdir(class_path) 
                                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    return counts

# Count images in each split
train_counts = count_images(TRAIN_PATH)
val_counts = count_images(VAL_PATH)
test_counts = count_images(TEST_PATH)

print("=" * 60)
print("DATASET DISTRIBUTION")
print("=" * 60)

# Create DataFrame for better visualization
df_counts = pd.DataFrame({
    'Train': train_counts,
    'Validation': val_counts,
    'Test': test_counts
}).T

df_counts['Total'] = df_counts.sum(axis=1)
print(df_counts)
print("\n" + "=" * 60)
print(f"Total Train Images: {sum(train_counts.values())}")
print(f"Total Validation Images: {sum(val_counts.values())}")
print(f"Total Test Images: {sum(test_counts.values())}")

In [None]:
# Visualize class distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Train distribution
axes[0].bar(train_counts.keys(), train_counts.values(), color=['#2ecc71', '#e74c3c', '#3498db', '#9b59b6'])
axes[0].set_title('Training Set Distribution', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Class')
axes[0].set_ylabel('Number of Images')
axes[0].tick_params(axis='x', rotation=45)
for i, (k, v) in enumerate(train_counts.items()):
    axes[0].text(i, v + 100, str(v), ha='center', fontweight='bold')

# Validation distribution
axes[1].bar(val_counts.keys(), val_counts.values(), color=['#2ecc71', '#e74c3c', '#3498db', '#9b59b6'])
axes[1].set_title('Validation Set Distribution', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Class')
axes[1].set_ylabel('Number of Images')
axes[1].tick_params(axis='x', rotation=45)
for i, (k, v) in enumerate(val_counts.items()):
    axes[1].text(i, v + 5, str(v), ha='center', fontweight='bold')

# Test distribution
axes[2].bar(test_counts.keys(), test_counts.values(), color=['#2ecc71', '#e74c3c', '#3498db', '#9b59b6'])
axes[2].set_title('Test Set Distribution', fontsize=14, fontweight='bold')
axes[2].set_xlabel('Class')
axes[2].set_ylabel('Number of Images')
axes[2].tick_params(axis='x', rotation=45)
for i, (k, v) in enumerate(test_counts.items()):
    axes[2].text(i, v + 5, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'dataset_distribution.png'), dpi=150, bbox_inches='tight')
plt.show()

---
## 4. Data Loading with Augmentation

**Important:** 
- Training data gets augmentation (rotation, flip, zoom, etc.)
- Validation and Test data only get rescaling (NO augmentation to prevent data leaking)

In [None]:
# Training data generator WITH augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest'
)

# Validation and Test data generator - ONLY rescaling (NO augmentation!)
val_test_datagen = ImageDataGenerator(
    rescale=1./255
)

print("Data generators created successfully!")
print("- Training: With augmentation (rotation, flip, zoom, shift)")
print("- Validation/Test: Only rescaling (no augmentation to prevent data leaking)")

In [None]:
# Load training data
train_generator = train_datagen.flow_from_directory(
    TRAIN_PATH,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=SEED
)

# Load validation data (NO augmentation)
val_generator = val_test_datagen.flow_from_directory(
    VAL_PATH,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False  # Don't shuffle for consistent evaluation
)

# Load test data (NO augmentation)
test_generator = val_test_datagen.flow_from_directory(
    TEST_PATH,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False  # Don't shuffle for consistent evaluation
)

# Get class indices
print("\nClass Indices (alphabetical order):")
print(train_generator.class_indices)

# Store class names in correct order
CLASS_NAMES = list(train_generator.class_indices.keys())
print(f"\nClass Names: {CLASS_NAMES}")

In [None]:
# Visualize some training images with augmentation
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

# Get a batch of images
sample_batch, sample_labels = next(train_generator)

for i in range(8):
    axes[i].imshow(sample_batch[i])
    class_idx = np.argmax(sample_labels[i])
    axes[i].set_title(f'Class: {CLASS_NAMES[class_idx]}', fontsize=10)
    axes[i].axis('off')

plt.suptitle('Sample Training Images (with Augmentation)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'sample_images.png'), dpi=150, bbox_inches='tight')
plt.show()

---
## 5. Calculate Class Weights

Handle class imbalance by computing class weights

In [None]:
# Calculate class weights to handle imbalance
train_labels = train_generator.classes

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_labels),
    y=train_labels
)

class_weight_dict = dict(enumerate(class_weights))

print("Class Weights (to handle imbalance):")
print("=" * 40)
for idx, weight in class_weight_dict.items():
    print(f"  {CLASS_NAMES[idx]}: {weight:.4f}")

print("\n(Higher weight = less samples = model pays more attention)")

---
## 6. Build Model Architecture

Using **EfficientNetB0** with Transfer Learning:
- Pre-trained on ImageNet
- Fine-tune top layers
- Add dropout for regularization
- Add L2 regularization to prevent overfitting

In [None]:
def build_model(num_classes, img_size=224, dropout_rate=0.5):
    """
    Build EfficientNetB0 model with custom classification head
    
    Args:
        num_classes: Number of output classes
        img_size: Input image size
        dropout_rate: Dropout rate for regularization
    
    Returns:
        Compiled Keras model
    """
    # Load pre-trained EfficientNetB0
    base_model = EfficientNetB0(
        weights='imagenet',
        include_top=False,
        input_shape=(img_size, img_size, 3)
    )
    
    # Freeze base model layers initially
    base_model.trainable = False
    
    # Build custom classification head
    inputs = keras.Input(shape=(img_size, img_size, 3))
    x = base_model(inputs, training=False)
    
    # Global Average Pooling
    x = layers.GlobalAveragePooling2D()(x)
    
    # Dense layers with regularization
    x = layers.Dense(
        256, 
        activation='relu',
        kernel_regularizer=regularizers.l2(0.01)
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    
    x = layers.Dense(
        128, 
        activation='relu',
        kernel_regularizer=regularizers.l2(0.01)
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    
    # Output layer
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    
    return model, base_model

# Build the model
model, base_model = build_model(
    num_classes=NUM_CLASSES, 
    img_size=IMG_SIZE, 
    dropout_rate=DROPOUT_RATE
)

print("Model Architecture Summary:")
model.summary()

In [None]:
# Compile model
model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print("Model compiled successfully!")
print(f"  Optimizer: Adam (lr={LEARNING_RATE})")
print(f"  Loss: Categorical Crossentropy")
print(f"  Metrics: Accuracy")

---
## 7. Setup Callbacks

Callbacks to prevent overfitting and save best model

In [None]:
# Callbacks
callbacks = [
    # Early stopping - stop if validation loss doesn't improve
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    
    # Save best model based on validation accuracy
    ModelCheckpoint(
        filepath=os.path.join(MODEL_SAVE_PATH, 'best_model.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    
    # Reduce learning rate when validation loss plateaus
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    )
]

print("Callbacks configured:")
print("  1. EarlyStopping (patience=10, monitor=val_loss)")
print("  2. ModelCheckpoint (save best model by val_accuracy)")
print("  3. ReduceLROnPlateau (factor=0.5, patience=5)")

---
## 8. Phase 1: Train with Frozen Base Model

In [None]:
print("=" * 60)
print("PHASE 1: Training with Frozen Base Model")
print("=" * 60)
print("Training only the classification head...")
print(f"Base model layers: {len(base_model.layers)} (all frozen)")
print()

# Train Phase 1
history_phase1 = model.fit(
    train_generator,
    epochs=15,  # Initial training
    validation_data=val_generator,
    class_weight=class_weight_dict,
    callbacks=callbacks,
    verbose=1
)

print("\nPhase 1 training completed!")

---
## 9. Phase 2: Fine-tune Top Layers of Base Model

In [None]:
print("=" * 60)
print("PHASE 2: Fine-tuning Top Layers")
print("=" * 60)

# Unfreeze the top layers of the base model
base_model.trainable = True

# Freeze all layers except the last 20
for layer in base_model.layers[:-20]:
    layer.trainable = False

# Count trainable layers
trainable_layers = sum(1 for layer in base_model.layers if layer.trainable)
print(f"Unfrozen top layers: {trainable_layers}")

# Re-compile with lower learning rate
model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE / 10),  # Lower LR for fine-tuning
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print(f"New learning rate: {LEARNING_RATE / 10}")
print("\nStarting fine-tuning...")

In [None]:
# Train Phase 2 (Fine-tuning)
history_phase2 = model.fit(
    train_generator,
    epochs=EPOCHS,
    initial_epoch=len(history_phase1.history['loss']),
    validation_data=val_generator,
    class_weight=class_weight_dict,
    callbacks=callbacks,
    verbose=1
)

print("\nPhase 2 fine-tuning completed!")

---
## 10. Plot Training History

In [None]:
# Combine histories
def combine_histories(h1, h2):
    """Combine two training histories"""
    combined = {}
    for key in h1.history.keys():
        combined[key] = h1.history[key] + h2.history[key]
    return combined

history = combine_histories(history_phase1, history_phase2)

# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy plot
axes[0].plot(history['accuracy'], label='Train Accuracy', linewidth=2)
axes[0].plot(history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[0].axvline(x=len(history_phase1.history['loss'])-1, color='r', linestyle='--', label='Fine-tuning Start')
axes[0].set_title('Model Accuracy', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss plot
axes[1].plot(history['loss'], label='Train Loss', linewidth=2)
axes[1].plot(history['val_loss'], label='Validation Loss', linewidth=2)
axes[1].axvline(x=len(history_phase1.history['loss'])-1, color='r', linestyle='--', label='Fine-tuning Start')
axes[1].set_title('Model Loss', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'training_history.png'), dpi=150, bbox_inches='tight')
plt.show()

# Print best metrics
best_val_acc = max(history['val_accuracy'])
best_epoch = history['val_accuracy'].index(best_val_acc) + 1
print(f"\nBest Validation Accuracy: {best_val_acc:.4f} (Epoch {best_epoch})")

---
## 11. Load Best Model and Evaluate on Test Set

In [None]:
# Load best model
best_model_path = os.path.join(MODEL_SAVE_PATH, 'best_model.keras')
model = keras.models.load_model(best_model_path)
print(f"Loaded best model from: {best_model_path}")

# Evaluate on test set
print("\n" + "=" * 60)
print("EVALUATING ON TEST SET")
print("=" * 60)

test_loss, test_accuracy = model.evaluate(test_generator, verbose=1)
print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")

---
## 12. Generate Predictions

In [None]:
# Get predictions on test set
test_generator.reset()
predictions = model.predict(test_generator, verbose=1)

# Get predicted classes
y_pred = np.argmax(predictions, axis=1)

# Get true labels
y_true = test_generator.classes

print(f"Total test samples: {len(y_true)}")
print(f"Total predictions: {len(y_pred)}")

---
## 13. Class-wise Metrics (Precision, Recall, F1-Score)

**Key Requirement:** Precision, Recall, and F1-score should be close to each other for each class

In [None]:
# Generate classification report
print("=" * 70)
print("CLASSIFICATION REPORT (Class-wise Metrics)")
print("=" * 70)

report = classification_report(
    y_true, 
    y_pred, 
    target_names=CLASS_NAMES,
    digits=4
)
print(report)

# Get metrics as dictionary for analysis
report_dict = classification_report(
    y_true, 
    y_pred, 
    target_names=CLASS_NAMES, 
    output_dict=True
)

In [None]:
# Create DataFrame for better visualization
metrics_data = []
for class_name in CLASS_NAMES:
    metrics_data.append({
        'Class': class_name,
        'Precision': report_dict[class_name]['precision'],
        'Recall': report_dict[class_name]['recall'],
        'F1-Score': report_dict[class_name]['f1-score'],
        'Support': report_dict[class_name]['support']
    })

metrics_df = pd.DataFrame(metrics_data)

# Calculate metric differences
metrics_df['P-R Diff'] = abs(metrics_df['Precision'] - metrics_df['Recall'])
metrics_df['P-F1 Diff'] = abs(metrics_df['Precision'] - metrics_df['F1-Score'])
metrics_df['R-F1 Diff'] = abs(metrics_df['Recall'] - metrics_df['F1-Score'])

print("\n" + "=" * 70)
print("CLASS-WISE METRICS SUMMARY")
print("=" * 70)
print(metrics_df.to_string(index=False))

# Check if metrics are balanced
print("\n" + "=" * 70)
print("METRICS BALANCE CHECK")
print("=" * 70)
max_diff = max(metrics_df['P-R Diff'].max(), metrics_df['P-F1 Diff'].max(), metrics_df['R-F1 Diff'].max())
print(f"Maximum Precision-Recall Difference: {metrics_df['P-R Diff'].max():.4f}")
print(f"Maximum Precision-F1 Difference: {metrics_df['P-F1 Diff'].max():.4f}")
print(f"Maximum Recall-F1 Difference: {metrics_df['R-F1 Diff'].max():.4f}")

if max_diff < 0.10:
    print("\n✅ GOOD: Metrics are well balanced (difference < 10%)")
elif max_diff < 0.15:
    print("\n⚠️ ACCEPTABLE: Metrics are reasonably balanced (difference < 15%)")
else:
    print("\n❌ WARNING: Metrics have significant imbalance (difference > 15%)")

In [None]:
# Visualize class-wise metrics
fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(CLASS_NAMES))
width = 0.25

bars1 = ax.bar(x - width, metrics_df['Precision'], width, label='Precision', color='#3498db')
bars2 = ax.bar(x, metrics_df['Recall'], width, label='Recall', color='#2ecc71')
bars3 = ax.bar(x + width, metrics_df['F1-Score'], width, label='F1-Score', color='#e74c3c')

# Add value labels on bars
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.2f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Class-wise Precision, Recall, and F1-Score', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(CLASS_NAMES, rotation=15)
ax.legend()
ax.set_ylim(0, 1.15)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'class_metrics.png'), dpi=150, bbox_inches='tight')
plt.show()

---
## 14. Confusion Matrix

In [None]:
# Generate confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[0])
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')

# Normalized (percentages)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[1])
axes[1].set_title('Confusion Matrix (Normalized %)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'confusion_matrix.png'), dpi=150, bbox_inches='tight')
plt.show()

---
## 15. Overall Metrics Summary

In [None]:
# Calculate overall metrics
overall_precision = precision_score(y_true, y_pred, average='weighted')
overall_recall = recall_score(y_true, y_pred, average='weighted')
overall_f1 = f1_score(y_true, y_pred, average='weighted')
macro_f1 = f1_score(y_true, y_pred, average='macro')

print("=" * 60)
print("OVERALL MODEL PERFORMANCE SUMMARY")
print("=" * 60)
print(f"\n  Test Accuracy:       {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"  Weighted Precision:  {overall_precision:.4f} ({overall_precision*100:.2f}%)")
print(f"  Weighted Recall:     {overall_recall:.4f} ({overall_recall*100:.2f}%)")
print(f"  Weighted F1-Score:   {overall_f1:.4f} ({overall_f1*100:.2f}%)")
print(f"  Macro F1-Score:      {macro_f1:.4f} ({macro_f1*100:.2f}%)")

# Check accuracy vs F1 difference
acc_f1_diff = abs(test_accuracy - overall_f1)
print(f"\n  Accuracy - F1 Diff:  {acc_f1_diff:.4f}")

print("\n" + "=" * 60)
print("REQUIREMENTS CHECK")
print("=" * 60)

# Check requirements
if acc_f1_diff < 0.05:
    print("✅ Accuracy is close to F1-Score (diff < 5%)")
else:
    print(f"⚠️ Accuracy and F1-Score have some difference ({acc_f1_diff:.2%})")

# Check class balance
f1_scores = [report_dict[c]['f1-score'] for c in CLASS_NAMES]
f1_std = np.std(f1_scores)
if f1_std < 0.10:
    print("✅ F1-Scores are similar across classes (std < 10%)")
else:
    print(f"⚠️ F1-Scores vary across classes (std = {f1_std:.2%})")

In [None]:
# Visualize overall metrics
fig, ax = plt.subplots(figsize=(10, 6))

metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score (W)', 'F1-Score (M)']
values = [test_accuracy, overall_precision, overall_recall, overall_f1, macro_f1]
colors = ['#3498db', '#2ecc71', '#e74c3c', '#9b59b6', '#f39c12']

bars = ax.bar(metrics, values, color=colors)

# Add value labels
for bar, val in zip(bars, values):
    ax.annotate(f'{val:.2%}',
                xy=(bar.get_x() + bar.get_width() / 2, val),
                xytext=(0, 5),
                textcoords="offset points",
                ha='center', va='bottom', fontsize=12, fontweight='bold')

ax.set_ylabel('Score', fontsize=12)
ax.set_title('Overall Model Performance Metrics', fontsize=14, fontweight='bold')
ax.set_ylim(0, 1.1)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_SAVE_PATH, 'overall_metrics.png'), dpi=150, bbox_inches='tight')
plt.show()

---
## 16. Save Model Info

In [None]:
# Save model information
model_info = {
    'model_name': 'Coconut Disease Detection Model',
    'version': 'v1',
    'architecture': 'EfficientNetB0 (Transfer Learning)',
    'input_size': [IMG_SIZE, IMG_SIZE, 3],
    'num_classes': NUM_CLASSES,
    'classes': CLASS_NAMES,
    'class_indices': train_generator.class_indices,
    'performance': {
        'test_accuracy': float(test_accuracy),
        'weighted_precision': float(overall_precision),
        'weighted_recall': float(overall_recall),
        'weighted_f1': float(overall_f1),
        'macro_f1': float(macro_f1)
    },
    'class_metrics': {
        class_name: {
            'precision': float(report_dict[class_name]['precision']),
            'recall': float(report_dict[class_name]['recall']),
            'f1_score': float(report_dict[class_name]['f1-score']),
            'support': int(report_dict[class_name]['support'])
        } for class_name in CLASS_NAMES
    },
    'training_config': {
        'batch_size': BATCH_SIZE,
        'initial_learning_rate': LEARNING_RATE,
        'dropout_rate': DROPOUT_RATE,
        'augmentation': True,
        'class_weights': {CLASS_NAMES[k]: float(v) for k, v in class_weight_dict.items()}
    },
    'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'total_epochs': len(history['loss']),
    'best_epoch': best_epoch
}

# Save to JSON
info_path = os.path.join(MODEL_SAVE_PATH, 'model_info.json')
with open(info_path, 'w') as f:
    json.dump(model_info, f, indent=2)

print(f"Model info saved to: {info_path}")
print("\nModel Info Summary:")
print(json.dumps(model_info, indent=2))

---
## 17. Final Summary

In [None]:
print("\n" + "=" * 70)
print("                    TRAINING COMPLETE - FINAL SUMMARY")
print("=" * 70)

print(f"""
Model: Coconut Disease Detection v1
Architecture: EfficientNetB0 (Transfer Learning)

Classes ({NUM_CLASSES}):
  - Leaf Rot (Coconut leaf rot disease)
  - Leaf_Spot (Coconut leaf spot disease)
  - healthy (Healthy coconut leaves)
  - not_cocount (Non-coconut images - rejection class)

Performance:
  ┌──────────────────┬──────────────┐
  │ Metric           │ Value        │
  ├──────────────────┼──────────────┤
  │ Test Accuracy    │ {test_accuracy:.2%}      │
  │ Precision (W)    │ {overall_precision:.2%}      │
  │ Recall (W)       │ {overall_recall:.2%}      │
  │ F1-Score (W)     │ {overall_f1:.2%}      │
  │ F1-Score (Macro) │ {macro_f1:.2%}      │
  └──────────────────┴──────────────┘

Files Saved:
  - {os.path.join(MODEL_SAVE_PATH, 'best_model.keras')}
  - {os.path.join(MODEL_SAVE_PATH, 'model_info.json')}
  - {os.path.join(MODEL_SAVE_PATH, 'training_history.png')}
  - {os.path.join(MODEL_SAVE_PATH, 'confusion_matrix.png')}
  - {os.path.join(MODEL_SAVE_PATH, 'class_metrics.png')}
  - {os.path.join(MODEL_SAVE_PATH, 'overall_metrics.png')}
""")

print("=" * 70)
print("                          ✅ TRAINING SUCCESSFUL")
print("=" * 70)