# File: notebooks/04_training_optimization/11_optimizers_in_jax.ipynb

## JAX Training Optimization: Optimizers in JAX

This notebook implements various optimization algorithms from scratch in JAX, including SGD, momentum variants, Adam, AdamW, RMSprop, and advanced techniques like learning rate scheduling and gradient clipping. We'll analyze their behavior and provide practical guidance for choosing optimizers.

Understanding optimization algorithms is crucial for effective neural network training, as the choice of optimizer significantly impacts convergence speed, stability, and final performance.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, lax
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, Optional, Callable, Tuple
import functools
from dataclasses import dataclass

jax.config.update("jax_enable_x64", True)
print(f"JAX version: {jax.__version__}")
```

## Base Optimizer Interface

### Optimizer Abstract Base

```python
class Optimizer:
    """Base class for optimizers"""
    
    def __init__(self, learning_rate: float):
        self.learning_rate = learning_rate
    
    def init_state(self, params: Dict[str, Any]) -> Dict[str, Any]:
        """Initialize optimizer state"""
        raise NotImplementedError
    
    def update(self, grads: Dict[str, Any], state: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """Update parameters given gradients"""
        raise NotImplementedError
    
    def get_learning_rate(self, step: int) -> float:
        """Get learning rate for current step (can be overridden for scheduling)"""
        return self.learning_rate

def clip_gradients(grads, max_norm):
    """Clip gradients by global norm"""
    global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(grads)))
    clip_factor = jnp.minimum(max_norm / (global_norm + 1e-8), 1.0)
    return jax.tree_map(lambda g: g * clip_factor, grads)
```

## Basic Optimizers

### Stochastic Gradient Descent (SGD)

```python
class SGD(Optimizer):
    """Stochastic Gradient Descent optimizer"""
    
    def __init__(self, learning_rate: float = 0.01, momentum: float = 0.0, weight_decay: float = 0.0):
        super().__init__(learning_rate)
        self.momentum = momentum
        self.weight_decay = weight_decay
    
    def init_state(self, params):
        """Initialize SGD state"""
        if self.momentum > 0:
            return {
                'velocity': jax.tree_map(jnp.zeros_like, params),
                'step': 0
            }
        return {'step': 0}
    
    def update(self, grads, state, params):
        """SGD parameter update"""
        step = state['step'] + 1
        
        # Add weight decay to gradients
        if self.weight_decay > 0:
            grads = jax.tree_map(
                lambda g, p: g + self.weight_decay * p,
                grads, params
            )
        
        if self.momentum > 0:
            # Momentum update
            velocity = jax.tree_map(
                lambda v, g: self.momentum * v + g,
                state['velocity'], grads
            )
            
            new_params = jax.tree_map(
                lambda p, v: p - self.learning_rate * v,
                params, velocity
            )
            
            new_state = {'velocity': velocity, 'step': step}
        else:
            # Vanilla SGD
            new_params = jax.tree_map(
                lambda p, g: p - self.learning_rate * g,
                params, grads
            )
            new_state = {'step': step}
        
        return new_params, new_state

# Test SGD
def test_sgd():
    """Test SGD optimizer"""
    
    key = random.PRNGKey(42)
    params = {'w': random.normal(key, (3, 2)), 'b': jnp.zeros(2)}
    grads = {'w': random.normal(random.split(key)[1], (3, 2)), 'b': jnp.ones(2)}
    
    # Test vanilla SGD
    sgd = SGD(learning_rate=0.1)
    state = sgd.init_state(params)
    new_params, new_state = sgd.update(grads, state, params)
    
    print("SGD Test:")
    print(f"Original params['w'][0,0]: {params['w'][0,0]:.4f}")
    print(f"Updated params['w'][0,0]: {new_params['w'][0,0]:.4f}")
    print(f"Step count: {new_state['step']}")
    
    # Test SGD with momentum
    sgd_momentum = SGD(learning_rate=0.1, momentum=0.9)
    state_momentum = sgd_momentum.init_state(params)
    new_params_momentum, new_state_momentum = sgd_momentum.update(grads, state_momentum, params)
    
    print(f"SGD+Momentum params['w'][0,0]: {new_params_momentum['w'][0,0]:.4f}")
    print(f"Velocity shape: {new_state_momentum['velocity']['w'].shape}")

test_sgd()
```

### Nesterov Accelerated Gradient

```python
class Nesterov(Optimizer):
    """Nesterov Accelerated Gradient optimizer"""
    
    def __init__(self, learning_rate: float = 0.01, momentum: float = 0.9, weight_decay: float = 0.0):
        super().__init__(learning_rate)
        self.momentum = momentum
        self.weight_decay = weight_decay
    
    def init_state(self, params):
        """Initialize Nesterov state"""
        return {
            'velocity': jax.tree_map(jnp.zeros_like, params),
            'step': 0
        }
    
    def update(self, grads, state, params):
        """Nesterov parameter update"""
        step = state['step'] + 1
        
        # Add weight decay
        if self.weight_decay > 0:
            grads = jax.tree_map(
                lambda g, p: g + self.weight_decay * p,
                grads, params
            )
        
        # Nesterov update: look ahead
        velocity = jax.tree_map(
            lambda v, g: self.momentum * v + g,
            state['velocity'], grads
        )
        
        new_params = jax.tree_map(
            lambda p, v, g: p - self.learning_rate * (self.momentum * v + g),
            params, velocity, grads
        )
        
        new_state = {'velocity': velocity, 'step': step}
        return new_params, new_state

class RMSprop(Optimizer):
    """RMSprop optimizer"""
    
    def __init__(self, learning_rate: float = 0.001, rho: float = 0.9, eps: float = 1e-8, weight_decay: float = 0.0):
        super().__init__(learning_rate)
        self.rho = rho
        self.eps = eps
        self.weight_decay = weight_decay
    
    def init_state(self, params):
        """Initialize RMSprop state"""
        return {
            'v': jax.tree_map(jnp.zeros_like, params),
            'step': 0
        }
    
    def update(self, grads, state, params):
        """RMSprop parameter update"""
        step = state['step'] + 1
        
        # Add weight decay
        if self.weight_decay > 0:
            grads = jax.tree_map(
                lambda g, p: g + self.weight_decay * p,
                grads, params
            )
        
        # Update squared gradient accumulator
        v = jax.tree_map(
            lambda v_prev, g: self.rho * v_prev + (1 - self.rho) * g**2,
            state['v'], grads
        )
        
        # Parameter update
        new_params = jax.tree_map(
            lambda p, g, v_val: p - self.learning_rate * g / (jnp.sqrt(v_val) + self.eps),
            params, grads, v
        )
        
        new_state = {'v': v, 'step': step}
        return new_params, new_state

class Adam(Optimizer):
    """Adam optimizer"""
    
    def __init__(self, learning_rate: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, weight_decay: float = 0.0):
        super().__init__(learning_rate)
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.weight_decay = weight_decay
    
    def init_state(self, params):
        """Initialize Adam state"""
        return {
            'm': jax.tree_map(jnp.zeros_like, params),
            'v': jax.tree_map(jnp.zeros_like, params),
            'step': 0
        }
    
    def update(self, grads, state, params):
        """Adam parameter update"""
        step = state['step'] + 1
        
        # Add weight decay (L2 regularization)
        if self.weight_decay > 0:
            grads = jax.tree_map(
                lambda g, p: g + self.weight_decay * p,
                grads, params
            )
        
        # Update biased first moment estimate
        m = jax.tree_map(
            lambda m_prev, g: self.beta1 * m_prev + (1 - self.beta1) * g,
            state['m'], grads
        )
        
        # Update biased second moment estimate
        v = jax.tree_map(
            lambda v_prev, g: self.beta2 * v_prev + (1 - self.beta2) * g**2,
            state['v'], grads
        )
        
        # Bias correction
        m_hat = jax.tree_map(lambda m_val: m_val / (1 - self.beta1**step), m)
        v_hat = jax.tree_map(lambda v_val: v_val / (1 - self.beta2**step), v)
        
        # Parameter update
        new_params = jax.tree_map(
            lambda p, m_val, v_val: p - self.learning_rate * m_val / (jnp.sqrt(v_val) + self.eps),
            params, m_hat, v_hat
        )
        
        new_state = {'m': m, 'v': v, 'step': step}
        return new_params, new_state

class AdamW(Optimizer):
    """AdamW optimizer (Adam with decoupled weight decay)"""
    
    def __init__(self, learning_rate: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, weight_decay: float = 0.01):
        super().__init__(learning_rate)
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.weight_decay = weight_decay
    
    def init_state(self, params):
        """Initialize AdamW state"""
        return {
            'm': jax.tree_map(jnp.zeros_like, params),
            'v': jax.tree_map(jnp.zeros_like, params),
            'step': 0
        }
    
    def update(self, grads, state, params):
        """AdamW parameter update with decoupled weight decay"""
        step = state['step'] + 1
        
        # Update biased first moment estimate
        m = jax.tree_map(
            lambda m_prev, g: self.beta1 * m_prev + (1 - self.beta1) * g,
            state['m'], grads
        )
        
        # Update biased second moment estimate
        v = jax.tree_map(
            lambda v_prev, g: self.beta2 * v_prev + (1 - self.beta2) * g**2,
            state['v'], grads
        )
        
        # Bias correction
        m_hat = jax.tree_map(lambda m_val: m_val / (1 - self.beta1**step), m)
        v_hat = jax.tree_map(lambda v_val: v_val / (1 - self.beta2**step), v)
        
        # Parameter update with decoupled weight decay
        new_params = jax.tree_map(
            lambda p, m_val, v_val: p - self.learning_rate * (m_val / (jnp.sqrt(v_val) + self.eps) + self.weight_decay * p),
            params, m_hat, v_hat
        )
        
        new_state = {'m': m, 'v': v, 'step': step}
        return new_params, new_state

def test_optimizers():
    """Test all optimizers on simple quadratic function"""
    
    key = random.PRNGKey(42)
    
    # Simple quadratic: f(x) = x^T A x + b^T x
    A = jnp.array([[2.0, 0.5], [0.5, 1.0]])  # Positive definite
    b = jnp.array([1.0, -0.5])
    x_optimal = -jnp.linalg.solve(A, b) / 2  # Analytical optimum
    
    def objective(params):
        x = params['x']
        return jnp.dot(x, A @ x) + jnp.dot(b, x)
    
    grad_fn = grad(objective)
    
    # Test different optimizers
    optimizers = {
        'SGD': SGD(0.1),
        'Momentum': SGD(0.1, momentum=0.9),
        'Nesterov': Nesterov(0.1, momentum=0.9),
        'RMSprop': RMSprop(0.1),
        'Adam': Adam(0.1),
        'AdamW': AdamW(0.1, weight_decay=0.01)
    }
    
    n_steps = 100
    results = {}
    
    for name, optimizer in optimizers.items():
        params = {'x': jnp.array([2.0, -1.0])}  # Starting point
        state = optimizer.init_state(params)
        losses = []
        
        for step in range(n_steps):
            grads = grad_fn(params)
            params, state = optimizer.update(grads, state, params)
            loss = objective(params)
            losses.append(loss)
        
        results[name] = {
            'losses': jnp.array(losses),
            'final_params': params['x'],
            'distance_to_optimal': jnp.linalg.norm(params['x'] - x_optimal)
        }
    
    # Plot convergence
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    for name, result in results.items():
        plt.semilogy(result['losses'], label=name)
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.title('Convergence Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    names = list(results.keys())
    distances = [results[name]['distance_to_optimal'] for name in names]
    plt.bar(names, distances)
    plt.ylabel('Distance to Optimal')
    plt.title('Final Optimization Error')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print("Optimizer Comparison Results:")
    print(f"Optimal solution: {x_optimal}")
    for name, result in results.items():
        print(f"{name:10}: final_x={result['final_params']}, error={result['distance_to_optimal']:.6f}")
    
    return results

optimizer_results = test_optimizers()
```

## Learning Rate Scheduling

### Learning Rate Schedulers

```python
class LRScheduler:
    """Base class for learning rate schedulers"""
    
    def __call__(self, step: int) -> float:
        raise NotImplementedError

class StepLR(LRScheduler):
    """Step decay learning rate scheduler"""
    
    def __init__(self, initial_lr: float, step_size: int, gamma: float = 0.1):
        self.initial_lr = initial_lr
        self.step_size = step_size
        self.gamma = gamma
    
    def __call__(self, step: int) -> float:
        return self.initial_lr * (self.gamma ** (step // self.step_size))

class ExponentialLR(LRScheduler):
    """Exponential decay learning rate scheduler"""
    
    def __init__(self, initial_lr: float, gamma: float):
        self.initial_lr = initial_lr
        self.gamma = gamma
    
    def __call__(self, step: int) -> float:
        return self.initial_lr * (self.gamma ** step)

class CosineAnnealingLR(LRScheduler):
    """Cosine annealing learning rate scheduler"""
    
    def __init__(self, initial_lr: float, T_max: int, eta_min: float = 0):
        self.initial_lr = initial_lr
        self.T_max = T_max
        self.eta_min = eta_min
    
    def __call__(self, step: int) -> float:
        return self.eta_min + (self.initial_lr - self.eta_min) * (1 + jnp.cos(jnp.pi * step / self.T_max)) / 2

class WarmupLR(LRScheduler):
    """Warmup + cosine learning rate scheduler"""
    
    def __init__(self, initial_lr: float, warmup_steps: int, max_steps: int):
        self.initial_lr = initial_lr
        self.warmup_steps = warmup_steps
        self.max_steps = max_steps
    
    def __call__(self, step: int) -> float:
        if step < self.warmup_steps:
            return self.initial_lr * step / self.warmup_steps
        else:
            progress = (step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
            return self.initial_lr * 0.5 * (1 + jnp.cos(jnp.pi * progress))

def test_schedulers():
    """Test learning rate schedulers"""
    
    schedulers = {
        'Constant': lambda step: 0.1,
        'Step': StepLR(0.1, step_size=30, gamma=0.5),
        'Exponential': ExponentialLR(0.1, gamma=0.95),
        'Cosine': CosineAnnealingLR(0.1, T_max=100),
        'Warmup+Cosine': WarmupLR(0.1, warmup_steps=20, max_steps=100)
    }
    
    steps = jnp.arange(100)
    
    plt.figure(figsize=(10, 6))
    for name, scheduler in schedulers.items():
        lrs = [scheduler(step) for step in steps]
        plt.plot(steps, lrs, label=name, linewidth=2)
    
    plt.xlabel('Step')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedules')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

test_schedulers()
```

## Advanced Optimizer Features

### Gradient Clipping and Optimizer Wrappers

```python
class OptimWithClipping:
    """Optimizer wrapper with gradient clipping"""
    
    def __init__(self, optimizer: Optimizer, max_grad_norm: float = 1.0):
        self.optimizer = optimizer
        self.max_grad_norm = max_grad_norm
    
    def init_state(self, params):
        return self.optimizer.init_state(params)
    
    def update(self, grads, state, params):
        # Clip gradients
        clipped_grads = clip_gradients(grads, self.max_grad_norm)
        return self.optimizer.update(clipped_grads, state, params)

class OptimWithScheduler:
    """Optimizer wrapper with learning rate scheduling"""
    
    def __init__(self, optimizer: Optimizer, scheduler: LRScheduler):
        self.optimizer = optimizer
        self.scheduler = scheduler
    
    def init_state(self, params):
        state = self.optimizer.init_state(params)
        state['scheduler_step'] = 0
        return state
    
    def update(self, grads, state, params):
        # Update learning rate
        step = state.get('scheduler_step', 0)
        self.optimizer.learning_rate = self.scheduler(step)
        
        # Update parameters
        new_params, new_state = self.optimizer.update(grads, state, params)
        new_state['scheduler_step'] = step + 1
        
        return new_params, new_state

def test_advanced_features():
    """Test gradient clipping and scheduling"""
    
    key = random.PRNGKey(123)
    
    # Create problem with large gradients
    def loss_fn(params):
        x = params['x']
        return jnp.sum(x**4)  # Fourth power creates large gradients
    
    grad_fn = grad(loss_fn)
    
    # Test with and without clipping
    params = {'x': jnp.array([2.0, -1.5, 3.0])}
    grads = grad_fn(params)
    
    print("Gradient Clipping Test:")
    print(f"Original gradient norm: {jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(grads))):.4f}")
    
    clipped_grads = clip_gradients(grads, max_norm=1.0)
    print(f"Clipped gradient norm: {jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(clipped_grads))):.4f}")
    
    # Test optimizer with scheduling
    base_optimizer = Adam(0.1)
    scheduler = CosineAnnealingLR(0.1, T_max=50)
    scheduled_optimizer = OptimWithScheduler(base_optimizer, scheduler)
    
    state = scheduled_optimizer.init_state(params)
    
    lrs = []
    for step in range(20):
        grads = grad_fn(params)
        params, state = scheduled_optimizer.update(grads, state, params)
        lrs.append(base_optimizer.learning_rate)
    
    print(f"Learning rates over 20 steps: {lrs[:5]}...{lrs[-5:]}")

test_advanced_features()
```

## Optimizer Comparison on Neural Network

### Training MLP with Different Optimizers

```python
def create_classification_data(key, n_samples=1000, n_features=20, n_classes=3, noise=0.1):
    """Create synthetic classification dataset"""
    
    # Generate class centers
    centers = random.normal(key, (n_classes, n_features)) * 2
    
    # Generate samples
    samples_per_class = n_samples // n_classes
    X = []
    y = []
    
    for class_idx in range(n_classes):
        class_samples = centers[class_idx] + noise * random.normal(
            random.split(key, n_classes + 1)[class_idx + 1], 
            (samples_per_class, n_features)
        )
        X.append(class_samples)
        y.extend([class_idx] * samples_per_class)
    
    X = jnp.concatenate(X, axis=0)
    y = jax.nn.one_hot(jnp.array(y), n_classes)
    
    return X, y

def simple_mlp_forward(params, x):
    """Simple 2-layer MLP forward pass"""
    h = jax.nn.relu(x @ params['W1'] + params['b1'])
    return h @ params['W2'] + params['b2']

def cross_entropy_loss(params, x, y):
    """Cross-entropy loss"""
    logits = simple_mlp_forward(params, x)
    log_probs = jax.nn.log_softmax(logits)
    return -jnp.mean(jnp.sum(y * log_probs, axis=1))

def accuracy(params, x, y):
    """Classification accuracy"""
    logits = simple_mlp_forward(params, x)
    pred_classes = jnp.argmax(logits, axis=1)
    true_classes = jnp.argmax(y, axis=1)
    return jnp.mean(pred_classes == true_classes)

def compare_optimizers_on_mlp():
    """Compare optimizers training an MLP"""
    
    key = random.PRNGKey(42)
    
    # Generate data
    X, y = create_classification_data(key, n_samples=800, n_features=20, n_classes=3)
    
    # Split train/test
    split_idx = int(0.8 * len(X))
    X_train, X_test = X[:split_idx], X[split_idx:]
    y_train, y_test = y[:split_idx], y[split_idx:]
    
    # Initialize MLP parameters
    def init_mlp_params(key, input_size, hidden_size, output_size):
        keys = random.split(key, 4)
        return {
            'W1': random.normal(keys[0], (input_size, hidden_size)) * jnp.sqrt(2.0 / input_size),
            'b1': jnp.zeros(hidden_size),
            'W2': random.normal(keys[1], (hidden_size, output_size)) * jnp.sqrt(2.0 / hidden_size),
            'b2': jnp.zeros(output_size)
        }
    
    # Test optimizers
    optimizers_to_test = {
        'SGD': SGD(0.01),
        'SGD+Momentum': SGD(0.01, momentum=0.9),
        'Adam': Adam(0.001),
        'AdamW': AdamW(0.001, weight_decay=0.01),
        'RMSprop': RMSprop(0.001)
    }
    
    n_epochs = 50
    results = {}
    
    grad_fn = grad(cross_entropy_loss)
    
    for opt_name, optimizer in optimizers_to_test.items():
        print(f"Training with {opt_name}...")
        
        # Initialize parameters and state
        params = init_mlp_params(random.split(key)[1], 20, 64, 3)
        state = optimizer.init_state(params)
        
        train_losses = []
        test_accs = []
        
        for epoch in range(n_epochs):
            # Training step
            grads = grad_fn(params, X_train, y_train)
            params, state = optimizer.update(grads, state, params)
            
            # Record metrics
            train_loss = cross_entropy_loss(params, X_train, y_train)
            test_acc = accuracy(params, X_test, y_test)
            
            train_losses.append(train_loss)
            test_accs.append(test_acc)
        
        results[opt_name] = {
            'train_losses': jnp.array(train_losses),
            'test_accs': jnp.array(test_accs),
            'final_acc': test_acc
        }
    
    # Plot results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Training loss
    for opt_name, result in results.items():
        ax1.plot(result['train_losses'], label=opt_name, linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.set_title('Training Loss Comparison')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_yscale('log')
    
    # Test accuracy
    for opt_name, result in results.items():
        ax2.plot(result['test_accs'], label=opt_name, linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Test Accuracy')
    ax2.set_title('Test Accuracy Comparison')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nFinal Results:")
    for opt_name, result in results.items():
        print(f"{opt_name:15}: {result['final_acc']:.4f}")
    
    return results

mlp_results = compare_optimizers_on_mlp()
```

## Summary

In this notebook, we've implemented and analyzed various optimization algorithms:

**Core Optimizers:**

1. **SGD**: Simple gradient descent with optional momentum
2. **Nesterov**: Accelerated gradient with look-ahead
3. **RMSprop**: Adaptive learning rates based on gradient magnitude
4. **Adam**: Adaptive moments with bias correction
5. **AdamW**: Adam with decoupled weight decay

**Advanced Features:**
- Gradient clipping for training stability
- Learning rate scheduling (step, exponential, cosine, warmup)
- Optimizer wrappers for modular functionality
- Weight decay vs L2 regularization differences

**Key Insights:**
- Adam often converges faster than SGD on neural networks
- AdamW provides better generalization than Adam with L2 regularization
- Learning rate scheduling crucial for final performance
- Gradient clipping prevents exploding gradients
- Different optimizers suit different problem types

**Practical Guidelines:**
- Use Adam/AdamW as default for most neural networks
- SGD+momentum for well-tuned scenarios or when batch size is large
- RMSprop for RNNs and non-stationary objectives
- Always use learning rate scheduling for best results
- Clip gradients when training deep networks or RNNs

**JAX Implementation Benefits:**
- Functional approach with immutable state
- Easy to compose with other transformations
- JIT compilation for performance
- Tree operations for complex parameter structures

**Next Steps:**
- The next notebook will cover loss functions in detail
- We'll explore different loss formulations and their properties
- Understanding optimizers enables effective training of complex models

This comprehensive optimizer implementation provides the foundation for training neural networks effectively across various domains and architectures.