# Coconut Leaf Dieback Detection Model v6

This notebook trains a model to detect **leaf dieback disease** in coconut trees.

## Model Configuration
- **Architecture:** MobileNetV2 (Transfer Learning)
- **Loss Function:** Focal Loss (gamma=2.0) for handling class imbalance
- **Training Strategy:** 2-phase training (frozen base â†’ fine-tuning)
- **Classes:** 3 (healthy, leaf_die_back, not_cocount)

## Goals (Madam's Requirements)
- Avoid data leaking and overfitting
- Balanced Precision, Recall, F1-score for each class
- Similar values across all classes
- Accuracy close to macro F1-score
- Real outputs, not hardcoded

## 1. Setup and Imports

In [None]:
import os
import shutil
import json
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import random
import warnings
warnings.filterwarnings('ignore')

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# Set random seeds for reproducibility (IMPORTANT: prevents randomness issues)
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
random.seed(SEED)

print(f"\nRandom seed set to {SEED} for reproducibility")

## 2. Configuration

In [None]:
# Paths
BASE_DIR = os.path.abspath(os.path.join('..', '..'))
RAW_DATASET_DIR = os.path.join(BASE_DIR, 'data', 'raw', 'leaf_dieback_v1')
BALANCED_DATASET_DIR = os.path.join(BASE_DIR, 'data', 'processed', 'leaf_dieback_balanced')
MODEL_DIR = os.path.join(BASE_DIR, 'models', 'leaf_dieback_v6')

# Model parameters
IMG_SIZE = 224
BATCH_SIZE = 32
PHASE1_EPOCHS = 25  # Frozen base
PHASE2_EPOCHS = 20  # Fine-tuning
LEARNING_RATE_PHASE1 = 1e-3
LEARNING_RATE_PHASE2 = 1e-5

# Classes (must match folder names)
CLASS_NAMES = ['healthy', 'leaf_die_back', 'not_cocount']

# Create directories
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(BALANCED_DATASET_DIR, exist_ok=True)

print("=" * 70)
print("CONFIGURATION")
print("=" * 70)
print(f"Base Directory:      {BASE_DIR}")
print(f"Raw Dataset:         {RAW_DATASET_DIR}")
print(f"Balanced Dataset:    {BALANCED_DATASET_DIR}")
print(f"Model Directory:     {MODEL_DIR}")
print(f"Image Size:          {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size:          {BATCH_SIZE}")
print(f"Classes:             {CLASS_NAMES}")
print("=" * 70)

## 3. Exploratory Data Analysis (EDA)

First, let's examine the raw dataset to understand the class distribution.

In [None]:
# Count images in each split and class (BEFORE balancing)
raw_data_summary = {}

print("=" * 70)
print("RAW DATASET SUMMARY (BEFORE BALANCING)")
print("=" * 70)

for split in ['train', 'val', 'test']:
    split_path = os.path.join(RAW_DATASET_DIR, split)
    print(f"\n{split.upper()}:")
    print("-" * 50)
    
    split_total = 0
    for cls in CLASS_NAMES:
        cls_path = os.path.join(split_path, cls)
        if os.path.exists(cls_path):
            count = len([f for f in os.listdir(cls_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        else:
            count = 0
        split_total += count
        print(f"  {cls:<20} {count:>6} images")
        
        if split not in raw_data_summary:
            raw_data_summary[split] = {}
        raw_data_summary[split][cls] = count
    
    print(f"  {'TOTAL':<20} {split_total:>6} images")

print("\n" + "=" * 70)

# Identify imbalance
print("\nIMBALANCE DETECTED:")
for split in ['val', 'test']:
    counts = [raw_data_summary[split][cls] for cls in CLASS_NAMES]
    min_count = min(counts)
    max_count = max(counts)
    ratio = max_count / min_count if min_count > 0 else float('inf')
    print(f"  {split.upper()}: min={min_count}, max={max_count}, ratio={ratio:.2f}x")
    if ratio > 1.5:
        print(f"    --> Need to balance {split} set!")

## 4. Balance Validation and Test Sets

The validation and test sets are imbalanced. We'll balance them by:
1. Moving images from training to val/test for underrepresented classes
2. Creating a properly balanced dataset

**IMPORTANT:** This prevents biased evaluation and ensures fair metrics.

In [None]:
def balance_dataset(raw_dir, balanced_dir, class_names, target_val_test=90):
    """
    Balance the dataset by redistributing images.
    
    Strategy:
    - Target ~90 images per class in val and test
    - Move images from train to val/test for underrepresented classes
    - This ensures balanced evaluation metrics
    """
    
    print("=" * 70)
    print("BALANCING DATASET")
    print("=" * 70)
    
    # Create directory structure
    for split in ['train', 'val', 'test']:
        for cls in class_names:
            os.makedirs(os.path.join(balanced_dir, split, cls), exist_ok=True)
    
    # Check if already balanced
    already_done = True
    for split in ['train', 'val', 'test']:
        for cls in class_names:
            balanced_path = os.path.join(balanced_dir, split, cls)
            if not os.path.exists(balanced_path) or len(os.listdir(balanced_path)) == 0:
                already_done = False
                break
    
    if already_done:
        print("Dataset already balanced. Skipping...")
        return
    
    # Step 1: Copy all raw data to balanced directory first
    print("\nStep 1: Copying raw data...")
    for split in ['train', 'val', 'test']:
        for cls in class_names:
            src_dir = os.path.join(raw_dir, split, cls)
            dst_dir = os.path.join(balanced_dir, split, cls)
            
            if os.path.exists(src_dir):
                files = [f for f in os.listdir(src_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                for f in files:
                    src_file = os.path.join(src_dir, f)
                    dst_file = os.path.join(dst_dir, f)
                    if not os.path.exists(dst_file):
                        shutil.copy2(src_file, dst_file)
                print(f"  Copied {len(files)} files to {split}/{cls}")
    
    # Step 2: Balance val and test sets
    print("\nStep 2: Balancing val and test sets...")
    
    for split in ['val', 'test']:
        for cls in class_names:
            split_cls_dir = os.path.join(balanced_dir, split, cls)
            current_count = len([f for f in os.listdir(split_cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            
            if current_count < target_val_test:
                # Need to add more images from training
                needed = target_val_test - current_count
                train_cls_dir = os.path.join(balanced_dir, 'train', cls)
                
                train_files = [f for f in os.listdir(train_cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                random.shuffle(train_files)
                
                # Move files from train to val/test
                moved = 0
                for f in train_files[:needed]:
                    src_file = os.path.join(train_cls_dir, f)
                    # Rename to avoid conflicts
                    new_name = f"moved_from_train_{f}"
                    dst_file = os.path.join(split_cls_dir, new_name)
                    if os.path.exists(src_file):
                        shutil.move(src_file, dst_file)
                        moved += 1
                
                print(f"  Moved {moved} images from train to {split}/{cls}")
    
    print("\nDataset balancing complete!")
    return True

# Balance the dataset
balance_dataset(RAW_DATASET_DIR, BALANCED_DATASET_DIR, CLASS_NAMES, target_val_test=90)

In [None]:
# Count images AFTER balancing
data_summary = {}

print("=" * 70)
print("BALANCED DATASET SUMMARY")
print("=" * 70)

total_images = 0
for split in ['train', 'val', 'test']:
    split_path = os.path.join(BALANCED_DATASET_DIR, split)
    print(f"\n{split.upper()}:")
    print("-" * 50)
    
    split_total = 0
    for cls in CLASS_NAMES:
        cls_path = os.path.join(split_path, cls)
        if os.path.exists(cls_path):
            count = len([f for f in os.listdir(cls_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        else:
            count = 0
        split_total += count
        print(f"  {cls:<20} {count:>6} images")
        
        if split not in data_summary:
            data_summary[split] = {}
        data_summary[split][cls] = count
    
    print(f"  {'TOTAL':<20} {split_total:>6} images")
    total_images += split_total

print("\n" + "=" * 70)
print(f"TOTAL BALANCED IMAGES: {total_images}")
print("=" * 70)

In [None]:
# Visualize class distribution (AFTER balancing)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('Class Distribution Across Splits (BALANCED)', fontsize=14, fontweight='bold')

colors = ['#2ecc71', '#e74c3c', '#3498db']  # green, red, blue

for idx, split in enumerate(['train', 'val', 'test']):
    counts = [data_summary[split][cls] for cls in CLASS_NAMES]
    bars = axes[idx].bar(CLASS_NAMES, counts, color=colors)
    axes[idx].set_title(f'{split.upper()} Split', fontweight='bold')
    axes[idx].set_ylabel('Number of Images')
    axes[idx].set_xlabel('Class')
    axes[idx].tick_params(axis='x', rotation=15)
    
    # Add count labels
    for bar, count in zip(bars, counts):
        axes[idx].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                       str(count), ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'class_distribution_balanced.png'), dpi=150, bbox_inches='tight')
plt.show()

# Check balance ratio
print("\nBalance Check:")
for split in ['val', 'test']:
    counts = [data_summary[split][cls] for cls in CLASS_NAMES]
    ratio = max(counts) / min(counts) if min(counts) > 0 else float('inf')
    status = "Balanced!" if ratio < 1.2 else "Needs attention"
    print(f"  {split.upper()}: ratio={ratio:.2f}x - {status}")

## 5. Visualize Sample Images

In [None]:
from tensorflow.keras.preprocessing import image

# Show sample images from each class
fig, axes = plt.subplots(3, 5, figsize=(15, 10))
fig.suptitle('Sample Images from Balanced Training Set', fontsize=14, fontweight='bold')

class_colors = {'healthy': 'green', 'leaf_die_back': 'red', 'not_cocount': 'blue'}

for row, cls in enumerate(CLASS_NAMES):
    cls_dir = os.path.join(BALANCED_DATASET_DIR, 'train', cls)
    images_list = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    # Get 5 random images
    sample_imgs = random.sample(images_list, min(5, len(images_list)))
    
    for col, img_name in enumerate(sample_imgs):
        img_path = os.path.join(cls_dir, img_name)
        img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
        
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
    
    # Add class label
    axes[row, 0].text(-0.3, 0.5, cls.replace('_', '\n').upper(), 
                      transform=axes[row, 0].transAxes,
                      fontsize=10, fontweight='bold', va='center', ha='right',
                      color=class_colors[cls])

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'sample_images.png'), dpi=150, bbox_inches='tight')
plt.show()

## 6. Data Generators with Augmentation

**IMPORTANT - Preventing Data Leakage:**
- Augmentation is ONLY applied to training data
- Validation and test use raw images with only rescaling
- No data from val/test leaks into training

In [None]:
# Training data generator with augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.2,
    brightness_range=[0.8, 1.2],
    shear_range=0.15,
    fill_mode='nearest'
)

# Validation and test generators - NO augmentation
val_test_datagen = ImageDataGenerator(rescale=1./255)

print("Data Augmentation (Training Only):")
print("-" * 50)
print("  - Rotation: +/-30 degrees")
print("  - Width/Height Shift: +/-20%")
print("  - Horizontal & Vertical Flip")
print("  - Zoom: +/-20%")
print("  - Brightness: 80-120%")
print("  - Shear: +/-15%")
print("\nNo augmentation on validation/test data (prevents data leakage)")

In [None]:
# Create generators from BALANCED dataset
train_generator = train_datagen.flow_from_directory(
    os.path.join(BALANCED_DATASET_DIR, 'train'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    classes=CLASS_NAMES,
    shuffle=True,
    seed=SEED
)

val_generator = val_test_datagen.flow_from_directory(
    os.path.join(BALANCED_DATASET_DIR, 'val'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    classes=CLASS_NAMES,
    shuffle=False
)

test_generator = val_test_datagen.flow_from_directory(
    os.path.join(BALANCED_DATASET_DIR, 'test'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    classes=CLASS_NAMES,
    shuffle=False
)

print("\n" + "=" * 70)
print("DATA GENERATORS CREATED (FROM BALANCED DATASET)")
print("=" * 70)
print(f"Training samples:   {train_generator.samples}")
print(f"Validation samples: {val_generator.samples}")
print(f"Test samples:       {test_generator.samples}")
print(f"\nClass indices: {train_generator.class_indices}")
print("=" * 70)

## 7. Visualize Augmented Images

In [None]:
# Show augmented images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
fig.suptitle('Data Augmentation Examples (Same Image)', fontsize=14, fontweight='bold')

# Get a sample image
sample_cls = 'leaf_die_back'
cls_dir = os.path.join(BALANCED_DATASET_DIR, 'train', sample_cls)
sample_img_name = random.choice([f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
sample_img_path = os.path.join(cls_dir, sample_img_name)

# Load image
img = image.load_img(sample_img_path, target_size=(IMG_SIZE, IMG_SIZE))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)

# Show original
axes[0, 0].imshow(img)
axes[0, 0].set_title('Original', fontweight='bold')
axes[0, 0].axis('off')

# Generate augmented versions
aug_gen = train_datagen.flow(img_array, batch_size=1, seed=42)
for i in range(9):
    row, col = (i + 1) // 5, (i + 1) % 5
    augmented = next(aug_gen)[0]
    axes[row, col].imshow(augmented)
    axes[row, col].set_title(f'Augmented {i+1}')
    axes[row, col].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'augmentation_examples.png'), dpi=150, bbox_inches='tight')
plt.show()

## 8. Define Focal Loss

Focal Loss helps with class imbalance by focusing on hard-to-classify examples.

In [None]:
def focal_loss(gamma=2.0, alpha=0.25):
    """
    Focal Loss for handling class imbalance.
    
    Args:
        gamma: Focusing parameter (higher = more focus on hard examples)
        alpha: Class balancing parameter
    """
    def focal_loss_fn(y_true, y_pred):
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.keras.backend.clip(y_pred, epsilon, 1.0 - epsilon)
        
        # Cross entropy
        cross_entropy = -y_true * tf.keras.backend.log(y_pred)
        
        # Focal weight: (1 - p)^gamma
        focal_weight = tf.keras.backend.pow(1.0 - y_pred, gamma)
        
        # Focal loss
        focal_loss = alpha * focal_weight * cross_entropy
        
        return tf.keras.backend.sum(focal_loss, axis=-1)
    
    return focal_loss_fn

print("Focal Loss function defined!")
print(f"  Gamma: 2.0 (focus on hard examples)")
print(f"  Alpha: 0.25 (class balancing)")

## 9. Build Model with Transfer Learning

In [None]:
def build_model(num_classes):
    """
    Build MobileNetV2 model with transfer learning.
    """
    # Load pre-trained MobileNetV2
    base_model = MobileNetV2(
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model initially
    base_model.trainable = False
    
    # Build classification head with regularization
    inputs = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.4)(x)  # Dropout for regularization
    x = layers.Dense(256, activation='relu', kernel_regularizer=keras.regularizers.l2(0.01))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation='relu', kernel_regularizer=keras.regularizers.l2(0.01))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs)
    
    return model, base_model

# Build model
model, base_model = build_model(len(CLASS_NAMES))

print("=" * 70)
print("MODEL ARCHITECTURE")
print("=" * 70)
model.summary()

print(f"\nBase model layers: {len(base_model.layers)}")
print(f"Trainable parameters: {sum([np.prod(w.shape) for w in model.trainable_weights]):,}")

## 10. Phase 1: Training with Frozen Base

In [None]:
# Compile model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE_PHASE1),
    loss=focal_loss(gamma=2.0, alpha=0.25),
    metrics=['accuracy']
)

# Callbacks
checkpoint_phase1 = ModelCheckpoint(
    os.path.join(MODEL_DIR, 'phase1_best.keras'),
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stop_phase1 = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_phase1 = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-7,
    verbose=1
)

print("=" * 70)
print("PHASE 1: TRAINING WITH FROZEN BASE MODEL")
print("=" * 70)
print(f"Learning Rate:      {LEARNING_RATE_PHASE1}")
print(f"Max Epochs:         {PHASE1_EPOCHS}")
print(f"Base Model Frozen:  {not base_model.trainable}")
print("=" * 70)

In [None]:
# Train Phase 1
start_time = time.time()

history_phase1 = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=PHASE1_EPOCHS,
    callbacks=[checkpoint_phase1, early_stop_phase1, reduce_lr_phase1],
    verbose=1
)

phase1_time = (time.time() - start_time) / 60

print("\n" + "=" * 70)
print("PHASE 1 COMPLETE")
print("=" * 70)
print(f"Training Time: {phase1_time:.1f} minutes")
print(f"Epochs Run: {len(history_phase1.history['accuracy'])}")
print(f"Best Val Accuracy: {max(history_phase1.history['val_accuracy'])*100:.2f}%")
print("=" * 70)

## 11. Phase 2: Fine-tuning

In [None]:
# Unfreeze base model for fine-tuning
base_model.trainable = True

# Only fine-tune last 30 layers
fine_tune_at = len(base_model.layers) - 30

for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

print(f"Total base model layers: {len(base_model.layers)}")
print(f"Fine-tuning from layer: {fine_tune_at}")
print(f"Layers being fine-tuned: {len(base_model.layers) - fine_tune_at}")

In [None]:
# Recompile with lower learning rate
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE_PHASE2),
    loss=focal_loss(gamma=2.0, alpha=0.25),
    metrics=['accuracy']
)

# Callbacks for Phase 2
checkpoint_phase2 = ModelCheckpoint(
    os.path.join(MODEL_DIR, 'best_model.keras'),
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stop_phase2 = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_phase2 = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-8,
    verbose=1
)

print("\n" + "=" * 70)
print("PHASE 2: FINE-TUNING")
print("=" * 70)
print(f"Learning Rate:      {LEARNING_RATE_PHASE2}")
print(f"Max Epochs:         {PHASE2_EPOCHS}")
print("=" * 70)

In [None]:
# Train Phase 2
start_time = time.time()

history_phase2 = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=PHASE2_EPOCHS,
    callbacks=[checkpoint_phase2, early_stop_phase2, reduce_lr_phase2],
    verbose=1
)

phase2_time = (time.time() - start_time) / 60
total_time = phase1_time + phase2_time

print("\n" + "=" * 70)
print("PHASE 2 COMPLETE")
print("=" * 70)
print(f"Phase 2 Time: {phase2_time:.1f} minutes")
print(f"Total Training Time: {total_time:.1f} minutes")
print(f"Best Val Accuracy: {max(history_phase2.history['val_accuracy'])*100:.2f}%")
print("=" * 70)

## 12. Training History Visualization

In [None]:
# Combine histories
history_combined = {
    'accuracy': history_phase1.history['accuracy'] + history_phase2.history['accuracy'],
    'val_accuracy': history_phase1.history['val_accuracy'] + history_phase2.history['val_accuracy'],
    'loss': history_phase1.history['loss'] + history_phase2.history['loss'],
    'val_loss': history_phase1.history['val_loss'] + history_phase2.history['val_loss']
}

phase1_epochs = len(history_phase1.history['accuracy'])

# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('Training History - Leaf Dieback Model v6', fontsize=14, fontweight='bold')

# Accuracy plot
axes[0].plot(history_combined['accuracy'], label='Train Accuracy', linewidth=2, color='#3498db')
axes[0].plot(history_combined['val_accuracy'], label='Val Accuracy', linewidth=2, color='#e74c3c')
axes[0].axvline(x=phase1_epochs-1, color='gray', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[0].set_title('Model Accuracy', fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend(loc='lower right')
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim([0, 1.05])

# Loss plot
axes[1].plot(history_combined['loss'], label='Train Loss', linewidth=2, color='#3498db')
axes[1].plot(history_combined['val_loss'], label='Val Loss', linewidth=2, color='#e74c3c')
axes[1].axvline(x=phase1_epochs-1, color='gray', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[1].set_title('Model Loss', fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'training_history.png'), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Check for overfitting
final_train_acc = history_combined['accuracy'][-1]
final_val_acc = history_combined['val_accuracy'][-1]
acc_gap = final_train_acc - final_val_acc

print("=" * 70)
print("OVERFITTING CHECK")
print("=" * 70)
print(f"Final Train Accuracy: {final_train_acc*100:.2f}%")
print(f"Final Val Accuracy:   {final_val_acc*100:.2f}%")
print(f"Accuracy Gap:         {acc_gap*100:.2f}%")
print("-" * 70)

if acc_gap < 0.05:
    print("No significant overfitting!")
elif acc_gap < 0.10:
    print("Minor overfitting - acceptable")
else:
    print("Overfitting detected - consider more regularization")
print("=" * 70)

## 13. Load Best Model and Evaluate on Test Set

In [None]:
# Load best model
best_model = keras.models.load_model(
    os.path.join(MODEL_DIR, 'best_model.keras'),
    custom_objects={'focal_loss_fn': focal_loss(gamma=2.0, alpha=0.25)}
)

print("Best model loaded from: best_model.keras")

# Make predictions on test set
test_generator.reset()
print("\nMaking predictions on test set...")
predictions = best_model.predict(test_generator, verbose=1)

# Get true labels and predicted labels
y_true = test_generator.classes
y_pred = np.argmax(predictions, axis=1)
y_pred_proba = np.max(predictions, axis=1)

# Calculate accuracy
test_accuracy = np.mean(y_true == y_pred)
print(f"\n" + "=" * 70)
print(f"TEST ACCURACY: {test_accuracy*100:.2f}%")
print("=" * 70)

## 14. Detailed Class-wise Metrics (Madam's Requirements)

Checking:
- Precision, Recall, F1-score for each class
- P, R, F1 should be close to each other
- Similar values across all classes
- Accuracy close to macro F1

In [None]:
# Calculate per-class metrics
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)
macro_precision = np.mean(precision)
macro_recall = np.mean(recall)
macro_f1 = np.mean(f1)

print("=" * 85)
print("DETAILED CLASS-WISE METRICS")
print("=" * 85)
print(f"\n{'Class':<20} {'Precision':>12} {'Recall':>12} {'F1-Score':>12} {'Support':>10}")
print("-" * 85)

for i, cls in enumerate(CLASS_NAMES):
    print(f"{cls:<20} {precision[i]*100:>11.2f}% {recall[i]*100:>11.2f}% {f1[i]*100:>11.2f}% {support[i]:>10}")

print("-" * 85)
print(f"{'MACRO AVERAGE':<20} {macro_precision*100:>11.2f}% {macro_recall*100:>11.2f}% {macro_f1*100:>11.2f}%")
print("=" * 85)

In [None]:
# Visualize class-wise metrics
fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(CLASS_NAMES))
width = 0.25

bars1 = ax.bar(x - width, precision * 100, width, label='Precision', color='#3498db')
bars2 = ax.bar(x, recall * 100, width, label='Recall', color='#2ecc71')
bars3 = ax.bar(x + width, f1 * 100, width, label='F1-Score', color='#e74c3c')

ax.set_xlabel('Class', fontweight='bold')
ax.set_ylabel('Score (%)', fontweight='bold')
ax.set_title('Class-wise Precision, Recall, and F1-Score', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels([c.replace('_', '\n') for c in CLASS_NAMES])
ax.legend(loc='lower right')
ax.set_ylim([0, 110])
ax.grid(axis='y', alpha=0.3)

# Add value labels
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.1f}%',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'class_metrics.png'), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Check metric balance (Madam's requirements)
print("=" * 85)
print("METRIC BALANCE CHECK (Madam's Requirements)")
print("=" * 85)

for i, cls in enumerate(CLASS_NAMES):
    p, r, f = precision[i], recall[i], f1[i]
    max_diff = max(abs(p - r), abs(p - f), abs(r - f))
    
    print(f"\n{cls.upper()}:")
    print(f"  Precision: {p*100:.2f}%")
    print(f"  Recall:    {r*100:.2f}%")
    print(f"  F1-Score:  {f*100:.2f}%")
    print(f"  Max difference: {max_diff*100:.2f}%")
    
    if max_diff < 0.05:
        print(f"  --> Well balanced!")
    elif max_diff < 0.10:
        print(f"  --> Acceptable balance")
    else:
        print(f"  --> Needs attention")

# Cross-class balance
print("\n" + "-" * 85)
print("CROSS-CLASS BALANCE:")
f1_range = max(f1) - min(f1)
print(f"F1-Score range: {f1_range*100:.2f}%")
if f1_range < 0.10:
    print("Similar performance across all classes!")
else:
    print("Some variation across classes")
print("=" * 85)

In [None]:
# Check accuracy vs F1 alignment
acc_f1_diff = abs(test_accuracy - macro_f1)

print("=" * 85)
print("ACCURACY vs F1-SCORE ALIGNMENT")
print("=" * 85)
print(f"\nTest Accuracy:    {test_accuracy*100:.2f}%")
print(f"Macro F1-Score:   {macro_f1*100:.2f}%")
print(f"Difference:       {acc_f1_diff*100:.2f}%")
print("-" * 85)

if acc_f1_diff < 0.02:
    print("Excellent! Accuracy and F1 are very close (< 2%)")
elif acc_f1_diff < 0.05:
    print("Good! Accuracy and F1 are reasonably close (< 5%)")
else:
    print("Moderate difference - may indicate class imbalance")
print("=" * 85)

## 15. Confusion Matrix

In [None]:
# Create confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrices
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('Confusion Matrix - Leaf Dieback Model v6', fontsize=14, fontweight='bold')

# Counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[0],
            cbar_kws={'label': 'Count'})
axes[0].set_title('Confusion Matrix (Counts)', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Predicted Label')
axes[0].set_ylabel('True Label')

# Percentages
cm_pct = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
sns.heatmap(cm_pct, annot=True, fmt='.1f', cmap='RdYlGn',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[1],
            cbar_kws={'label': 'Percentage (%)'})
axes[1].set_title('Confusion Matrix (Row %)', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Predicted Label')
axes[1].set_ylabel('True Label')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'confusion_matrix.png'), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Print confusion matrix details
print("\n" + "=" * 70)
print("CONFUSION MATRIX ANALYSIS")
print("=" * 70)

for i, true_cls in enumerate(CLASS_NAMES):
    print(f"\n{true_cls.upper()} (True):")
    for j, pred_cls in enumerate(CLASS_NAMES):
        count = cm[i, j]
        pct = cm_pct[i, j]
        if i == j:
            print(f"  Correctly classified as {pred_cls}: {count} ({pct:.1f}%)")
        elif count > 0:
            print(f"  Misclassified as {pred_cls}: {count} ({pct:.1f}%)")

print("\n" + "=" * 70)

## 16. Classification Report

In [None]:
# Print sklearn classification report
print("\n" + "=" * 85)
print("SKLEARN CLASSIFICATION REPORT")
print("=" * 85)
print(classification_report(y_true, y_pred, target_names=CLASS_NAMES, digits=4))
print("=" * 85)

## 17. Sample Predictions Visualization

In [None]:
# Get filenames
filenames = test_generator.filenames

# Find correct and wrong predictions
correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]]
wrong_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]]

print(f"Total test images: {len(y_true)}")
print(f"Correct predictions: {len(correct_idx)} ({len(correct_idx)/len(y_true)*100:.1f}%)")
print(f"Wrong predictions: {len(wrong_idx)} ({len(wrong_idx)/len(y_true)*100:.1f}%)")

In [None]:
# Plot correct predictions
n_samples = min(10, len(correct_idx))
fig, axes = plt.subplots(2, 5, figsize=(15, 7))
fig.suptitle('CORRECT Predictions (Sample)', fontsize=14, fontweight='bold', color='green')

sample_correct = random.sample(correct_idx, n_samples)
for idx, i in enumerate(sample_correct):
    row, col = idx // 5, idx % 5
    img_path = os.path.join(BALANCED_DATASET_DIR, 'test', filenames[i])
    img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
    
    axes[row, col].imshow(img)
    axes[row, col].axis('off')
    pred_label = CLASS_NAMES[y_pred[i]]
    confidence = predictions[i][y_pred[i]] * 100
    axes[row, col].set_title(f'{pred_label}\n({confidence:.1f}%)', fontsize=9, color='green')

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'correct_predictions.png'), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Plot wrong predictions (if any)
if len(wrong_idx) > 0:
    n_wrong = min(10, len(wrong_idx))
    rows = (n_wrong + 4) // 5
    fig, axes = plt.subplots(rows, 5, figsize=(15, 3.5 * rows))
    fig.suptitle(f'WRONG Predictions ({len(wrong_idx)} total)', fontsize=14, fontweight='bold', color='red')
    
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    for idx, i in enumerate(wrong_idx[:n_wrong]):
        row, col = idx // 5, idx % 5
        img_path = os.path.join(BALANCED_DATASET_DIR, 'test', filenames[i])
        img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
        
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
        true_label = CLASS_NAMES[y_true[i]]
        pred_label = CLASS_NAMES[y_pred[i]]
        confidence = predictions[i][y_pred[i]] * 100
        axes[row, col].set_title(f'True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)', 
                                fontsize=9, color='red')
    
    # Hide empty subplots
    for idx in range(n_wrong, rows * 5):
        row, col = idx // 5, idx % 5
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_DIR, 'wrong_predictions.png'), dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("\nPerfect! No wrong predictions!")

## 18. Confidence Analysis

In [None]:
# Analyze confidence distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('Prediction Confidence Analysis', fontsize=14, fontweight='bold')

# Correct vs wrong
correct_conf = [y_pred_proba[i] for i in correct_idx]
wrong_conf = [y_pred_proba[i] for i in wrong_idx] if wrong_idx else []

axes[0].hist(correct_conf, bins=20, alpha=0.7, label=f'Correct (n={len(correct_conf)})', color='green')
if wrong_conf:
    axes[0].hist(wrong_conf, bins=20, alpha=0.7, label=f'Wrong (n={len(wrong_conf)})', color='red')
axes[0].set_xlabel('Confidence')
axes[0].set_ylabel('Count')
axes[0].set_title('Confidence: Correct vs Wrong', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# By class
for i, cls in enumerate(CLASS_NAMES):
    cls_conf = [y_pred_proba[j] for j in range(len(y_pred)) if y_pred[j] == i]
    axes[1].hist(cls_conf, bins=15, alpha=0.5, label=cls)

axes[1].set_xlabel('Confidence')
axes[1].set_ylabel('Count')
axes[1].set_title('Confidence by Predicted Class', fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'confidence_analysis.png'), dpi=150, bbox_inches='tight')
plt.show()

print("\nConfidence Statistics:")
print(f"  Correct: Mean={np.mean(correct_conf)*100:.1f}%, Min={np.min(correct_conf)*100:.1f}%")
if wrong_conf:
    print(f"  Wrong: Mean={np.mean(wrong_conf)*100:.1f}%, Max={np.max(wrong_conf)*100:.1f}%")

## 19. Save Model Information

In [None]:
# Create model info
model_info = {
    'model_name': 'leaf_dieback_v6',
    'model_type': 'disease_detection',
    'architecture': 'MobileNetV2',
    'classes': CLASS_NAMES,
    'num_classes': len(CLASS_NAMES),
    'input_shape': [IMG_SIZE, IMG_SIZE, 3],
    'training': {
        'phase1_epochs': len(history_phase1.history['accuracy']),
        'phase2_epochs': len(history_phase2.history['accuracy']),
        'batch_size': BATCH_SIZE,
        'learning_rate_phase1': LEARNING_RATE_PHASE1,
        'learning_rate_phase2': LEARNING_RATE_PHASE2,
        'loss_function': 'Focal Loss (gamma=2.0, alpha=0.25)',
        'training_time_minutes': round(total_time, 1),
        'final_train_accuracy': float(final_train_acc),
        'final_val_accuracy': float(final_val_acc),
        'overfitting_gap': float(acc_gap)
    },
    'data': {
        'train_samples': train_generator.samples,
        'val_samples': val_generator.samples,
        'test_samples': test_generator.samples,
        'balanced': True
    },
    'test_performance': {
        'accuracy': float(test_accuracy),
        'macro_precision': float(macro_precision),
        'macro_recall': float(macro_recall),
        'macro_f1': float(macro_f1),
        'accuracy_f1_difference': float(acc_f1_diff),
        'per_class': {}
    }
}

# Add per-class metrics
for i, cls in enumerate(CLASS_NAMES):
    model_info['test_performance']['per_class'][cls] = {
        'precision': float(precision[i]),
        'recall': float(recall[i]),
        'f1_score': float(f1[i]),
        'support': int(support[i])
    }

# Save
model_info_path = os.path.join(MODEL_DIR, 'model_info.json')
with open(model_info_path, 'w') as f:
    json.dump(model_info, f, indent=2)

print(f"Model information saved to: {model_info_path}")

In [None]:
# List saved files
print("\n" + "=" * 70)
print("SAVED FILES")
print("=" * 70)

for f in sorted(os.listdir(MODEL_DIR)):
    file_path = os.path.join(MODEL_DIR, f)
    size_kb = os.path.getsize(file_path) / 1024
    if size_kb > 1024:
        print(f"  {f} ({size_kb/1024:.1f} MB)")
    else:
        print(f"  {f} ({size_kb:.1f} KB)")

print("=" * 70)

## 20. Final Summary

In [None]:
print("\n" + "=" * 85)
print("                LEAF DIEBACK DETECTION MODEL v6 - FINAL SUMMARY")
print("=" * 85)

print("\nMODEL INFORMATION:")
print("-" * 65)
print(f"  Name:              Leaf Dieback Detection Model v6")
print(f"  Architecture:      MobileNetV2 (Transfer Learning)")
print(f"  Loss Function:     Focal Loss (gamma=2.0)")
print(f"  Input Size:        {IMG_SIZE}x{IMG_SIZE}x3")
print(f"  Classes:           {CLASS_NAMES}")

print("\nTRAINING SUMMARY:")
print("-" * 65)
print(f"  Phase 1 (Frozen):     {len(history_phase1.history['accuracy'])} epochs")
print(f"  Phase 2 (Fine-tune):  {len(history_phase2.history['accuracy'])} epochs")
print(f"  Total Training Time:  {total_time:.1f} minutes")

print("\nTEST PERFORMANCE:")
print("-" * 65)
print(f"  Test Accuracy:        {test_accuracy*100:.2f}%")
print(f"  Macro Precision:      {macro_precision*100:.2f}%")
print(f"  Macro Recall:         {macro_recall*100:.2f}%")
print(f"  Macro F1-Score:       {macro_f1*100:.2f}%")

print("\nCLASS-WISE PERFORMANCE:")
print("-" * 65)
for i, cls in enumerate(CLASS_NAMES):
    print(f"  {cls:<18} P={precision[i]*100:5.2f}%  R={recall[i]*100:5.2f}%  F1={f1[i]*100:5.2f}%")

print("\nQUALITY CHECKS (Madam's Requirements):")
print("-" * 65)
print(f"  Overfitting (Train-Val gap):      {acc_gap*100:.2f}%  {'Pass' if acc_gap < 0.10 else 'Check'}")
print(f"  Accuracy-F1 alignment:            {acc_f1_diff*100:.2f}%  {'Pass' if acc_f1_diff < 0.05 else 'Check'}")
print(f"  Cross-class F1 range:             {(max(f1)-min(f1))*100:.2f}%  {'Pass' if (max(f1)-min(f1)) < 0.15 else 'Check'}")

for i, cls in enumerate(CLASS_NAMES):
    p, r, f = precision[i], recall[i], f1[i]
    max_diff = max(abs(p-r), abs(p-f), abs(r-f))
    status = 'Pass' if max_diff < 0.10 else 'Check'
    print(f"  {cls} P/R/F1 balance:  {max_diff*100:.2f}%  {status}")

print("\nOUTPUT FILES:")
print("-" * 65)
print(f"  Model:       {MODEL_DIR}/best_model.keras")
print(f"  Info:        {MODEL_DIR}/model_info.json")
print(f"  Charts:      {MODEL_DIR}/*.png")

print("\n" + "=" * 85)
print("                         TRAINING COMPLETE!")
print("=" * 85)