# Optimizer Comparison and Visualization

This notebook demonstrates how to compare different optimizers and visualize their behavior on various loss surfaces.

## Setup and Imports

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Setup for inline plots
%matplotlib inline
plt.style.use('seaborn-v0_8')

# Add paths for optimizer imports
sys.path.append('../validation')

# Import optimizers
from adam.solution import adam_init, adam_update
from adamw.solution import adamw_init, adamw_update
from sgd.solution import sgd_init, sgd_update
from muon.solution import muon_init, muon_update

# Import visualization tools
from contract import RosenbrockModel, SimpleQuadraticModel, OptimizerConfig
from compare import compare_optimizers, print_comparison_table, plot_loss_curves, plot_convergence_summary
from visualize import plot_2d_trajectories, plot_high_dimensional_trajectories, plot_parameter_evolution

print("✅ Setup complete!")

## Configure Optimizers

We'll compare Adam, AdamW, SGD with momentum, and Muon across different problems.

In [None]:
# Configure optimizers for comparison
optimizers = {
    'Adam': OptimizerConfig(
        name='Adam',
        init_fn=adam_init,
        update_fn=adam_update,
        kwargs={'lr': 0.01},
        color='red'
    ),
    'AdamW': OptimizerConfig(
        name='AdamW',
        init_fn=adamw_init,
        update_fn=adamw_update,
        kwargs={'lr': 0.01, 'weight_decay': 0.001},
        color='blue'
    ),
    'SGD+Momentum': OptimizerConfig(
        name='SGD+Momentum',
        init_fn=sgd_init,
        update_fn=sgd_update,
        kwargs={'lr': 0.002, 'momentum': 0.9},
        color='green'
    ),
    'Muon': OptimizerConfig(
        name='Muon',
        init_fn=muon_init,
        update_fn=muon_update,
        kwargs={'lr': 0.005, 'momentum': 0.9, 'orthogonalize': True},
        color='purple'
    )
}

print(f"Configured {len(optimizers)} optimizers: {list(optimizers.keys())}")

## Example 1: Rosenbrock Function (2D)

The Rosenbrock function is a classic non-convex optimization challenge with a narrow valley leading to the global minimum.

In [None]:
# Setup 2D problem
model_2d = RosenbrockModel(dim=2)
initial_params_2d = jnp.array([-2.0, 2.0])
batches = [None]  # No batches for deterministic functions

print(f"🎯 Rosenbrock function starting from {initial_params_2d}")
print(f"🎯 Global minimum at {model_2d.optimal_params()}")
print(f"🎯 Initial loss: {model_2d.loss(initial_params_2d):.2f}")

In [None]:
# Run optimization comparison
results_2d = compare_optimizers(
    model=model_2d,
    optimizer_configs=optimizers,
    initial_params=initial_params_2d,
    batches=batches,
    num_steps=1000,
    verbose=False  # Keep notebook clean
)

print("✅ Optimization complete!")

# Show comparison table
print_comparison_table(results_2d)

In [None]:
# Visualize 2D trajectories
fig = plot_2d_trajectories(
    results_2d, model_2d, 
    x_range=(-2.5, 2.5), y_range=(-1, 3),
    figsize=(12, 10)
)
fig.suptitle("Rosenbrock Function - Optimizer Trajectories", fontsize=16)
plt.show()

In [None]:
# Loss curves
fig = plot_loss_curves(results_2d, title="Rosenbrock Function - Loss Curves", figsize=(12, 6))
plt.show()

In [None]:
# Convergence analysis
fig = plot_convergence_summary(results_2d)
plt.show()

## Example 2: High-Dimensional Ill-Conditioned Problem

Test optimizers on a challenging high-dimensional quadratic with poor conditioning.

In [None]:
# Create challenging high-dimensional problem
model_hd = SimpleQuadraticModel(dim=20, condition_number=1000)
initial_params_hd = jax.random.normal(jax.random.PRNGKey(42), (20,)) * 2.0

print(f"🎯 High-dimensional problem: {model_hd.dim}D")
print(f"🎯 Condition number: {1000}")
print(f"🎯 Initial loss: {model_hd.loss(initial_params_hd):.2f}")

In [None]:
# Run optimization (fewer steps for high-dim)
results_hd = compare_optimizers(
    model=model_hd,
    optimizer_configs=optimizers,
    initial_params=initial_params_hd,
    batches=batches,
    num_steps=500,
    verbose=False
)

print("✅ High-dimensional optimization complete!")
print_comparison_table(results_hd)

In [None]:
# PCA projection of high-dimensional trajectories
fig = plot_high_dimensional_trajectories(
    results_hd, method='pca', n_components=2, figsize=(12, 8)
)
fig.suptitle("High-Dimensional Trajectories (PCA Projection)", fontsize=16)
plt.show()

In [None]:
# 3D PCA projection
fig = plot_high_dimensional_trajectories(
    results_hd, method='pca', n_components=3, figsize=(12, 10)
)
plt.show()

In [None]:
# Parameter evolution for first 6 parameters
fig = plot_parameter_evolution(results_hd, max_params=6, figsize=(15, 10))
fig.suptitle("Parameter Evolution (High-Dimensional Problem)", fontsize=16)
plt.show()

## Example 3: Effect of Problem Conditioning

Compare how optimizers handle problems with different condition numbers.

In [None]:
# Test on different condition numbers
condition_numbers = [1, 10, 100, 1000]
conditioning_results = {}

# Use subset of optimizers for cleaner comparison
optimizers_subset = {
    'Adam': optimizers['Adam'],
    'Muon': optimizers['Muon']
}

for cond_num in condition_numbers:
    print(f"Testing condition number: {cond_num}")
    
    model_cond = SimpleQuadraticModel(dim=5, condition_number=cond_num)
    initial_params_cond = jax.random.normal(jax.random.PRNGKey(123), (5,)) * 1.0
    
    results_cond = compare_optimizers(
        model=model_cond,
        optimizer_configs=optimizers_subset,
        initial_params=initial_params_cond,
        batches=batches,
        num_steps=200,
        verbose=False
    )
    
    conditioning_results[cond_num] = results_cond

print("✅ Conditioning study complete!")

In [None]:
# Plot conditioning analysis
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for i, cond_num in enumerate(condition_numbers):
    results = conditioning_results[cond_num]
    
    for name, trajectory in results.items():
        losses = np.array(trajectory.losses)
        steps = np.arange(len(losses))
        axes[i].semilogy(steps, losses, label=name, linewidth=2)
    
    axes[i].set_title(f'Condition Number: {cond_num}')
    axes[i].set_xlabel('Step')
    axes[i].set_ylabel('Loss')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

fig.suptitle('Effect of Problem Conditioning on Optimizer Performance', fontsize=16)
plt.tight_layout()
plt.show()

## Key Insights

From these experiments, we can observe:

1. **Adam/AdamW** adapt well to different scales and are generally robust
2. **SGD+Momentum** needs careful learning rate tuning but can be very stable
3. **Muon's orthogonalization** helps with ill-conditioned problems
4. **High condition numbers** challenge all optimizers, but adaptive methods cope better
5. **2D visualizations** reveal trajectory differences clearly
6. **PCA projections** help understand high-dimensional behavior

## Custom Experiments

You can easily run your own experiments by:

In [None]:
# Example: Create your own model
class CustomModel:
    def __init__(self):
        self.param_shape = (2,)
    
    def loss(self, params, batch=None):
        # Your custom loss function
        x, y = params[0], params[1]
        return x**4 + y**4 - 2*x**2 - 2*y**2 + x*y  # Multi-modal function
    
    def grad(self, params, batch=None):
        # Analytical or use jax.grad
        return jax.grad(self.loss)(params)

# Test your custom model
custom_model = CustomModel()
custom_initial = jnp.array([2.0, -1.0])

custom_results = compare_optimizers(
    model=custom_model,
    optimizer_configs={'Adam': optimizers['Adam'], 'SGD': optimizers['SGD+Momentum']},
    initial_params=custom_initial,
    batches=[None],
    num_steps=500,
    verbose=False
)

# Visualize
plot_loss_curves(custom_results, title="Custom Multi-Modal Function")
plt.show()

print("✅ Custom experiment complete!")