# Coconut Tree Branch Health Detection Model v1

This notebook trains a model to detect healthy vs unhealthy coconut tree branches.

**Model Configuration:**
- **Architecture:** MobileNetV2 (Transfer Learning)
- **Loss Function:** Focal Loss (gamma=2.0) for handling class imbalance
- **Training Strategy:** 2-phase training (frozen base â†’ fine-tuning)
- **Classes:** 2 (healthy, unhealthy)

**Goals:**
- Detect unhealthy coconut tree branches
- Provide unhealthy percentage
- High accuracy and balanced metrics
- Avoid overfitting

## 1. Setup and Imports

In [None]:
import os
import shutil
import json
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import random

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

## 2. Configuration

In [None]:
# Paths
BASE_DIR = os.path.abspath('..')
RAW_DATA_DIR = os.path.join(BASE_DIR, 'data', 'raw')
DATASET_DIR = os.path.join(RAW_DATA_DIR, 'coconut_branch_health', 'dataset')
MODEL_DIR = os.path.join(BASE_DIR, 'models', 'coconut_branch_health_v1')

# Model parameters
IMG_SIZE = 224
BATCH_SIZE = 32
PHASE1_EPOCHS = 20  # Frozen base
PHASE2_EPOCHS = 15  # Fine-tuning
LEARNING_RATE_PHASE1 = 1e-3
LEARNING_RATE_PHASE2 = 1e-4

# Classes
class_names = ['healthy', 'unhealthy']

# Create model directory
os.makedirs(MODEL_DIR, exist_ok=True)

print(f"Base Directory: {BASE_DIR}")
print(f"Dataset Directory: {DATASET_DIR}")
print(f"Model Directory: {MODEL_DIR}")
print(f"Classes: {class_names}")

## 3. Create Proper Dataset Structure

Reorganize data from separate folders into ImageDataGenerator-compatible structure:
```
dataset/
  train/
    healthy/
    unhealthy/
  val/
    healthy/
    unhealthy/
  test/
    healthy/
    unhealthy/
```

In [None]:
# Source directories
healthy_src = os.path.join(RAW_DATA_DIR, 'healthy-leaves')
unhealthy_src = os.path.join(RAW_DATA_DIR, 'coconut_tree_branch_Health')

print(f"Healthy source: {healthy_src}")
print(f"Unhealthy source: {unhealthy_src}")

# Create dataset structure
for split in ['training', 'validation', 'test']:
    for cls in ['healthy', 'unhealthy']:
        # Map to correct folder names
        if split == 'training':
            dest_split = 'train'
        elif split == 'validation':
            dest_split = 'val'
        else:
            dest_split = 'test'
        
        dest_dir = os.path.join(DATASET_DIR, dest_split, cls)
        os.makedirs(dest_dir, exist_ok=True)
        
        # Source directory
        if cls == 'healthy':
            src_dir = os.path.join(healthy_src, split)
        else:
            src_dir = os.path.join(unhealthy_src, split)
        
        # Copy files if destination is empty
        if len(os.listdir(dest_dir)) == 0:
            files = [f for f in os.listdir(src_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            for f in files:
                shutil.copy2(os.path.join(src_dir, f), os.path.join(dest_dir, f))
            print(f"Copied {len(files)} files to {dest_split}/{cls}/")
        else:
            print(f"{dest_split}/{cls}/ already has {len(os.listdir(dest_dir))} files")

print("\nDataset structure created successfully!")

## 4. Exploratory Data Analysis (EDA)

In [None]:
# Count images in each split and class
data_summary = {}

print("="*70)
print("DATASET SUMMARY")
print("="*70)

for split in ['train', 'val', 'test']:
    split_path = os.path.join(DATASET_DIR, split)
    print(f"\n{split.upper()}:")
    print("-"*50)
    
    split_total = 0
    for cls in class_names:
        cls_path = os.path.join(split_path, cls)
        count = len([f for f in os.listdir(cls_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        split_total += count
        print(f"  {cls:<15} {count:>6} images")
        
        if split not in data_summary:
            data_summary[split] = {}
        data_summary[split][cls] = count
    
    print(f"  {'TOTAL':<15} {split_total:>6} images")

print("\n" + "="*70)

# Visualize class distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
fig.suptitle('Class Distribution Across Splits', fontsize=14, fontweight='bold')

for idx, split in enumerate(['train', 'val', 'test']):
    counts = [data_summary[split][cls] for cls in class_names]
    axes[idx].bar(class_names, counts, color=['green', 'red'])
    axes[idx].set_title(f'{split.upper()} Split', fontweight='bold')
    axes[idx].set_ylabel('Number of Images')
    axes[idx].set_xlabel('Class')
    
    # Add count labels on bars
    for i, count in enumerate(counts):
        axes[idx].text(i, count, str(count), ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'class_distribution.png'), dpi=150, bbox_inches='tight')
plt.show()

# Check for class imbalance
train_healthy = data_summary['train']['healthy']
train_unhealthy = data_summary['train']['unhealthy']
imbalance_ratio = max(train_healthy, train_unhealthy) / min(train_healthy, train_unhealthy)
print(f"\nClass imbalance ratio (train): {imbalance_ratio:.2f}")
if imbalance_ratio > 1.5:
    print("âš  Class imbalance detected - Focal Loss will help!")
else:
    print("âœ“ Classes are relatively balanced")

## 5. Visualize Sample Images

In [None]:
from tensorflow.keras.preprocessing import image

# Show sample images from each class
fig, axes = plt.subplots(2, 5, figsize=(15, 7))
fig.suptitle('Sample Images from Training Set', fontsize=14, fontweight='bold')

for row, cls in enumerate(class_names):
    cls_dir = os.path.join(DATASET_DIR, 'train', cls)
    images_list = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    # Get 5 random images
    sample_imgs = random.sample(images_list, min(5, len(images_list)))
    
    for col, img_name in enumerate(sample_imgs):
        img_path = os.path.join(cls_dir, img_name)
        img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
        
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
    
    # Add class label on the left
    axes[row, 0].text(-0.3, 0.5, cls.upper(), transform=axes[row, 0].transAxes,
                      fontsize=12, fontweight='bold', va='center', ha='right',
                      color='green' if cls == 'healthy' else 'red')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'sample_images.png'), dpi=150, bbox_inches='tight')
plt.show()

## 6. Data Generators with Augmentation

**Important:** Using ImageDataGenerator with proper validation split prevents data leakage!

In [None]:
# Training data generator with augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.2,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

# Validation and test generators (no augmentation, only rescaling)
val_test_datagen = ImageDataGenerator(rescale=1./255)

# Create generators
train_generator = train_datagen.flow_from_directory(
    os.path.join(DATASET_DIR, 'train'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    classes=class_names,
    shuffle=True,
    seed=42
)

val_generator = val_test_datagen.flow_from_directory(
    os.path.join(DATASET_DIR, 'val'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    classes=class_names,
    shuffle=False
)

test_generator = val_test_datagen.flow_from_directory(
    os.path.join(DATASET_DIR, 'test'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    classes=class_names,
    shuffle=False
)

print(f"\nTraining samples: {train_generator.samples}")
print(f"Validation samples: {val_generator.samples}")
print(f"Test samples: {test_generator.samples}")
print(f"Class indices: {train_generator.class_indices}")

## 7. Build Model with Transfer Learning

In [None]:
def build_model():
    """Build MobileNetV2 model with transfer learning"""
    
    # Load pre-trained MobileNetV2
    base_model = MobileNetV2(
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model
    base_model.trainable = False
    
    # Build model
    inputs = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(len(class_names), activation='softmax')(x)
    
    model = keras.Model(inputs, outputs)
    
    return model, base_model

# Build model
model, base_model = build_model()

print("Model Architecture:")
model.summary()

print(f"\nBase model layers: {len(base_model.layers)}")
print(f"Total model layers: {len(model.layers)}")
print(f"Trainable layers: {sum([layer.trainable for layer in model.layers])}")

## 8. Define Focal Loss

Focal Loss helps with class imbalance by focusing on hard-to-classify examples.

In [None]:
def focal_loss(gamma=2.0, alpha=0.25):
    """
    Focal Loss for handling class imbalance
    
    Args:
        gamma: Focusing parameter (higher = more focus on hard examples)
        alpha: Balancing parameter
    """
    def focal_loss_fn(y_true, y_pred):
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.keras.backend.clip(y_pred, epsilon, 1.0 - epsilon)
        
        # Cross entropy
        cross_entropy = -y_true * tf.keras.backend.log(y_pred)
        
        # Focal weight
        focal_weight = tf.keras.backend.pow(1.0 - y_pred, gamma)
        
        # Focal loss
        focal_loss = alpha * focal_weight * cross_entropy
        
        return tf.keras.backend.sum(focal_loss, axis=-1)
    
    return focal_loss_fn

print("Focal Loss function defined!")
print(f"  Gamma: 2.0 (focus on hard examples)")
print(f"  Alpha: 0.25 (class balancing)")

## 9. Phase 1: Train with Frozen Base

In [None]:
# Compile model for Phase 1
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE_PHASE1),
    loss=focal_loss(gamma=2.0, alpha=0.25),
    metrics=['accuracy']
)

# Callbacks
checkpoint_phase1 = ModelCheckpoint(
    os.path.join(MODEL_DIR, 'phase1_best.keras'),
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stop_phase1 = EarlyStopping(
    monitor='val_loss',
    patience=7,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_phase1 = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

print("="*70)
print("PHASE 1: Training with Frozen Base Model")
print("="*70)
print(f"Learning Rate: {LEARNING_RATE_PHASE1}")
print(f"Epochs: {PHASE1_EPOCHS}")
print(f"Base Model Trainable: {base_model.trainable}")
print("="*70)

# Train Phase 1
start_time = time.time()

history_phase1 = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=PHASE1_EPOCHS,
    callbacks=[checkpoint_phase1, early_stop_phase1, reduce_lr_phase1],
    verbose=1
)

phase1_time = (time.time() - start_time) / 60

print(f"\nPhase 1 completed in {phase1_time:.1f} minutes")
print(f"Best validation accuracy: {max(history_phase1.history['val_accuracy'])*100:.2f}%")

## 10. Phase 2: Fine-tuning

In [None]:
# Unfreeze base model for fine-tuning
base_model.trainable = True

# Freeze early layers, only fine-tune last layers
fine_tune_at = len(base_model.layers) - 30

for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

print(f"Total base model layers: {len(base_model.layers)}")
print(f"Fine-tuning from layer: {fine_tune_at}")
print(f"Trainable layers: {sum([layer.trainable for layer in model.layers])}")

# Recompile with lower learning rate
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE_PHASE2),
    loss=focal_loss(gamma=2.0, alpha=0.25),
    metrics=['accuracy']
)

# Callbacks for Phase 2
checkpoint_phase2 = ModelCheckpoint(
    os.path.join(MODEL_DIR, 'best_model.keras'),
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stop_phase2 = EarlyStopping(
    monitor='val_loss',
    patience=7,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_phase2 = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-8,
    verbose=1
)

print("\n" + "="*70)
print("PHASE 2: Fine-tuning")
print("="*70)
print(f"Learning Rate: {LEARNING_RATE_PHASE2}")
print(f"Epochs: {PHASE2_EPOCHS}")
print("="*70)

# Train Phase 2
start_time = time.time()

history_phase2 = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=PHASE2_EPOCHS,
    callbacks=[checkpoint_phase2, early_stop_phase2, reduce_lr_phase2],
    verbose=1
)

phase2_time = (time.time() - start_time) / 60
total_time = phase1_time + phase2_time

print(f"\nPhase 2 completed in {phase2_time:.1f} minutes")
print(f"Total training time: {total_time:.1f} minutes")
print(f"Best validation accuracy: {max(history_phase2.history['val_accuracy'])*100:.2f}%")

## 11. Training History Visualization

In [None]:
# Combine histories
history_combined = {
    'accuracy': history_phase1.history['accuracy'] + history_phase2.history['accuracy'],
    'val_accuracy': history_phase1.history['val_accuracy'] + history_phase2.history['val_accuracy'],
    'loss': history_phase1.history['loss'] + history_phase2.history['loss'],
    'val_loss': history_phase1.history['val_loss'] + history_phase2.history['val_loss']
}

# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('Training History - Coconut Branch Health Model v1', fontsize=14, fontweight='bold')

# Accuracy plot
axes[0].plot(history_combined['accuracy'], label='Train Accuracy', linewidth=2)
axes[0].plot(history_combined['val_accuracy'], label='Val Accuracy', linewidth=2)
axes[0].axvline(x=len(history_phase1.history['accuracy'])-1, color='red', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[0].set_title('Model Accuracy', fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss plot
axes[1].plot(history_combined['loss'], label='Train Loss', linewidth=2)
axes[1].plot(history_combined['val_loss'], label='Val Loss', linewidth=2)
axes[1].axvline(x=len(history_phase1.history['loss'])-1, color='red', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[1].set_title('Model Loss', fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'training_history.png'), dpi=150, bbox_inches='tight')
plt.show()

# Check for overfitting
final_train_acc = history_combined['accuracy'][-1]
final_val_acc = history_combined['val_accuracy'][-1]
acc_gap = abs(final_train_acc - final_val_acc)

print("\n" + "="*70)
print("OVERFITTING CHECK")
print("="*70)
print(f"Final Train Accuracy: {final_train_acc*100:.2f}%")
print(f"Final Val Accuracy: {final_val_acc*100:.2f}%")
print(f"Accuracy Gap: {acc_gap*100:.2f}%")

if acc_gap < 0.05:
    print("âœ“ No significant overfitting detected!")
elif acc_gap < 0.10:
    print("âš  Minor overfitting - acceptable")
else:
    print("âš âš  Overfitting detected - consider more regularization")
print("="*70)

## 12. Load Best Model and Evaluate on Test Set

In [None]:
# Load best model
best_model = keras.models.load_model(
    os.path.join(MODEL_DIR, 'best_model.keras'),
    custom_objects={'focal_loss_fn': focal_loss(gamma=2.0, alpha=0.25)}
)

print("Best model loaded!")

# Make predictions on test set
test_generator.reset()
predictions = best_model.predict(test_generator, verbose=1)

# Get true labels and predicted labels
y_true = test_generator.classes
y_pred = np.argmax(predictions, axis=1)

# Calculate accuracy
test_accuracy = np.mean(y_true == y_pred)
print(f"\nTest Accuracy: {test_accuracy*100:.2f}%")

## 13. Detailed Metrics - Class-wise Performance

In [None]:
# Calculate per-class metrics
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)
macro_precision = np.mean(precision)
macro_recall = np.mean(recall)
macro_f1 = np.mean(f1)

print("="*80)
print("DETAILED CLASS-WISE METRICS")
print("="*80)
print(f"\n{'Class':<15} {'Precision':>12} {'Recall':>12} {'F1-Score':>12} {'Support':>10}")
print("-"*80)

for i, cls in enumerate(class_names):
    print(f"{cls:<15} {precision[i]*100:>11.2f}% {recall[i]*100:>11.2f}% {f1[i]*100:>11.2f}% {support[i]:>10}")

print("-"*80)
print(f"{'Macro Average':<15} {macro_precision*100:>11.2f}% {macro_recall*100:>11.2f}% {macro_f1*100:>11.2f}%")
print("="*80)

# Check metric balance
print("\n" + "="*80)
print("METRIC BALANCE CHECK")
print("="*80)

for i, cls in enumerate(class_names):
    p, r, f = precision[i], recall[i], f1[i]
    max_diff = max(abs(p-r), abs(p-f), abs(r-f))
    print(f"\n{cls.upper()}:")
    print(f"  Precision: {p*100:.2f}%")
    print(f"  Recall: {r*100:.2f}%")
    print(f"  F1-Score: {f*100:.2f}%")
    print(f"  Max difference: {max_diff*100:.2f}%")
    
    if max_diff < 0.05:
        print(f"  âœ“ Well balanced!")
    elif max_diff < 0.10:
        print(f"  âš  Acceptable balance")
    else:
        print(f"  âš âš  Imbalanced - one metric is weaker")

# Check if accuracy is close to F1
acc_f1_diff = abs(test_accuracy - macro_f1)
print(f"\n{'='*80}")
print(f"Accuracy: {test_accuracy*100:.2f}%")
print(f"Macro F1: {macro_f1*100:.2f}%")
print(f"Difference: {acc_f1_diff*100:.2f}%")

if acc_f1_diff < 0.03:
    print("âœ“ Accuracy and F1 are very close!")
elif acc_f1_diff < 0.05:
    print("âœ“ Accuracy and F1 are reasonably close")
else:
    print("âš  Accuracy and F1 have notable difference")
print("="*80)

## 14. Confusion Matrix

In [None]:
# Create confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrices
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('Confusion Matrix - Coconut Branch Health Model v1', fontsize=14, fontweight='bold')

# Counts
sns.heatmap(cm, annot=True, fmt='d', cmap='RdYlGn', 
            xticklabels=class_names, yticklabels=class_names, ax=axes[0],
            cbar_kws={'label': 'Count'})
axes[0].set_title('Confusion Matrix (Counts)', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Predicted Label')
axes[0].set_ylabel('True Label')

# Percentages
cm_pct = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
sns.heatmap(cm_pct, annot=True, fmt='.1f', cmap='RdYlGn',
            xticklabels=class_names, yticklabels=class_names, ax=axes[1],
            cbar_kws={'label': 'Percentage (%)'})
axes[1].set_title('Confusion Matrix (Percentages)', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Predicted Label')
axes[1].set_ylabel('True Label')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'confusion_matrix.png'), dpi=150, bbox_inches='tight')
plt.show()

# Print confusion matrix details
print("\nConfusion Matrix Details:")
print("-"*50)
for i, true_cls in enumerate(class_names):
    for j, pred_cls in enumerate(class_names):
        count = cm[i, j]
        pct = cm_pct[i, j]
        if i == j:
            print(f"âœ“ {true_cls} correctly classified: {count} ({pct:.1f}%)")
        elif count > 0:
            print(f"âœ— {true_cls} misclassified as {pred_cls}: {count} ({pct:.1f}%)")

## 15. Classification Report

In [None]:
# Print full classification report
print("\n" + "="*80)
print("CLASSIFICATION REPORT")
print("="*80)
print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
print("="*80)

## 16. Sample Predictions Visualization

In [None]:
# Get filenames
filenames = test_generator.filenames

# Find correct and wrong predictions
correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]]
wrong_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]]

print(f"Total: {len(y_true)} | Correct: {len(correct_idx)} | Wrong: {len(wrong_idx)}")

# Plot correct predictions
fig, axes = plt.subplots(2, 5, figsize=(15, 7))
fig.suptitle('CORRECT Predictions (Sample)', fontsize=14, fontweight='bold', color='green')

sample_correct = random.sample(correct_idx, min(10, len(correct_idx)))
for idx, i in enumerate(sample_correct):
    row, col = idx // 5, idx % 5
    img_path = os.path.join(DATASET_DIR, 'test', filenames[i])
    img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
    
    axes[row, col].imshow(img)
    axes[row, col].axis('off')
    true_label = class_names[y_true[i]]
    pred_label = class_names[y_pred[i]]
    confidence = predictions[i][y_pred[i]] * 100
    axes[row, col].set_title(f'True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)', 
                             fontsize=9, color='green')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'correct_predictions.png'), dpi=150, bbox_inches='tight')
plt.show()

# Plot wrong predictions if any
if len(wrong_idx) > 0:
    n_wrong = min(10, len(wrong_idx))
    rows = (n_wrong + 4) // 5
    fig, axes = plt.subplots(rows, 5, figsize=(15, 3.5*rows))
    fig.suptitle(f'WRONG Predictions (All {len(wrong_idx)})', fontsize=14, fontweight='bold', color='red')
    
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    for idx, i in enumerate(wrong_idx[:n_wrong]):
        row, col = idx // 5, idx % 5
        img_path = os.path.join(DATASET_DIR, 'test', filenames[i])
        img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
        
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
        true_label = class_names[y_true[i]]
        pred_label = class_names[y_pred[i]]
        confidence = predictions[i][y_pred[i]] * 100
        axes[row, col].set_title(f'True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)', 
                                fontsize=9, color='red')
    
    # Hide empty subplots
    for idx in range(n_wrong, rows * 5):
        row, col = idx // 5, idx % 5
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_DIR, 'wrong_predictions.png'), dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("\nðŸŽ‰ Perfect! No wrong predictions!")

print(f"\nImages saved to {MODEL_DIR}/")

## 17. Save Model Information

In [None]:
# Create model info dictionary
model_info = {
    'model_name': 'coconut_branch_health_v1',
    'architecture': 'MobileNetV2',
    'classes': class_names,
    'num_classes': len(class_names),
    'input_shape': [IMG_SIZE, IMG_SIZE, 3],
    'training': {
        'phase1_epochs': PHASE1_EPOCHS,
        'phase2_epochs': PHASE2_EPOCHS,
        'total_epochs': PHASE1_EPOCHS + PHASE2_EPOCHS,
        'batch_size': BATCH_SIZE,
        'learning_rate_phase1': LEARNING_RATE_PHASE1,
        'learning_rate_phase2': LEARNING_RATE_PHASE2,
        'loss_function': 'Focal Loss (gamma=2.0, alpha=0.25)',
        'optimizer': 'Adam',
        'training_time_minutes': round(total_time, 1),
        'final_train_accuracy': float(final_train_acc),
        'final_val_accuracy': float(final_val_acc)
    },
    'data': {
        'train_samples': train_generator.samples,
        'val_samples': val_generator.samples,
        'test_samples': test_generator.samples,
        'train_healthy': data_summary['train']['healthy'],
        'train_unhealthy': data_summary['train']['unhealthy'],
        'class_imbalance_ratio': float(imbalance_ratio)
    },
    'test_performance': {
        'accuracy': float(test_accuracy),
        'macro_precision': float(macro_precision),
        'macro_recall': float(macro_recall),
        'macro_f1': float(macro_f1),
        'healthy_precision': float(precision[0]),
        'healthy_recall': float(recall[0]),
        'healthy_f1': float(f1[0]),
        'unhealthy_precision': float(precision[1]),
        'unhealthy_recall': float(recall[1]),
        'unhealthy_f1': float(f1[1])
    },
    'augmentation': {
        'rotation_range': 30,
        'width_shift_range': 0.2,
        'height_shift_range': 0.2,
        'horizontal_flip': True,
        'vertical_flip': True,
        'zoom_range': 0.2,
        'brightness_range': [0.8, 1.2]
    }
}

# Save model info
model_info_path = os.path.join(MODEL_DIR, 'model_info.json')
with open(model_info_path, 'w') as f:
    json.dump(model_info, f, indent=2)

print(f"Model information saved to: {model_info_path}")
print("\nModel files created:")
print(f"  - best_model.keras")
print(f"  - model_info.json")
print(f"  - training_history.png")
print(f"  - confusion_matrix.png")
print(f"  - class_distribution.png")
print(f"  - sample_images.png")
print(f"  - correct_predictions.png")
if len(wrong_idx) > 0:
    print(f"  - wrong_predictions.png")

## 18. Final Summary

In [None]:
print("\n" + "="*80)
print("COCONUT BRANCH HEALTH MODEL v1 - FINAL SUMMARY")
print("="*80)
print()
print("  Model Information:")
print("  " + "-"*60)
print(f"  Name:                    Coconut Branch Health Detection Model v1")
print(f"  Architecture:            MobileNetV2 (Transfer Learning)")
print(f"  Loss Function:           Focal Loss (gamma=2.0)")
print(f"  Input Size:              {IMG_SIZE}x{IMG_SIZE}x3")
print(f"  Classes:                 {class_names}")
print()
print("  Training Summary:")
print("  " + "-"*60)
print(f"  Phase 1 (Frozen):        {PHASE1_EPOCHS} epochs, LR={LEARNING_RATE_PHASE1}")
print(f"  Phase 2 (Fine-tune):     {PHASE2_EPOCHS} epochs, LR={LEARNING_RATE_PHASE2}")
print(f"  Total Training Time:     {total_time:.1f} minutes")
print(f"  Final Train Accuracy:    {final_train_acc*100:.2f}%")
print(f"  Final Val Accuracy:      {final_val_acc*100:.2f}%")
print()
print("  Test Performance:")
print("  " + "-"*60)
print(f"  Test Accuracy:           {test_accuracy*100:.2f}%")
print(f"  Macro Precision:         {macro_precision*100:.2f}%")
print(f"  Macro Recall:            {macro_recall*100:.2f}%")
print(f"  Macro F1-Score:          {macro_f1*100:.2f}%")
print()
print("  Class-wise Performance:")
print("  " + "-"*60)
for i, cls in enumerate(class_names):
    print(f"  {cls.upper():12} P={precision[i]*100:.2f}% R={recall[i]*100:.2f}% F1={f1[i]*100:.2f}%")
print()
print("  Quality Checks:")
print("  " + "-"*60)
print(f"  Overfitting (Train-Val gap):    {acc_gap*100:.2f}% {'âœ“' if acc_gap < 0.05 else 'âš '}")
print(f"  Accuracy-F1 alignment:          {acc_f1_diff*100:.2f}% {'âœ“' if acc_f1_diff < 0.03 else 'âš '}")
print(f"  Metric balance (healthy):       {'âœ“' if max(abs(precision[0]-recall[0]), abs(precision[0]-f1[0])) < 0.05 else 'âš '}")
print(f"  Metric balance (unhealthy):     {'âœ“' if max(abs(precision[1]-recall[1]), abs(precision[1]-f1[1])) < 0.05 else 'âš '}")
print()
print("  Files:")
print("  " + "-"*60)
print(f"  Model:                   {MODEL_DIR}/best_model.keras")
print(f"  Model Info:              {MODEL_DIR}/model_info.json")
print()
print("="*80)
print("                         âœ“ TRAINING COMPLETE!")
print("="*80)