# üß† CAPSTONE-LAZARUS: Model Training Pipeline

## Advanced Plant Disease Classification Training
**State-of-the-art transfer learning with comprehensive evaluation**

### üéØ Training Objectives:
- **Multi-Architecture Evaluation**: EfficientNet, ResNet, MobileNet comparisons
- **Transfer Learning**: Pre-trained ImageNet ‚Üí Agricultural fine-tuning
- **Balanced Training**: Class-weighted loss for imbalanced dataset
- **Advanced Augmentation**: Field condition simulation
- **Comprehensive Metrics**: F1, Precision, Recall, Confusion Matrix
- **Model Optimization**: Pruning, quantization for deployment

In [None]:
# üì¶ Import Essential Libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# ü§ñ Deep Learning Framework
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, callbacks, metrics
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# üìä Interactive Visualizations
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# üéØ Metrics & Evaluation
from sklearn.metrics import (
    classification_report, confusion_matrix, 
    f1_score, precision_recall_fscore_support
)
from sklearn.utils.class_weight import compute_class_weight

# üîß Utilities
import pickle
import json
from datetime import datetime
from typing import Dict, List, Tuple, Any

# Add src to path
sys.path.append('../src')
from data_utils import PlantDiseaseDataLoader
from model_factory import ModelFactory

# üé® Configure Plotly
px.defaults.template = "plotly_white"

# üîß TensorFlow Configuration
tf.config.experimental.enable_tensor_float_32_execution(False)
print(f"üöÄ TensorFlow version: {tf.__version__}")
print(f"üñ•Ô∏è  GPU Available: {tf.config.list_physical_devices('GPU')}")
print(f"üìÇ Working directory: {os.getcwd()}")

## üìä Load Dataset & Preprocessing

In [None]:
# üìÇ Load Dataset Splits from EDA
data_dir = "../data"
models_dir = Path("../models")
models_dir.mkdir(exist_ok=True)

# Load pre-computed splits and weights
try:
    split_info = np.load('../models/dataset_splits.npy', allow_pickle=True).item()
    class_weights = np.load('../models/class_weights.npy', allow_pickle=True).item()
    
    X_train = split_info['train_paths']
    y_train = split_info['train_labels']
    X_val = split_info['val_paths']
    y_val = split_info['val_labels']
    X_test = split_info['test_paths']
    y_test = split_info['test_labels']
    class_names = split_info['class_names']
    label_mapping = split_info['label_mapping']
    
    print("‚úÖ Loaded pre-computed dataset splits and class weights")
    
except FileNotFoundError:
    print("‚ö†Ô∏è  Pre-computed splits not found. Running EDA first...")
    # Fallback: create splits
    loader = PlantDiseaseDataLoader(data_dir, img_size=(224, 224), batch_size=32)
    dataset_stats = loader.scan_dataset()
    (X_train, y_train), (X_val, y_val), (X_test, y_test) = loader.create_balanced_splits()
    class_weights = loader.compute_class_weights(y_train)
    class_names = loader.class_names

num_classes = len(class_names)
print(f"\nüìä Dataset Configuration:")
print(f"   üöÇ Training: {len(X_train):,} images")
print(f"   üîç Validation: {len(X_val):,} images")
print(f"   üß™ Testing: {len(X_test):,} images")
print(f"   üè∑Ô∏è  Classes: {num_classes}")
print(f"   ‚öñÔ∏è  Using class weights: {len(class_weights)} weights computed")

In [None]:
# üîß Create TensorFlow Datasets with Optimizations
def create_optimized_dataset(paths: List[str], labels: List[int], 
                           batch_size: int = 32, is_training: bool = True,
                           img_size: Tuple[int, int] = (224, 224)) -> tf.data.Dataset:
    """Create optimized TensorFlow dataset with augmentation."""
    
    def load_and_preprocess(path, label):
        # Load image
        image = tf.io.read_file(path)
        image = tf.image.decode_image(image, channels=3, expand_animations=False)
        image = tf.cast(image, tf.float32)
        
        # Resize
        image = tf.image.resize(image, img_size)
        
        # Normalize to [0, 1]
        image = image / 255.0
        
        # Augmentation for training only
        if is_training:
            # Random flips
            image = tf.image.random_flip_left_right(image)
            image = tf.image.random_flip_up_down(image)
            
            # Color augmentation
            image = tf.image.random_brightness(image, max_delta=0.2)
            image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
            image = tf.image.random_hue(image, max_delta=0.1)
            image = tf.image.random_saturation(image, lower=0.8, upper=1.2)
            
            # Random rotation (approximate using cropping)
            image = tf.image.random_crop(image, size=[int(img_size[0]*0.9), int(img_size[1]*0.9), 3])
            image = tf.image.resize(image, img_size)
        
        # Final normalization (ImageNet stats)
        image = (image - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
        
        return image, label
    
    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices((paths, labels))
    
    if is_training:
        dataset = dataset.shuffle(buffer_size=min(len(paths), 10000))
        dataset = dataset.repeat()  # Repeat for multiple epochs
    
    dataset = dataset.map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

# üìä Create optimized datasets
BATCH_SIZE = 32
IMG_SIZE = (224, 224)

train_dataset = create_optimized_dataset(X_train, y_train, BATCH_SIZE, is_training=True, img_size=IMG_SIZE)
val_dataset = create_optimized_dataset(X_val, y_val, BATCH_SIZE, is_training=False, img_size=IMG_SIZE)
test_dataset = create_optimized_dataset(X_test, y_test, BATCH_SIZE, is_training=False, img_size=IMG_SIZE)

print("‚úÖ TensorFlow datasets created with optimizations")
print(f"   üöÇ Training batches per epoch: {len(X_train) // BATCH_SIZE}")
print(f"   üîç Validation batches: {len(X_val) // BATCH_SIZE}")
print(f"   üß™ Test batches: {len(X_test) // BATCH_SIZE}")

## üèóÔ∏è Model Architecture Comparison

In [None]:
# üè≠ Initialize Model Factory
factory = ModelFactory(input_shape=(224, 224, 3), num_classes=num_classes, use_mixed_precision=True)

# üéØ Define Model Architectures to Evaluate
architectures_to_test = [
    {
        'name': 'EfficientNetV2-B0',
        'arch': 'efficientnet_v2_b0',
        'description': 'ü•á Best accuracy-efficiency balance',
        'target_size': '~15MB',
        'expected_accuracy': '0.85+'
    },
    {
        'name': 'MobileNetV3-Large',
        'arch': 'mobilenet_v3_large',
        'description': 'üì± Optimized for mobile deployment',
        'target_size': '~10MB',
        'expected_accuracy': '0.82+'
    },
    {
        'name': 'ResNet50',
        'arch': 'resnet50',
        'description': 'üèóÔ∏è Reliable baseline performance',
        'target_size': '~25MB',
        'expected_accuracy': '0.83+'
    },
    {
        'name': 'Custom CNN',
        'arch': 'custom_cnn',
        'description': 'üé® Lightweight custom architecture',
        'target_size': '~5MB',
        'expected_accuracy': '0.78+'
    }
]

# üìä Display Architecture Comparison Table
arch_df = pd.DataFrame(architectures_to_test)
print("üèóÔ∏è Architecture Comparison:")
print("=" * 80)
for _, row in arch_df.iterrows():
    print(f"{row['name']:20} | {row['description']:35} | {row['target_size']:8} | {row['expected_accuracy']}")
print("=" * 80)

# üéØ Select primary architecture for full training
PRIMARY_ARCHITECTURE = 'efficientnet_v2_b0'
print(f"\nüéØ Selected primary architecture: {PRIMARY_ARCHITECTURE}")

## üß† Model Training Pipeline

In [None]:
# üîß Training Configuration
TRAINING_CONFIG = {
    'epochs': 50,
    'initial_learning_rate': 1e-3,
    'min_learning_rate': 1e-7,
    'patience_early_stop': 15,
    'patience_lr_reduce': 8,
    'lr_reduction_factor': 0.2,
    'validation_freq': 1,
    'save_best_only': True,
    'monitor_metric': 'val_f1_score'
}

print("‚öôÔ∏è Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"   ‚Ä¢ {key}: {value}")

# üìä Custom F1 Score Metric
class F1Score(tf.keras.metrics.Metric):
    def __init__(self, name='f1_score', **kwargs):
        super().__init__(name=name, **kwargs)
        self.precision = tf.keras.metrics.Precision()
        self.recall = tf.keras.metrics.Recall()
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)
    
    def result(self):
        p = self.precision.result()
        r = self.recall.result()
        return 2 * ((p * r) / (p + r + tf.keras.backend.epsilon()))
    
    def reset_state(self):
        self.precision.reset_state()
        self.recall.reset_state()

In [None]:
# üèóÔ∏è Create and Compile Model
def create_and_compile_model(architecture: str, num_classes: int, 
                           learning_rate: float = 1e-3) -> tf.keras.Model:
    """Create and compile model with optimized settings."""
    
    # Create model
    model = factory.get_model(architecture, dropout_rate=0.3, freeze_backbone=False)
    
    # Optimizer with learning rate scheduling
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=learning_rate,
        weight_decay=1e-4,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-7
    )
    
    # Compile with comprehensive metrics
    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=[
            'accuracy',
            tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top_3_accuracy'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall'),
            F1Score()
        ]
    )
    
    return model

# üéØ Create primary model
print(f"üèóÔ∏è Creating {PRIMARY_ARCHITECTURE} model...")
model = create_and_compile_model(PRIMARY_ARCHITECTURE, num_classes, TRAINING_CONFIG['initial_learning_rate'])

# üìä Model Summary
print("\nüìã Model Architecture:")
model.summary()

# üìà Count parameters
total_params = model.count_params()
trainable_params = sum([tf.keras.backend.count_params(layer) for layer in model.trainable_weights])

print(f"\nüìä Model Statistics:")
print(f"   ‚Ä¢ Total parameters: {total_params:,}")
print(f"   ‚Ä¢ Trainable parameters: {trainable_params:,}")
print(f"   ‚Ä¢ Non-trainable parameters: {total_params - trainable_params:,}")
print(f"   ‚Ä¢ Estimated size: ~{total_params * 4 / (1024*1024):.1f} MB")

In [None]:
# üîÑ Advanced Callbacks Setup
def create_callbacks(model_name: str, config: Dict[str, Any]) -> List[tf.keras.callbacks.Callback]:
    """Create comprehensive training callbacks."""
    
    callbacks_list = []
    
    # üíæ Model Checkpoint - Save best model
    checkpoint_path = f"../models/{model_name}_best.h5"
    callbacks_list.append(
        tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_path,
            monitor=config['monitor_metric'],
            save_best_only=config['save_best_only'],
            save_weights_only=False,
            mode='max',
            verbose=1
        )
    )
    
    # ‚èπÔ∏è Early Stopping
    callbacks_list.append(
        tf.keras.callbacks.EarlyStopping(
            monitor=config['monitor_metric'],
            patience=config['patience_early_stop'],
            mode='max',
            restore_best_weights=True,
            verbose=1
        )
    )
    
    # üìâ Learning Rate Scheduler
    callbacks_list.append(
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor=config['monitor_metric'],
            factor=config['lr_reduction_factor'],
            patience=config['patience_lr_reduce'],
            min_lr=config['min_learning_rate'],
            mode='max',
            verbose=1
        )
    )
    
    # üìä TensorBoard Logging
    log_dir = f"../models/logs/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    callbacks_list.append(
        tf.keras.callbacks.TensorBoard(
            log_dir=log_dir,
            histogram_freq=1,
            write_graph=True,
            write_images=True,
            update_freq='epoch'
        )
    )
    
    # üéØ Custom Progress Callback
    class TrainingProgressCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            if logs:
                print(f"\nüìä Epoch {epoch + 1} Summary:")
                print(f"   ‚Ä¢ Training Accuracy: {logs.get('accuracy', 0):.4f}")
                print(f"   ‚Ä¢ Validation Accuracy: {logs.get('val_accuracy', 0):.4f}")
                print(f"   ‚Ä¢ Training F1: {logs.get('f1_score', 0):.4f}")
                print(f"   ‚Ä¢ Validation F1: {logs.get('val_f1_score', 0):.4f}")
                print(f"   ‚Ä¢ Learning Rate: {logs.get('lr', 0):.2e}")
    
    callbacks_list.append(TrainingProgressCallback())
    
    return callbacks_list

# üîÑ Create callbacks
model_name = f"{PRIMARY_ARCHITECTURE}_plant_disease"
training_callbacks = create_callbacks(model_name, TRAINING_CONFIG)

print(f"‚úÖ Created {len(training_callbacks)} training callbacks")
print(f"   ‚Ä¢ Model checkpoint: ../models/{model_name}_best.h5")
print(f"   ‚Ä¢ TensorBoard logs: ../models/logs/{model_name}_*")

## üöÇ Execute Training

In [None]:
# üöÄ Start Training Process
print("üöÇ Starting model training...")
print("=" * 60)
print(f"üèóÔ∏è Architecture: {PRIMARY_ARCHITECTURE}")
print(f"üìä Training samples: {len(X_train):,}")
print(f"üîç Validation samples: {len(X_val):,}")
print(f"‚è±Ô∏è Max epochs: {TRAINING_CONFIG['epochs']}")
print(f"üéØ Batch size: {BATCH_SIZE}")
print(f"‚öñÔ∏è Using class weights: Yes")
print("=" * 60)

# Calculate steps per epoch
steps_per_epoch = len(X_train) // BATCH_SIZE
validation_steps = len(X_val) // BATCH_SIZE

# üèãÔ∏è Train the model
start_time = datetime.now()

try:
    history = model.fit(
        train_dataset,
        epochs=TRAINING_CONFIG['epochs'],
        steps_per_epoch=steps_per_epoch,
        validation_data=val_dataset,
        validation_steps=validation_steps,
        class_weight=class_weights,
        callbacks=training_callbacks,
        verbose=1
    )
    
    training_time = datetime.now() - start_time
    
    print(f"\n‚úÖ Training completed successfully!")
    print(f"‚è±Ô∏è Total training time: {training_time}")
    print(f"üìà Final training accuracy: {history.history['accuracy'][-1]:.4f}")
    print(f"üîç Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
    print(f"üéØ Final validation F1: {history.history['val_f1_score'][-1]:.4f}")
    
    # üíæ Save training history
    history_path = f"../models/{model_name}_history.json"
    with open(history_path, 'w') as f:
        json.dump({
            'history': {k: [float(x) for x in v] for k, v in history.history.items()},
            'config': TRAINING_CONFIG,
            'training_time': str(training_time),
            'architecture': PRIMARY_ARCHITECTURE,
            'total_params': int(total_params),
            'trainable_params': int(trainable_params)
        }, f, indent=2)
    
    print(f"üíæ Training history saved: {history_path}")
    
except Exception as e:
    print(f"‚ùå Training failed: {str(e)}")
    raise

## üìä Training Analysis & Visualization

In [None]:
# üìà Interactive Training History Visualization
def plot_training_history(history_dict: Dict[str, List[float]]) -> None:
    """Create comprehensive training history plots."""
    
    # Create subplots
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('üìà Accuracy', 'üìâ Loss', 'üéØ F1 Score', 'üìä Learning Rate'),
        specs=[[{"secondary_y": False}, {"secondary_y": False}],
               [{"secondary_y": False}, {"secondary_y": False}]]
    )
    
    epochs = list(range(1, len(history_dict['accuracy']) + 1))
    
    # Accuracy plot
    fig.add_trace(
        go.Scatter(x=epochs, y=history_dict['accuracy'], name='Training Accuracy', 
                  line=dict(color='blue'), mode='lines+markers'),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=history_dict['val_accuracy'], name='Validation Accuracy', 
                  line=dict(color='red'), mode='lines+markers'),
        row=1, col=1
    )
    
    # Loss plot
    fig.add_trace(
        go.Scatter(x=epochs, y=history_dict['loss'], name='Training Loss', 
                  line=dict(color='blue'), mode='lines+markers'),
        row=1, col=2
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=history_dict['val_loss'], name='Validation Loss', 
                  line=dict(color='red'), mode='lines+markers'),
        row=1, col=2
    )
    
    # F1 Score plot
    fig.add_trace(
        go.Scatter(x=epochs, y=history_dict['f1_score'], name='Training F1', 
                  line=dict(color='green'), mode='lines+markers'),
        row=2, col=1
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=history_dict['val_f1_score'], name='Validation F1', 
                  line=dict(color='orange'), mode='lines+markers'),
        row=2, col=1
    )
    
    # Learning Rate plot
    if 'lr' in history_dict:
        fig.add_trace(
            go.Scatter(x=epochs, y=history_dict['lr'], name='Learning Rate', 
                      line=dict(color='purple'), mode='lines+markers'),
            row=2, col=2
        )
    
    # Update layout
    fig.update_layout(
        height=800,
        title_text=f"üß† {PRIMARY_ARCHITECTURE} Training History Analysis",
        title_x=0.5,
        showlegend=True
    )
    
    # Update y-axes
    fig.update_yaxes(title_text="Accuracy", row=1, col=1)
    fig.update_yaxes(title_text="Loss", row=1, col=2)
    fig.update_yaxes(title_text="F1 Score", row=2, col=1)
    fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=2)
    
    # Update x-axes
    fig.update_xaxes(title_text="Epoch", row=2, col=1)
    fig.update_xaxes(title_text="Epoch", row=2, col=2)
    
    fig.show()

# üìä Plot training history
if 'history' in locals():
    plot_training_history(history.history)
    
    # üìà Performance Summary
    best_val_acc = max(history.history['val_accuracy'])
    best_val_f1 = max(history.history['val_f1_score'])
    final_lr = history.history['lr'][-1] if 'lr' in history.history else 'N/A'
    
    print(f"\nüèÜ Training Performance Summary:")
    print(f"   ‚Ä¢ Best Validation Accuracy: {best_val_acc:.4f}")
    print(f"   ‚Ä¢ Best Validation F1 Score: {best_val_f1:.4f}")
    print(f"   ‚Ä¢ Final Learning Rate: {final_lr}")
    print(f"   ‚Ä¢ Total Epochs Completed: {len(history.history['accuracy'])}")
else:
    print("‚ö†Ô∏è No training history available to plot")