# File: notebooks/03_neural_networks/08_mlp_from_scratch.ipynb

## JAX Neural Networks: MLP from Scratch

Welcome to the neural networks section! This notebook implements a Multi-Layer Perceptron (MLP) from scratch using pure JAX. We'll cover weight initialization, forward propagation, backpropagation, and training loops while leveraging JAX's autodiff capabilities.

Building neural networks from scratch in JAX provides deep understanding of the underlying mechanics while showcasing JAX's functional programming approach and automatic differentiation.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, lax
from jax.nn import relu, sigmoid, softmax, log_softmax
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Callable, Dict, Any
import functools

jax.config.update("jax_enable_x64", True)
print(f"JAX version: {jax.__version__}")
```

## Neural Network Components

### Activation Functions

```python
# Custom activation functions
def swish(x):
    """Swish activation: x * sigmoid(x)"""
    return x * sigmoid(x)

def gelu(x):
    """Gaussian Error Linear Unit"""
    return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2/jnp.pi) * (x + 0.044715 * x**3)))

def leaky_relu(x, alpha=0.01):
    """Leaky ReLU with negative slope alpha"""
    return jnp.where(x > 0, x, alpha * x)

# Activation function registry
ACTIVATIONS = {
    'relu': relu,
    'sigmoid': sigmoid,
    'tanh': jnp.tanh,
    'swish': swish,
    'gelu': gelu,
    'leaky_relu': leaky_relu,
    'linear': lambda x: x
}

def test_activations():
    """Test and visualize activation functions"""
    x = jnp.linspace(-3, 3, 100)
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    axes = axes.ravel()
    
    for i, (name, func) in enumerate(ACTIVATIONS.items()):
        if i >= len(axes):
            break
        y = func(x)
        axes[i].plot(x, y, label=name, linewidth=2)
        axes[i].set_title(f'{name.capitalize()} Activation')
        axes[i].grid(True, alpha=0.3)
        axes[i].set_xlabel('x')
        axes[i].set_ylabel('f(x)')
    
    plt.tight_layout()
    plt.show()

# Test activations
test_activations()
```

### Weight Initialization

```python
def xavier_uniform_init(key, shape, gain=1.0):
    """Xavier/Glorot uniform initialization"""
    fan_in, fan_out = shape[0], shape[1]
    limit = gain * jnp.sqrt(6.0 / (fan_in + fan_out))
    return random.uniform(key, shape, minval=-limit, maxval=limit)

def xavier_normal_init(key, shape, gain=1.0):
    """Xavier/Glorot normal initialization"""
    fan_in, fan_out = shape[0], shape[1]
    std = gain * jnp.sqrt(2.0 / (fan_in + fan_out))
    return random.normal(key, shape) * std

def he_uniform_init(key, shape, gain=1.0):
    """He uniform initialization (good for ReLU)"""
    fan_in = shape[0]
    limit = gain * jnp.sqrt(6.0 / fan_in)
    return random.uniform(key, shape, minval=-limit, maxval=limit)

def he_normal_init(key, shape, gain=1.0):
    """He normal initialization (good for ReLU)"""
    fan_in = shape[0]
    std = gain * jnp.sqrt(2.0 / fan_in)
    return random.normal(key, shape) * std

def lecun_normal_init(key, shape, gain=1.0):
    """LeCun normal initialization"""
    fan_in = shape[0]
    std = gain * jnp.sqrt(1.0 / fan_in)
    return random.normal(key, shape) * std

# Initialization registry
INITIALIZERS = {
    'xavier_uniform': xavier_uniform_init,
    'xavier_normal': xavier_normal_init,
    'he_uniform': he_uniform_init,
    'he_normal': he_normal_init,
    'lecun_normal': lecun_normal_init,
    'zeros': lambda key, shape: jnp.zeros(shape),
    'ones': lambda key, shape: jnp.ones(shape),
    'normal': lambda key, shape: random.normal(key, shape) * 0.01
}

def compare_initializations():
    """Compare different weight initialization schemes"""
    key = random.PRNGKey(42)
    shape = (784, 256)  # Input to hidden layer
    
    print("Weight Initialization Comparison:")
    print("=" * 40)
    
    for name, init_func in INITIALIZERS.items():
        if name in ['zeros', 'ones']:
            continue  # Skip trivial cases
            
        weights = init_func(key, shape)
        
        mean_val = jnp.mean(weights)
        std_val = jnp.std(weights)
        min_val = jnp.min(weights)
        max_val = jnp.max(weights)
        
        print(f"{name:15}: mean={mean_val:7.4f}, std={std_val:.4f}, "
              f"range=[{min_val:6.3f}, {max_val:6.3f}]")

compare_initializations()
```

## MLP Implementation

### Core MLP Class

```python
class MLP:
    """Multi-Layer Perceptron implementation in JAX"""
    
    def __init__(self, 
                 layer_sizes: List[int],
                 activation: str = 'relu',
                 output_activation: str = 'linear',
                 weight_init: str = 'he_normal',
                 bias_init: str = 'zeros'):
        
        self.layer_sizes = layer_sizes
        self.activation = ACTIVATIONS[activation]
        self.output_activation = ACTIVATIONS[output_activation]
        self.weight_init = INITIALIZERS[weight_init]
        self.bias_init = INITIALIZERS[bias_init]
        self.num_layers = len(layer_sizes) - 1
    
    def init_params(self, key: jax.random.PRNGKey) -> Dict[str, Any]:
        """Initialize network parameters"""
        keys = random.split(key, 2 * self.num_layers)
        params = {}
        
        for i in range(self.num_layers):
            layer_name = f'layer_{i}'
            input_size = self.layer_sizes[i]
            output_size = self.layer_sizes[i + 1]
            
            # Initialize weights and biases
            W = self.weight_init(keys[2*i], (input_size, output_size))
            b = self.bias_init(keys[2*i + 1], (output_size,))
            
            params[layer_name] = {'W': W, 'b': b}
        
        return params
    
    def forward(self, params: Dict[str, Any], x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass through the network"""
        for i in range(self.num_layers):
            layer_name = f'layer_{i}'
            W = params[layer_name]['W']
            b = params[layer_name]['b']
            
            # Linear transformation
            x = x @ W + b
            
            # Apply activation
            if i < self.num_layers - 1:  # Hidden layers
                x = self.activation(x)
            else:  # Output layer
                x = self.output_activation(x)
        
        return x
    
    def __call__(self, params: Dict[str, Any], x: jnp.ndarray) -> jnp.ndarray:
        """Make the MLP callable"""
        return self.forward(params, x)

# Test MLP creation and forward pass
def test_mlp():
    """Test MLP initialization and forward pass"""
    
    # Create MLP for MNIST-like classification
    mlp = MLP(layer_sizes=[784, 256, 128, 10], 
              activation='relu',
              output_activation='linear',
              weight_init='he_normal')
    
    # Initialize parameters
    key = random.PRNGKey(0)
    params = mlp.init_params(key)
    
    print("MLP Architecture:")
    print(f"Layers: {mlp.layer_sizes}")
    print(f"Activation: {mlp.activation.__name__}")
    
    # Test forward pass
    batch_size = 32
    input_dim = 784
    x = random.normal(random.split(key)[1], (batch_size, input_dim))
    
    output = mlp.forward(params, x)
    print(f"\nForward pass test:")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    
    # Count parameters
    total_params = 0
    for layer_name, layer_params in params.items():
        W_params = layer_params['W'].size
        b_params = layer_params['b'].size
        layer_total = W_params + b_params
        total_params += layer_total
        print(f"{layer_name}: W{layer_params['W'].shape} + b{layer_params['b'].shape} = {layer_total} params")
    
    print(f"Total parameters: {total_params:,}")
    
    return mlp, params

mlp, params = test_mlp()
```

## Loss Functions

### Common Loss Functions

```python
def mse_loss(predictions: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Mean Squared Error loss"""
    return jnp.mean((predictions - targets) ** 2)

def mae_loss(predictions: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Mean Absolute Error loss"""
    return jnp.mean(jnp.abs(predictions - targets))

def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    """Cross-entropy loss for classification"""
    # Numerically stable implementation
    log_probs = log_softmax(logits)
    return -jnp.mean(jnp.sum(labels * log_probs, axis=1))

def sparse_cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    """Cross-entropy loss with integer labels"""
    log_probs = log_softmax(logits)
    num_classes = logits.shape[1]
    one_hot_labels = jax.nn.one_hot(labels, num_classes)
    return cross_entropy_loss(logits, one_hot_labels)

def huber_loss(predictions: jnp.ndarray, targets: jnp.ndarray, delta: float = 1.0) -> jnp.ndarray:
    """Huber loss (smooth L1 loss)"""
    residual = jnp.abs(predictions - targets)
    return jnp.mean(
        jnp.where(residual < delta,
                  0.5 * residual ** 2,
                  delta * (residual - 0.5 * delta))
    )

# Loss function registry
LOSSES = {
    'mse': mse_loss,
    'mae': mae_loss,
    'cross_entropy': cross_entropy_loss,
    'sparse_cross_entropy': sparse_cross_entropy_loss,
    'huber': huber_loss
}

def test_losses():
    """Test loss function implementations"""
    key = random.PRNGKey(123)
    
    # Test regression losses
    y_pred = random.normal(key, (100, 1))
    y_true = random.normal(random.split(key)[1], (100, 1))
    
    print("Regression Loss Functions:")
    for name in ['mse', 'mae', 'huber']:
        loss_val = LOSSES[name](y_pred, y_true)
        print(f"{name.upper():6}: {loss_val:.4f}")
    
    # Test classification losses
    logits = random.normal(random.split(key, 3)[2], (100, 10))
    labels_one_hot = jax.nn.one_hot(random.randint(random.split(key, 4)[3], (100,), 0, 10), 10)
    labels_sparse = jnp.argmax(labels_one_hot, axis=1)
    
    print("\nClassification Loss Functions:")
    ce_loss = cross_entropy_loss(logits, labels_one_hot)
    sce_loss = sparse_cross_entropy_loss(logits, labels_sparse)
    
    print(f"Cross-entropy: {ce_loss:.4f}")
    print(f"Sparse CE:     {sce_loss:.4f}")
    print(f"Difference:    {jnp.abs(ce_loss - sce_loss):.6f}")

test_losses()
```

## Training Implementation

### Optimizer Implementation

```python
class SGD:
    """Stochastic Gradient Descent optimizer"""
    
    def __init__(self, learning_rate: float = 0.01, momentum: float = 0.0):
        self.learning_rate = learning_rate
        self.momentum = momentum
    
    def init_state(self, params):
        """Initialize optimizer state"""
        if self.momentum > 0:
            return jax.tree_map(jnp.zeros_like, params)
        return {}
    
    def update(self, grads, state, params):
        """Update parameters using gradients"""
        if self.momentum > 0:
            # Momentum update
            new_state = jax.tree_map(
                lambda v, g: self.momentum * v + g,
                state, grads
            )
            new_params = jax.tree_map(
                lambda p, v: p - self.learning_rate * v,
                params, new_state
            )
            return new_params, new_state
        else:
            # Simple SGD
            new_params = jax.tree_map(
                lambda p, g: p - self.learning_rate * g,
                params, grads
            )
            return new_params, state

class Adam:
    """Adam optimizer"""
    
    def __init__(self, learning_rate: float = 0.001, 
                 beta1: float = 0.9, beta2: float = 0.999, 
                 eps: float = 1e-8):
        self.learning_rate = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
    
    def init_state(self, params):
        """Initialize Adam state"""
        return {
            'm': jax.tree_map(jnp.zeros_like, params),  # First moment
            'v': jax.tree_map(jnp.zeros_like, params),  # Second moment
            'step': 0
        }
    
    def update(self, grads, state, params):
        """Adam parameter update"""
        step = state['step'] + 1
        
        # Update biased first and second moments
        m = jax.tree_map(
            lambda m_prev, g: self.beta1 * m_prev + (1 - self.beta1) * g,
            state['m'], grads
        )
        v = jax.tree_map(
            lambda v_prev, g: self.beta2 * v_prev + (1 - self.beta2) * g**2,
            state['v'], grads
        )
        
        # Bias correction
        m_hat = jax.tree_map(lambda m_val: m_val / (1 - self.beta1**step), m)
        v_hat = jax.tree_map(lambda v_val: v_val / (1 - self.beta2**step), v)
        
        # Parameter update
        new_params = jax.tree_map(
            lambda p, m_val, v_val: p - self.learning_rate * m_val / (jnp.sqrt(v_val) + self.eps),
            params, m_hat, v_hat
        )
        
        new_state = {'m': m, 'v': v, 'step': step}
        return new_params, new_state
```

### Training Loop

```python
def train_mlp(mlp: MLP, 
              train_data: Tuple[jnp.ndarray, jnp.ndarray],
              test_data: Tuple[jnp.ndarray, jnp.ndarray],
              optimizer,
              loss_fn: Callable,
              num_epochs: int = 100,
              batch_size: int = 32,
              key: jax.random.PRNGKey = random.PRNGKey(0)):
    """Train MLP with given data and optimizer"""
    
    X_train, y_train = train_data
    X_test, y_test = test_data
    
    # Initialize parameters and optimizer state
    params = mlp.init_params(key)
    opt_state = optimizer.init_state(params)
    
    # JIT compile training step
    @jit
    def train_step(params, opt_state, batch_x, batch_y):
        def loss_fn_params(params):
            predictions = mlp(params, batch_x)
            return loss_fn(predictions, batch_y)
        
        loss, grads = jax.value_and_grad(loss_fn_params)(params)
        new_params, new_opt_state = optimizer.update(grads, opt_state, params)
        return new_params, new_opt_state, loss
    
    # JIT compile evaluation
    @jit
    def eval_step(params, x, y):
        predictions = mlp(params, x)
        loss = loss_fn(predictions, y)
        
        # Compute accuracy for classification
        if len(y.shape) > 1 and y.shape[1] > 1:  # One-hot labels
            pred_labels = jnp.argmax(predictions, axis=1)
            true_labels = jnp.argmax(y, axis=1)
        else:  # Regression or sparse labels
            pred_labels = jnp.argmax(predictions, axis=1)
            true_labels = y.flatten()
        
        accuracy = jnp.mean(pred_labels == true_labels)
        return loss, accuracy
    
    # Training history
    train_losses = []
    test_losses = []
    test_accuracies = []
    
    # Training loop
    n_train = len(X_train)
    n_batches = n_train // batch_size
    
    for epoch in range(num_epochs):
        # Shuffle training data
        perm = random.permutation(key, n_train)
        key = random.split(key)[0]
        X_shuffled = X_train[perm]
        y_shuffled = y_train[perm]
        
        # Mini-batch training
        epoch_losses = []
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            
            batch_x = X_shuffled[start_idx:end_idx]
            batch_y = y_shuffled[start_idx:end_idx]
            
            params, opt_state, batch_loss = train_step(params, opt_state, batch_x, batch_y)
            epoch_losses.append(batch_loss)
        
        # Record training loss
        avg_train_loss = jnp.mean(jnp.array(epoch_losses))
        train_losses.append(avg_train_loss)
        
        # Evaluate on test set
        test_loss, test_accuracy = eval_step(params, X_test, y_test)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        
        # Print progress
        if epoch % 10 == 0 or epoch == num_epochs - 1:
            print(f"Epoch {epoch:3d}: train_loss={avg_train_loss:.4f}, "
                  f"test_loss={test_loss:.4f}, test_acc={test_accuracy:.4f}")
    
    return params, {
        'train_losses': jnp.array(train_losses),
        'test_losses': jnp.array(test_losses),
        'test_accuracies': jnp.array(test_accuracies)
    }
```

## Practical Example: Classification

### Generate Synthetic Dataset

```python
def generate_classification_data(key, n_samples=1000, n_features=20, n_classes=3):
    """Generate synthetic classification dataset"""
    
    # Generate class centers
    centers = random.normal(key, (n_classes, n_features)) * 2
    
    # Generate samples
    samples_per_class = n_samples // n_classes
    X = []
    y = []
    
    for class_idx in range(n_classes):
        # Generate samples around class center
        class_samples = centers[class_idx] + random.normal(
            random.split(key, n_classes + 1)[class_idx + 1], 
            (samples_per_class, n_features)
        )
        X.append(class_samples)
        y.extend([class_idx] * samples_per_class)
    
    X = jnp.concatenate(X, axis=0)
    y = jnp.array(y)
    
    # Shuffle data
    perm = random.permutation(random.split(key)[-1], len(X))
    X = X[perm]
    y = y[perm]
    
    return X, y

# Generate dataset
key = random.PRNGKey(42)
X, y = generate_classification_data(key, n_samples=2000, n_features=20, n_classes=3)

# Split into train/test
split_idx = int(0.8 * len(X))
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]

# Convert to one-hot labels
y_train_oh = jax.nn.one_hot(y_train, 3)
y_test_oh = jax.nn.one_hot(y_test, 3)

print(f"Dataset: {len(X)} samples, {X.shape[1]} features, {len(jnp.unique(y))} classes")
print(f"Train: {len(X_train)}, Test: {len(X_test)}")
```

### Train and Evaluate

```python
# Create and train MLP
mlp_classifier = MLP(
    layer_sizes=[20, 64, 32, 3],
    activation='relu',
    output_activation='linear',
    weight_init='he_normal'
)

# Train with SGD
print("Training with SGD:")
sgd_optimizer = SGD(learning_rate=0.01, momentum=0.9)
params_sgd, history_sgd = train_mlp(
    mlp_classifier, 
    (X_train, y_train_oh), 
    (X_test, y_test_oh),
    sgd_optimizer,
    cross_entropy_loss,
    num_epochs=100,
    batch_size=32,
    key=random.PRNGKey(123)
)

print("\nTraining with Adam:")
adam_optimizer = Adam(learning_rate=0.001)
params_adam, history_adam = train_mlp(
    mlp_classifier,
    (X_train, y_train_oh),
    (X_test, y_test_oh), 
    adam_optimizer,
    cross_entropy_loss,
    num_epochs=100,
    batch_size=32,
    key=random.PRNGKey(123)
)

# Compare optimizers
print(f"\nFinal Results:")
print(f"SGD  - Test Accuracy: {history_sgd['test_accuracies'][-1]:.4f}")
print(f"Adam - Test Accuracy: {history_adam['test_accuracies'][-1]:.4f}")
```

## Advanced Features

### Regularization

```python
def l2_regularization(params, weight_decay=0.01):
    """Compute L2 regularization penalty"""
    l2_loss = 0.0
    for layer_params in params.values():
        l2_loss += weight_decay * jnp.sum(layer_params['W']**2)
    return l2_loss

def dropout(key, x, rate=0.5, training=True):
    """Apply dropout to layer activations"""
    if not training or rate == 0.0:
        return x
    
    keep_rate = 1.0 - rate
    mask = random.bernoulli(key, keep_rate, x.shape)
    return jnp.where(mask, x / keep_rate, 0.0)

class RegularizedMLP(MLP):
    """MLP with regularization support"""
    
    def __init__(self, *args, dropout_rate=0.0, weight_decay=0.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
    
    def forward(self, params, x, key=None, training=True):
        """Forward pass with dropout"""
        if key is None:
            key = random.PRNGKey(0)
        
        for i in range(self.num_layers):
            layer_name = f'layer_{i}'
            W = params[layer_name]['W']
            b = params[layer_name]['b']
            
            # Linear transformation
            x = x @ W + b
            
            # Apply activation
            if i < self.num_layers - 1:
                x = self.activation(x)
                # Apply dropout to hidden layers
                if self.dropout_rate > 0 and training:
                    key, dropout_key = random.split(key)
                    x = dropout(dropout_key, x, self.dropout_rate, training)
            else:
                x = self.output_activation(x)
        
        return x
    
    def loss_with_regularization(self, params, x, y, key=None, training=True):
        """Compute loss with L2 regularization"""
        predictions = self.forward(params, x, key, training)
        base_loss = cross_entropy_loss(predictions, y)
        reg_loss = l2_regularization(params, self.weight_decay)
        return base_loss + reg_loss

# Test regularized MLP
regularized_mlp = RegularizedMLP(
    layer_sizes=[20, 128, 64, 3],
    activation='relu',
    dropout_rate=0.3,
    weight_decay=0.001
)

print("Regularized MLP created with dropout=0.3, weight_decay=0.001")
```

### Batch Normalization

```python
def batch_norm(x, gamma, beta, running_mean, running_var, training=True, momentum=0.9, eps=1e-5):
    """Batch normalization implementation"""
    
    if training:
        # Compute batch statistics
        batch_mean = jnp.mean(x, axis=0, keepdims=True)
        batch_var = jnp.var(x, axis=0, keepdims=True)
        
        # Update running statistics
        new_running_mean = momentum * running_mean + (1 - momentum) * batch_mean
        new_running_var = momentum * running_var + (1 - momentum) * batch_var
        
        # Normalize using batch stats
        x_norm = (x - batch_mean) / jnp.sqrt(batch_var + eps)
    else:
        # Use running statistics for inference
        x_norm = (x - running_mean) / jnp.sqrt(running_var + eps)
        new_running_mean = running_mean
        new_running_var = running_var
    
    # Scale and shift
    out = gamma * x_norm + beta
    
    return out, new_running_mean, new_running_var

class BatchNormMLP(MLP):
    """MLP with batch normalization"""
    
    def init_params(self, key):
        """Initialize parameters including batch norm params"""
        params = super().init_params(key)
        
        # Add batch norm parameters for hidden layers
        for i in range(self.num_layers - 1):  # Exclude output layer
            layer_size = self.layer_sizes[i + 1]
            bn_name = f'bn_{i}'
            
            params[bn_name] = {
                'gamma': jnp.ones((1, layer_size)),
                'beta': jnp.zeros((1, layer_size)),
                'running_mean': jnp.zeros((1, layer_size)),
                'running_var': jnp.ones((1, layer_size))
            }
        
        return params
    
    def forward(self, params, x, training=True):
        """Forward pass with batch normalization"""
        for i in range(self.num_layers):
            layer_name = f'layer_{i}'
            W = params[layer_name]['W']
            b = params[layer_name]['b']
            
            # Linear transformation
            x = x @ W + b
            
            # Apply batch norm to hidden layers
            if i < self.num_layers - 1:
                bn_name = f'bn_{i}'
                bn_params = params[bn_name]
                
                x, new_running_mean, new_running_var = batch_norm(
                    x, bn_params['gamma'], bn_params['beta'],
                    bn_params['running_mean'], bn_params['running_var'],
                    training=training
                )
                
                # Update running statistics (in practice, this would be handled by the training loop)
                params[bn_name]['running_mean'] = new_running_mean
                params[bn_name]['running_var'] = new_running_var
                
                x = self.activation(x)
            else:
                x = self.output_activation(x)
        
        return x

print("Batch Normalization MLP implementation ready")
```

## Model Analysis and Visualization

### Gradient Analysis

```python
def analyze_gradients(mlp, params, X, y, loss_fn):
    """Analyze gradient magnitudes across layers"""
    
    def loss_fn_params(params):
        predictions = mlp(params, X)
        return loss_fn(predictions, y)
    
    grads = grad(loss_fn_params)(params)
    
    print("Gradient Analysis:")
    print("=" * 30)
    
    for layer_name in sorted(params.keys()):
        if 'layer_' in layer_name:
            W_grad = grads[layer_name]['W']
            b_grad = grads[layer_name]['b']
            
            W_grad_norm = jnp.linalg.norm(W_grad)
            b_grad_norm = jnp.linalg.norm(b_grad)
            
            print(f"{layer_name}: W_grad_norm={W_grad_norm:.6f}, b_grad_norm={b_grad_norm:.6f}")

# Analyze gradients for trained model
analyze_gradients(mlp_classifier, params_adam, X_test, y_test_oh, cross_entropy_loss)
```

### Learning Curves Visualization

```python
def plot_learning_curves(history_sgd, history_adam):
    """Plot training curves comparing optimizers"""
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    epochs = range(len(history_sgd['train_losses']))
    
    # Training loss
    axes[0].plot(epochs, history_sgd['train_losses'], label='SGD', alpha=0.7)
    axes[0].plot(epochs, history_adam['train_losses'], label='Adam', alpha=0.7)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Training Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Test loss
    axes[1].plot(epochs, history_sgd['test_losses'], label='SGD', alpha=0.7)
    axes[1].plot(epochs, history_adam['test_losses'], label='Adam', alpha=0.7)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Test Loss') 
    axes[1].set_title('Test Loss')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Test accuracy
    axes[2].plot(epochs, history_sgd['test_accuracies'], label='SGD', alpha=0.7)
    axes[2].plot(epochs, history_adam['test_accuracies'], label='Adam', alpha=0.7)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Test Accuracy')
    axes[2].set_title('Test Accuracy')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Plot learning curves
plot_learning_curves(history_sgd, history_adam)
```

## Summary

In this notebook, we've built a complete MLP implementation from scratch in JAX:

**Key Components:**

1. **Activation Functions**: ReLU, sigmoid, tanh, swish, GELU, leaky ReLU
2. **Weight Initialization**: Xavier, He, LeCun initialization schemes
3. **MLP Architecture**: Flexible multi-layer implementation
4. **Loss Functions**: MSE, MAE, cross-entropy, sparse cross-entropy, Huber loss
5. **Optimizers**: SGD with momentum, Adam optimizer

**Advanced Features:**
- L2 regularization and dropout for overfitting prevention
- Batch normalization for training stability
- Gradient analysis for debugging
- Learning curve visualization

**JAX-Specific Advantages:**
- Automatic differentiation eliminates manual backpropagation
- JIT compilation for performance optimization
- Functional programming approach with immutable parameters
- Easy vectorization with vmap for batch processing

**Best Practices Demonstrated:**
- Proper weight initialization for different activations
- Numerically stable loss function implementations
- Modular design with separate components
- Comprehensive training and evaluation loops

**Performance Insights:**
- Adam typically converges faster than SGD
- Proper initialization prevents vanishing/exploding gradients
- Regularization improves generalization
- Batch normalization stabilizes training

**Next Steps:**
- The next notebook will cover CNN implementation
- We'll explore convolutional layers and image processing
- Understanding MLPs provides foundation for more complex architectures

This MLP implementation demonstrates JAX's power for neural network development, combining ease of use with high performance and automatic differentiation capabilities.