In [None]:
# %% [markdown]
# # 03 - Model Training & Experimentation
# This notebook tests different model architectures and hyperparameters

# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from datetime import datetime

import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix

import warnings
warnings.filterwarnings('ignore')

# Import custom modules
import sys
sys.path.append('src')
from model import ModelBuilder
from data_loader import PlantVillageDataLoader

plt.style.use('seaborn-v0_8-darkgrid')

# %%
print("TensorFlow version:", tf.__version__)
print("GPU Available:", len(tf.config.list_physical_devices('GPU')) > 0)
print("GPU Devices:", tf.config.list_physical_devices('GPU'))

# %% [markdown]
# ## 1. Load Data

# %%
# Initialize data loader
loader = PlantVillageDataLoader(
    raw_data_path='data/raw/PlantVillage',
    processed_data_path='data/processed',
    img_size=(224, 224),
    seed=42
)

# Create data generators
train_gen, val_gen, test_gen = loader.create_data_generators(
    batch_size=32,
    augment_train=True
)

# Load class names
with open('data/processed/class_names.json', 'r') as f:
    class_names = json.load(f)

print(f"\n✅ Data loaded:")
print(f"   Train: {train_gen.samples} images")
print(f"   Val: {val_gen.samples} images")
print(f"   Test: {test_gen.samples} images")
print(f"   Classes: {len(class_names)}")

# %% [markdown]
# ## 2. Experiment 1: Compare Model Architectures

# %%
def quick_train_test(model_name, epochs=5):
    """
    Quickly train model for comparison
    """
    print(f"\n{'='*70}")
    print(f"Testing: {model_name.upper()}")
    print('='*70)
    
    # Build model
    model = ModelBuilder.build(
        model_name=model_name,
        input_shape=(224, 224, 3),
        num_classes=len(class_names)
    )
    
    # Compile
    ModelBuilder.compile_model(model, learning_rate=0.001)
    
    # Train for few epochs
    history = model.fit(
        train_gen,
        epochs=epochs,
        validation_data=val_gen,
        steps_per_epoch=100,  # Limited for quick test
        validation_steps=50,
        verbose=1
    )
    
    # Evaluate
    val_loss, val_acc, val_top3, val_precision, val_recall = model.evaluate(
        val_gen,
        steps=50,
        verbose=0
    )
    
    results = {
        'model': model_name,
        'params': model.count_params(),
        'val_accuracy': val_acc,
        'val_top3': val_top3,
        'val_loss': val_loss,
        'training_history': history.history
    }
    
    return results

# %% [markdown]
# ### Test Different Architectures

# %%
# IMPORTANT: This will take time! Comment out models you don't want to test

models_to_test = [
    'custom_cnn',
    'mobilenet_v2',
    # 'efficientnet_b0',  # Uncomment to test
    # 'resnet50'  # Uncomment to test
]

comparison_results = []

for model_name in models_to_test:
    try:
        result = quick_train_test(model_name, epochs=5)
        comparison_results.append(result)
    except Exception as e:
        print(f"❌ Error with {model_name}: {e}")

# %%
# Compare results
df_comparison = pd.DataFrame([
    {
        'Model': r['model'],
        'Parameters': f"{r['params']:,}",
        'Val Accuracy': f"{r['val_accuracy']:.4f}",
        'Val Top-3': f"{r['val_top3']:.4f}",
        'Val Loss': f"{r['val_loss']:.4f}"
    }
    for r in comparison_results
])

print("\n" + "="*70)
print("MODEL COMPARISON (5 epochs)")
print("="*70)
print(df_comparison.to_string(index=False))

# %%
# Visualize comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

models = [r['model'] for r in comparison_results]
accuracies = [r['val_accuracy'] for r in comparison_results]
params = [r['params'] / 1e6 for r in comparison_results]  # in millions

# Accuracy comparison
axes[0].bar(models, accuracies, color=['skyblue', 'coral', 'lightgreen', 'plum'][:len(models)])
axes[0].set_ylabel('Validation Accuracy', fontsize=12)
axes[0].set_title('Model Accuracy Comparison (5 epochs)', fontsize=14, fontweight='bold')
axes[0].set_ylim([0, 1])
for i, v in enumerate(accuracies):
    axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontweight='bold')

# Parameters comparison
axes[1].bar(models, params, color=['skyblue', 'coral', 'lightgreen', 'plum'][:len(models)])
axes[1].set_ylabel('Parameters (Millions)', fontsize=12)
axes[1].set_title('Model Size Comparison', fontsize=14, fontweight='bold')
for i, v in enumerate(params):
    axes[1].text(i, v + 0.1, f'{v:.1f}M', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig('results/model_architecture_comparison.png', dpi=150)
plt.show()

# %% [markdown]
# ## 3. Experiment 2: Learning Rate Tuning

# %%
def test_learning_rate(lr, epochs=10):
    """Test different learning rates"""
    print(f"\n Testing LR: {lr}")
    
    model = ModelBuilder.build('mobilenet_v2', num_classes=len(class_names))
    ModelBuilder.compile_model(model, learning_rate=lr)
    
    history = model.fit(
        train_gen,
        epochs=epochs,
        validation_data=val_gen,
        steps_per_epoch=100,
        validation_steps=50,
        verbose=0
    )
    
    return history.history

# %%
# Test different learning rates
learning_rates = [0.0001, 0.0005, 0.001, 0.005]
lr_results = {}

print("Testing learning rates...")
for lr in learning_rates:
    lr_results[lr] = test_learning_rate(lr, epochs=5)

# %%
# Plot learning rate comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

for lr in learning_rates:
    ax1.plot(lr_results[lr]['accuracy'], label=f'LR={lr}', marker='o')
    ax2.plot(lr_results[lr]['val_accuracy'], label=f'LR={lr}', marker='o')

ax1.set_title('Training Accuracy vs Learning Rate', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.set_title('Validation Accuracy vs Learning Rate', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('results/learning_rate_comparison.png', dpi=150)
plt.show()

# Find best LR
best_lr = max(learning_rates, key=lambda lr: max(lr_results[lr]['val_accuracy']))
print(f"\n✅ Best Learning Rate: {best_lr}")

# %% [markdown]
# ## 4. Experiment 3: Batch Size Impact

# %%
def test_batch_size(batch_size, epochs=5):
    """Test different batch sizes"""
    print(f"\nTesting Batch Size: {batch_size}")
    
    # Create new generators with different batch size
    train_gen_temp, val_gen_temp, _ = loader.create_data_generators(
        batch_size=batch_size,
        augment_train=True
    )
    
    model = ModelBuilder.build('mobilenet_v2', num_classes=len(class_names))
    ModelBuilder.compile_model(model, learning_rate=0.001)
    
    history = model.fit(
        train_gen_temp,
        epochs=epochs,
        validation_data=val_gen_temp,
        steps_per_epoch=50,
        validation_steps=25,
        verbose=0
    )
    
    return history.history

# %%
# Test different batch sizes (if you have enough memory)
batch_sizes = [16, 32]  # Add 64, 128 if you have GPU with enough memory
batch_results = {}

print("Testing batch sizes...")
for bs in batch_sizes:
    try:
        batch_results[bs] = test_batch_size(bs, epochs=5)
    except Exception as e:
        print(f"❌ Batch size {bs} failed: {e}")

# %%
# Plot batch size comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for bs in batch_results.keys():
    axes[0].plot(batch_results[bs]['loss'], label=f'BS={bs}', marker='o')
    axes[1].plot(batch_results[bs]['val_loss'], label=f'BS={bs}', marker='o')

axes[0].set_title('Training Loss vs Batch Size', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_title('Validation Loss vs Batch Size', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('results/batch_size_comparison.png', dpi=150)
plt.show()

# %% [markdown]
# ## 5. Full Training: Best Configuration

# %%
print("\n" + "="*70)
print("TRAINING FINAL MODEL (BEST CONFIGURATION)")
print("="*70)

# Best configuration from experiments
BEST_CONFIG = {
    'model': 'mobilenet_v2',
    'learning_rate': 0.001,
    'batch_size': 32,
    'epochs': 20  # Increase to 30-50 for production
}

print("\nBest Configuration:")
for k, v in BEST_CONFIG.items():
    print(f"  {k}: {v}")

# %%
# Build best model
final_model = ModelBuilder.build(
    model_name=BEST_CONFIG['model'],
    input_shape=(224, 224, 3),
    num_classes=len(class_names),
    trainable_layers=20,
    dropout_rate=0.5
)

ModelBuilder.compile_model(final_model, learning_rate=BEST_CONFIG['learning_rate'])

# %%
# Callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-7,
        verbose=1
    )
]

# %%
# Train final model
print("\n🚀 Starting training...")
print("⏰ This will take time! Go grab a coffee ☕")

history = final_model.fit(
    train_gen,
    epochs=BEST_CONFIG['epochs'],
    validation_data=val_gen,
    callbacks=callbacks,
    verbose=1
)

print("\n✅ Training complete!")

# %%
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Accuracy
axes[0, 0].plot(history.history['accuracy'], label='Train', linewidth=2)
axes[0, 0].plot(history.history['val_accuracy'], label='Validation', linewidth=2)
axes[0, 0].set_title('Model Accuracy', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Loss
axes[0, 1].plot(history.history['loss'], label='Train', linewidth=2)
axes[0, 1].plot(history.history['val_loss'], label='Validation', linewidth=2)
axes[0, 1].set_title('Model Loss', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Top-3 Accuracy
axes[1, 0].plot(history.history['top_3_accuracy'], label='Train', linewidth=2)
axes[1, 0].plot(history.history['val_top_3_accuracy'], label='Validation', linewidth=2)
axes[1, 0].set_title('Top-3 Accuracy', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Top-3 Accuracy')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Precision and Recall
axes[1, 1].plot(history.history['precision'], label='Train Precision', linewidth=2)
axes[1, 1].plot(history.history['val_precision'], label='Val Precision', linewidth=2)
axes[1, 1].plot(history.history['recall'], label='Train Recall', linewidth=2, linestyle='--')
axes[1, 1].plot(history.history['val_recall'], label='Val Recall', linewidth=2, linestyle='--')
axes[1, 1].set_title('Precision & Recall', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Score')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('Training History - Final Model', fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.savefig('results/final_training_curves.png', dpi=150)
plt.show()

# %% [markdown]
# ## 6. Model Evaluation on Test Set

# %%
print("\n" + "="*70)
print("EVALUATING ON TEST SET")
print("="*70)

# Evaluate
test_loss, test_acc, test_top3, test_precision, test_recall = final_model.evaluate(
    test_gen,
    verbose=1
)

print(f"\n📊 Test Results:")
print(f"   Accuracy: {test_acc:.4f}")
print(f"   Top-3 Accuracy: {test_top3:.4f}")
print(f"   Loss: {test_loss:.4f}")
print(f"   Precision: {test_precision:.4f}")
print(f"   Recall: {test_recall:.4f}")

# %%
# Get predictions for confusion matrix
print("\nGenerating predictions...")
y_pred_probs = final_model.predict(test_gen, verbose=1)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = test_gen.classes

# %%
# Classification Report
print("\n" + "="*70)
print("CLASSIFICATION REPORT")
print("="*70)
report = classification_report(y_true, y_pred, target_names=class_names, digits=4)
print(report)

# %%
# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

plt.figure(figsize=(20, 18))
sns.heatmap(
    cm_normalized,
    annot=False,
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    cbar_kws={'label': 'Normalized Count'}
)
plt.title('Confusion Matrix (Normalized)', fontsize=16, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=90, ha='right', fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig('results/final_confusion_matrix.png', dpi=150)
plt.show()

# %% [markdown]
# ## 7. Analyze Best and Worst Performing Classes

# %%
# Per-class accuracy
report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)

class_metrics = []
for class_name in class_names:
    if class_name in report_dict:
        class_metrics.append({
            'class': class_name,
            'precision': report_dict[class_name]['precision'],
            'recall': report_dict[class_name]['recall'],
            'f1-score': report_dict[class_name]['f1-score'],
            'support': report_dict[class_name]['support']
        })

df_metrics = pd.DataFrame(class_metrics)
df_metrics = df_metrics.sort_values('f1-score', ascending=False)

# %%
# Top 10 best classes
print("\n" + "="*70)
print("TOP 10 BEST PERFORMING CLASSES")
print("="*70)
print(df_metrics.head(10).to_string(index=False))

# %%
# Bottom 10 worst classes
print("\n" + "="*70)
print("BOTTOM 10 WORST PERFORMING CLASSES")
print("="*70)
print(df_metrics.tail(10).to_string(index=False))

# %%
# Visualize per-class performance
fig, ax = plt.subplots(figsize=(12, max(10, len(class_names) * 0.3)))

x = np.arange(len(df_metrics))
width = 0.25

ax.barh(x - width, df_metrics['precision'], width, label='Precision', color='steelblue')
ax.barh(x, df_metrics['recall'], width, label='Recall', color='coral')
ax.barh(x + width, df_metrics['f1-score'], width, label='F1-Score', color='mediumseagreen')

ax.set_xlabel('Score', fontsize=12)
ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
ax.set_yticks(x)
ax.set_yticklabels(df_metrics['class'], fontsize=8)
ax.legend()
ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig('results/per_class_performance.png', dpi=150)
plt.show()

# %% [markdown]
# ## 8. Analyze Misclassifications

# %%
# Find most confused pairs
misclassified = cm.copy()
np.fill_diagonal(misclassified, 0)

# Top 10 confused pairs
confused_pairs = []
for i in range(len(class_names)):
    for j in range(len(class_names)):
        if i != j and misclassified[i, j] > 0:
            confused_pairs.append({
                'true': class_names[i],
                'predicted': class_names[j],
                'count': misclassified[i, j]
            })

df_confused = pd.DataFrame(confused_pairs).sort_values('count', ascending=False)

print("\n" + "="*70)
print("TOP 10 MOST CONFUSED CLASS PAIRS")
print("="*70)
print(df_confused.head(10).to_string(index=False))

# %%
# Visualize top confused pairs
top_confused = df_confused.head(10)

fig, ax = plt.subplots(figsize=(12, 8))
y_pos = np.arange(len(top_confused))
labels = [f"{row['true'][:30]}\n→ {row['predicted'][:30]}" for _, row in top_confused.iterrows()]

ax.barh(y_pos, top_confused['count'], color='salmon')
ax.set_yticks(y_pos)
ax.set_yticklabels(labels, fontsize=9)
ax.set_xlabel('Number of Misclassifications', fontsize=12)
ax.set_title('Top 10 Most Confused Class Pairs', fontsize=14, fontweight='bold')
ax.invert_yaxis()

for i, v in enumerate(top_confused['count']):
    ax.text(v + 0.5, i, str(int(v)), va='center', fontweight='bold')

plt.tight_layout()
plt.savefig('results/confused_pairs.png', dpi=150)
plt.show()

# %% [markdown]
# ## 9. Save Model

# %%
# Save model
model_path = Path('models/notebook_trained_model.h5')
model_path.parent.mkdir(exist_ok=True)
final_model.save(model_path)

print(f"\n✅ Model saved to {model_path}")
print(f"   Model size: {model_path.stat().st_size / (1024*1024):.2f} MB")

# %%
# Save training history
history_dict = {k: [float(v) for v in vals] for k, vals in history.history.items()}

history_path = Path('results/notebook_training_history.json')
with open(history_path, 'w') as f:
    json.dump(history_dict, f, indent=2)

print(f"✅ Training history saved to {history_path}")

# %%
# Save evaluation results
eval_results = {
    'test_accuracy': float(test_acc),
    'test_top3_accuracy': float(test_top3),
    'test_loss': float(test_loss),
    'test_precision': float(test_precision),
    'test_recall': float(test_recall),
    'best_classes': df_metrics.head(5)['class'].tolist(),
    'worst_classes': df_metrics.tail(5)['class'].tolist(),
    'timestamp': datetime.now().isoformat()
}

eval_path = Path('results/notebook_evaluation.json')
with open(eval_path, 'w') as f:
    json.dump(eval_results, f, indent=2)

print(f"✅ Evaluation results saved to {eval_path}")

# %% [markdown]
# ## 10. Summary & Recommendations

# %%
print("\n" + "="*70)
print("TRAINING SUMMARY & RECOMMENDATIONS")
print("="*70)

summary = f"""
✅ FINAL MODEL PERFORMANCE:
   • Test Accuracy: {test_acc:.2%}
   • Top-3 Accuracy: {test_top3:.2%}
   • Precision: {test_precision:.2%}
   • Recall: {test_recall:.2%}

✅ MODEL CONFIGURATION:
   • Architecture: {BEST_CONFIG['model']}
   • Learning Rate: {BEST_CONFIG['learning_rate']}
   • Batch Size: {BEST_CONFIG['batch_size']}
   • Epochs Trained: {len(history.history['loss'])}

✅ BEST PERFORMING CLASSES:
   {', '.join(df_metrics.head(3)['class'].tolist())}

⚠️  CLASSES NEEDING IMPROVEMENT:
   {', '.join(df_metrics.tail(3)['class'].tolist())}

📊 NEXT STEPS:
   1. If accuracy < 95%: Train longer (40-50 epochs)
   2. Try EfficientNetB0 for +1-2% accuracy
   3. Add more augmentation for struggling classes
   4. Consider ensemble methods
   5. Deploy model using streamlit app

💡 PRODUCTION RECOMMENDATIONS:
   • Use model quantization for mobile deployment
   • Implement confidence thresholding (reject < 70%)
   • Add Grad-CAM for explainability
   • Monitor model drift in production
"""

print(summary)

# %%
print("\n" + "="*70)
print("🎉 MODEL TRAINING NOTEBOOK COMPLETE!")
print("="*70)
print("\n✅ Model trained and evaluated")
print("✅ All visualizations saved to results/")
print("✅ Model ready for deployment")
print("\n📍 Next Steps:")
print("   1. Test model: python src/predict.py --image examples/test.jpg")
print("   2. Run web app: streamlit run app/app.py")
print("   3. Push to GitHub")