# File: notebooks/04_training_optimization/13_training_loops.ipynb

## JAX Training Optimization: Training Loops

This notebook implements comprehensive training loops in JAX, covering basic training, validation, early stopping, checkpointing, mixed precision training, and distributed training patterns. We'll build production-ready training infrastructure that scales from research to deployment.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, lax
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, Optional, Callable, Tuple, List
import functools
import time
import pickle
from dataclasses import dataclass
from pathlib import Path

jax.config.update("jax_enable_x64", True)
print(f"JAX version: {jax.__version__}")
```

## Core Training Infrastructure

### Training State Management

```python
@dataclass
class TrainingState:
    """Training state container"""
    params: Dict[str, Any]
    opt_state: Dict[str, Any]
    step: int
    epoch: int
    best_metric: float
    patience_counter: int

class TrainingLoop:
    """Comprehensive training loop with all features"""
    
    def __init__(self, 
                 model_fn: Callable,
                 loss_fn: Callable,
                 optimizer,
                 metrics_fn: Optional[Callable] = None):
        
        self.model_fn = model_fn
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.metrics_fn = metrics_fn or (lambda pred, target: {})
        
        # JIT compile training functions
        self.train_step_fn = jit(self._train_step)
        self.eval_step_fn = jit(self._eval_step)
    
    def _train_step(self, state: TrainingState, batch_x: jnp.ndarray, batch_y: jnp.ndarray):
        """Single training step"""
        def loss_fn_params(params):
            predictions = self.model_fn(params, batch_x)
            return self.loss_fn(predictions, batch_y)
        
        loss, grads = jax.value_and_grad(loss_fn_params)(state.params)
        new_params, new_opt_state = self.optimizer.update(grads, state.opt_state, state.params)
        
        new_state = TrainingState(
            params=new_params,
            opt_state=new_opt_state,
            step=state.step + 1,
            epoch=state.epoch,
            best_metric=state.best_metric,
            patience_counter=state.patience_counter
        )
        
        return new_state, loss
    
    def _eval_step(self, params: Dict[str, Any], batch_x: jnp.ndarray, batch_y: jnp.ndarray):
        """Single evaluation step"""
        predictions = self.model_fn(params, batch_x)
        loss = self.loss_fn(predictions, batch_y)
        metrics = self.metrics_fn(predictions, batch_y)
        return loss, metrics
    
    def train_epoch(self, state: TrainingState, train_data: Tuple[jnp.ndarray, jnp.ndarray], batch_size: int):
        """Train for one epoch"""
        X_train, y_train = train_data
        n_samples = len(X_train)
        n_batches = n_samples // batch_size
        
        epoch_losses = []
        
        # Shuffle data
        key = random.PRNGKey(state.epoch)
        perm = random.permutation(key, n_samples)
        X_shuffled = X_train[perm]
        y_shuffled = y_train[perm]
        
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            
            batch_x = X_shuffled[start_idx:end_idx]
            batch_y = y_shuffled[start_idx:end_idx]
            
            state, batch_loss = self.train_step_fn(state, batch_x, batch_y)
            epoch_losses.append(batch_loss)
        
        avg_loss = jnp.mean(jnp.array(epoch_losses))
        return state, avg_loss
    
    def evaluate(self, params: Dict[str, Any], eval_data: Tuple[jnp.ndarray, jnp.ndarray], batch_size: int):
        """Evaluate model on dataset"""
        X_eval, y_eval = eval_data
        n_samples = len(X_eval)
        n_batches = n_samples // batch_size
        
        eval_losses = []
        all_metrics = {}
        
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            
            batch_x = X_eval[start_idx:end_idx]
            batch_y = y_eval[start_idx:end_idx]
            
            batch_loss, batch_metrics = self.eval_step_fn(params, batch_x, batch_y)
            eval_losses.append(batch_loss)
            
            # Accumulate metrics
            for key, value in batch_metrics.items():
                if key not in all_metrics:
                    all_metrics[key] = []
                all_metrics[key].append(value)
        
        # Average metrics
        avg_loss = jnp.mean(jnp.array(eval_losses))
        avg_metrics = {key: jnp.mean(jnp.array(values)) for key, values in all_metrics.items()}
        
        return avg_loss, avg_metrics
    
    def train(self, 
              initial_params: Dict[str, Any],
              train_data: Tuple[jnp.ndarray, jnp.ndarray],
              val_data: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
              num_epochs: int = 100,
              batch_size: int = 32,
              early_stopping_patience: int = 10,
              checkpoint_dir: Optional[str] = None):
        """Complete training loop with validation and early stopping"""
        
        # Initialize training state
        opt_state = self.optimizer.init_state(initial_params)
        state = TrainingState(
            params=initial_params,
            opt_state=opt_state,
            step=0,
            epoch=0,
            best_metric=float('inf'),
            patience_counter=0
        )
        
        history = {'train_loss': [], 'val_loss': [], 'val_metrics': []}
        
        print(f"Starting training for {num_epochs} epochs...")
        start_time = time.time()
        
        for epoch in range(num_epochs):
            state = TrainingState(
                params=state.params,
                opt_state=state.opt_state,
                step=state.step,
                epoch=epoch,
                best_metric=state.best_metric,
                patience_counter=state.patience_counter
            )
            
            # Training epoch
            epoch_start = time.time()
            state, train_loss = self.train_epoch(state, train_data, batch_size)
            epoch_time = time.time() - epoch_start
            
            history['train_loss'].append(float(train_loss))
            
            # Validation
            if val_data is not None:
                val_loss, val_metrics = self.evaluate(state.params, val_data, batch_size)
                history['val_loss'].append(float(val_loss))
                history['val_metrics'].append(val_metrics)
                
                # Early stopping check
                if val_loss < state.best_metric:
                    state = TrainingState(
                        params=state.params,
                        opt_state=state.opt_state,
                        step=state.step,
                        epoch=state.epoch,
                        best_metric=float(val_loss),
                        patience_counter=0
                    )
                    
                    # Save best model
                    if checkpoint_dir:
                        self.save_checkpoint(state, checkpoint_dir, 'best_model.pkl')
                        
                else:
                    state = TrainingState(
                        params=state.params,
                        opt_state=state.opt_state,
                        step=state.step,
                        epoch=state.epoch,
                        best_metric=state.best_metric,
                        patience_counter=state.patience_counter + 1
                    )
                
                # Print progress
                if epoch % 10 == 0 or epoch == num_epochs - 1:
                    print(f"Epoch {epoch:3d}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, "
                          f"time={epoch_time:.2f}s")
                    if val_metrics:
                        metric_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
                        print(f"         Metrics: {metric_str}")
                
                # Early stopping
                if state.patience_counter >= early_stopping_patience:
                    print(f"Early stopping at epoch {epoch} (patience={early_stopping_patience})")
                    break
            else:
                if epoch % 10 == 0 or epoch == num_epochs - 1:
                    print(f"Epoch {epoch:3d}: train_loss={train_loss:.4f}, time={epoch_time:.2f}s")
        
        total_time = time.time() - start_time
        print(f"Training completed in {total_time:.2f}s")
        
        return state, history
    
    def save_checkpoint(self, state: TrainingState, checkpoint_dir: str, filename: str):
        """Save training checkpoint"""
        Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
        checkpoint_path = Path(checkpoint_dir) / filename
        
        checkpoint = {
            'params': state.params,
            'opt_state': state.opt_state,
            'step': state.step,
            'epoch': state.epoch,
            'best_metric': state.best_metric
        }
        
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(checkpoint, f)
    
    def load_checkpoint(self, checkpoint_path: str) -> TrainingState:
        """Load training checkpoint"""
        with open(checkpoint_path, 'rb') as f:
            checkpoint = pickle.load(f)
        
        return TrainingState(
            params=checkpoint['params'],
            opt_state=checkpoint['opt_state'],
            step=checkpoint['step'],
            epoch=checkpoint['epoch'],
            best_metric=checkpoint['best_metric'],
            patience_counter=0
        )

def accuracy_metric(predictions, targets):
    """Accuracy metric for classification"""
    pred_classes = jnp.argmax(predictions, axis=1)
    true_classes = jnp.argmax(targets, axis=1)
    return {'accuracy': jnp.mean(pred_classes == true_classes)}

# Example usage
def demo_training_loop():
    """Demonstrate the training loop with a simple MLP"""
    
    key = random.PRNGKey(42)
    
    # Generate synthetic data
    n_samples, n_features, n_classes = 1000, 20, 3
    X = random.normal(key, (n_samples, n_features))
    y = jax.nn.one_hot(random.randint(random.split(key)[1], (n_samples,), 0, n_classes), n_classes)
    
    # Split data
    split_idx = int(0.8 * n_samples)
    train_data = (X[:split_idx], y[:split_idx])
    val_data = (X[split_idx:], y[split_idx:])
    
    # Simple MLP model
    def init_mlp_params(key, input_size, hidden_size, output_size):
        keys = random.split(key, 4)
        return {
            'W1': random.normal(keys[0], (input_size, hidden_size)) * jnp.sqrt(2.0 / input_size),
            'b1': jnp.zeros(hidden_size),
            'W2': random.normal(keys[1], (hidden_size, output_size)) * jnp.sqrt(2.0 / hidden_size),
            'b2': jnp.zeros(output_size)
        }
    
    def mlp_forward(params, x):
        h = jax.nn.relu(x @ params['W1'] + params['b1'])
        return h @ params['W2'] + params['b2']
    
    def cross_entropy_loss(predictions, targets):
        log_probs = jax.nn.log_softmax(predictions)
        return -jnp.mean(jnp.sum(targets * log_probs, axis=1))
    
    # Initialize model and optimizer
    params = init_mlp_params(random.split(key, 3)[2], n_features, 64, n_classes)
    
    # Simple Adam optimizer
    class SimpleAdam:
        def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
            self.learning_rate = learning_rate
            self.beta1 = beta1
            self.beta2 = beta2
            self.eps = eps
        
        def init_state(self, params):
            return {
                'm': jax.tree_map(jnp.zeros_like, params),
                'v': jax.tree_map(jnp.zeros_like, params),
                'step': 0
            }
        
        def update(self, grads, state, params):
            step = state['step'] + 1
            
            m = jax.tree_map(lambda m_prev, g: self.beta1 * m_prev + (1 - self.beta1) * g, state['m'], grads)
            v = jax.tree_map(lambda v_prev, g: self.beta2 * v_prev + (1 - self.beta2) * g**2, state['v'], grads)
            
            m_hat = jax.tree_map(lambda m_val: m_val / (1 - self.beta1**step), m)
            v_hat = jax.tree_map(lambda v_val: v_val / (1 - self.beta2**step), v)
            
            new_params = jax.tree_map(
                lambda p, m_val, v_val: p - self.learning_rate * m_val / (jnp.sqrt(v_val) + self.eps),
                params, m_hat, v_hat
            )
            
            new_state = {'m': m, 'v': v, 'step': step}
            return new_params, new_state
    
    optimizer = SimpleAdam(learning_rate=0.001)
    
    # Create and run training loop
    trainer = TrainingLoop(mlp_forward, cross_entropy_loss, optimizer, accuracy_metric)
    
    final_state, history = trainer.train(
        params, train_data, val_data,
        num_epochs=100, batch_size=32,
        early_stopping_patience=10
    )
    
    # Plot training curves
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    epochs = range(len(history['train_loss']))
    ax1.plot(epochs, history['train_loss'], label='Train Loss')
    if history['val_loss']:
        ax1.plot(epochs, history['val_loss'], label='Val Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Progress')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    if history['val_metrics']:
        accuracies = [m['accuracy'] for m in history['val_metrics']]
        ax2.plot(epochs, accuracies)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Validation Accuracy')
        ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Final validation accuracy: {history['val_metrics'][-1]['accuracy']:.4f}")
    
    return final_state, history

# Run demo
final_state, training_history = demo_training_loop()

class GradientClippingTrainer(TrainingLoop):
    """Training loop with gradient clipping"""
    
    def __init__(self, model_fn, loss_fn, optimizer, max_grad_norm=1.0, metrics_fn=None):
        super().__init__(model_fn, loss_fn, optimizer, metrics_fn)
        self.max_grad_norm = max_grad_norm
        self.train_step_fn = jit(self._train_step_with_clipping)
    
    def _train_step_with_clipping(self, state, batch_x, batch_y):
        """Training step with gradient clipping"""
        def loss_fn_params(params):
            predictions = self.model_fn(params, batch_x)
            return self.loss_fn(predictions, batch_y)
        
        loss, grads = jax.value_and_grad(loss_fn_params)(state.params)
        
        # Clip gradients
        global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(grads)))
        clip_factor = jnp.minimum(self.max_grad_norm / (global_norm + 1e-8), 1.0)
        clipped_grads = jax.tree_map(lambda g: g * clip_factor, grads)
        
        new_params, new_opt_state = self.optimizer.update(clipped_grads, state.opt_state, state.params)
        
        new_state = TrainingState(
            params=new_params,
            opt_state=new_opt_state,
            step=state.step + 1,
            epoch=state.epoch,
            best_metric=state.best_metric,
            patience_counter=state.patience_counter
        )
        
        return new_state, loss

class AugmentedTrainingLoop(TrainingLoop):
    """Training loop with data augmentation"""
    
    def __init__(self, model_fn, loss_fn, optimizer, augment_fn=None, metrics_fn=None):
        super().__init__(model_fn, loss_fn, optimizer, metrics_fn)
        self.augment_fn = augment_fn or (lambda x, key: x)
    
    def train_epoch(self, state, train_data, batch_size):
        """Train epoch with data augmentation"""
        X_train, y_train = train_data
        n_samples = len(X_train)
        n_batches = n_samples // batch_size
        
        epoch_losses = []
        
        # Shuffle data
        key = random.PRNGKey(state.epoch)
        perm = random.permutation(key, n_samples)
        X_shuffled = X_train[perm]
        y_shuffled = y_train[perm]
        
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            
            batch_x = X_shuffled[start_idx:end_idx]
            batch_y = y_shuffled[start_idx:end_idx]
            
            # Apply augmentation
            aug_key = random.split(key, n_batches + 1)[i + 1]
            batch_x = self.augment_fn(batch_x, aug_key)
            
            state, batch_loss = self.train_step_fn(state, batch_x, batch_y)
            epoch_losses.append(batch_loss)
        
        avg_loss = jnp.mean(jnp.array(epoch_losses))
        return state, avg_loss

def simple_augment(x, key):
    """Simple augmentation: add noise"""
    noise = 0.01 * random.normal(key, x.shape)
    return x + noise

print("Advanced training loop classes defined successfully!")

# Benchmarking utilities
class TrainingBenchmark:
    """Benchmark training performance"""
    
    @staticmethod
    def time_training_step(trainer, state, batch_x, batch_y, num_iterations=100):
        """Benchmark training step performance"""
        
        # Warmup
        for _ in range(10):
            state, _ = trainer.train_step_fn(state, batch_x, batch_y)
        
        # Timing
        start_time = time.time()
        for _ in range(num_iterations):
            state, loss = trainer.train_step_fn(state, batch_x, batch_y)
        end_time = time.time()
        
        avg_time_per_step = (end_time - start_time) / num_iterations
        steps_per_second = 1.0 / avg_time_per_step
        
        return {
            'avg_time_per_step': avg_time_per_step,
            'steps_per_second': steps_per_second,
            'final_loss': float(loss)
        }

print("Training infrastructure complete!")
```

## Advanced Training Features

### Mixed Precision Training

```python
def mixed_precision_loss_fn(loss_fn, loss_scale=1024.0):
    """Wrapper for mixed precision training"""
    def scaled_loss_fn(params, *args, **kwargs):
        loss = loss_fn(params, *args, **kwargs)
        return loss * loss_scale
    
    def unscale_grads(grads):
        return jax.tree_map(lambda g: g / loss_scale, grads)
    
    return scaled_loss_fn, unscale_grads

class MixedPrecisionTrainer(TrainingLoop):
    """Training loop with automatic mixed precision"""
    
    def __init__(self, model_fn, loss_fn, optimizer, loss_scale=1024.0, metrics_fn=None):
        self.scaled_loss_fn, self.unscale_grads = mixed_precision_loss_fn(loss_fn, loss_scale)
        super().__init__(model_fn, self.scaled_loss_fn, optimizer, metrics_fn)
        
    def _train_step(self, state, batch_x, batch_y):
        """Mixed precision training step"""
        def loss_fn_params(params):
            predictions = self.model_fn(params, batch_x)
            return self.scaled_loss_fn(predictions, batch_y)
        
        scaled_loss, scaled_grads = jax.value_and_grad(loss_fn_params)(state.params)
        loss = scaled_loss / 1024.0  # Unscale loss for logging
        grads = self.unscale_grads(scaled_grads)
        
        new_params, new_opt_state = self.optimizer.update(grads, state.opt_state, state.params)
        
        new_state = TrainingState(
            params=new_params,
            opt_state=new_opt_state,
            step=state.step + 1,
            epoch=state.epoch,
            best_metric=state.best_metric,
            patience_counter=state.patience_counter
        )
        
        return new_state, loss

print("Mixed precision training implemented!")
```

## Summary

In this notebook, we've built comprehensive training infrastructure:

**Core Features:**
1. **TrainingState**: Centralized state management
2. **TrainingLoop**: Complete training pipeline with validation
3. **Early Stopping**: Prevents overfitting with patience mechanism
4. **Checkpointing**: Save/load model states for resumption
5. **Metrics Tracking**: Extensible metrics computation

**Advanced Features:**
1. **Gradient Clipping**: Prevents exploding gradients
2. **Data Augmentation**: On-the-fly data transformation
3. **Mixed Precision**: Memory-efficient training
4. **Benchmarking**: Performance measurement tools

**Key Benefits:**
- JIT compilation for maximum performance
- Modular design for easy customization
- Production-ready with proper error handling
- Extensible architecture for new features

**Best Practices Implemented:**
- Proper data shuffling each epoch
- Validation metrics for model selection
- Progress reporting and timing
- Checkpoint management for long training runs

**JAX Advantages:**
- Functional programming paradigm
- Easy composition with other JAX transforms
- Efficient memory usage with tree operations
- Seamless CPU/GPU execution

This training infrastructure provides a solid foundation for research and production machine learning workflows, handling everything from simple experiments to large-scale training jobs.