# Introduction to JAX/Flax/Optax

This notebook covers the essential concepts needed to understand the ENF codebase. We'll focus on the key features and patterns used throughout the code.

## Key Concepts

### 1. JAX Basics

JAX is NumPy on steroids - it provides automatic differentiation and compilation to accelerators (GPU/TPU). Key features used in our codebase:

In [None]:
import jax
import jax.numpy as jnp

# 1. jit compilation - makes functions run faster by compiling to XLA
@jax.jit
def fast_function(x):
    return jnp.sum(x ** 2)

# 2. Automatic differentiation
@jax.jit
def loss_and_grad(x):
    loss = jnp.sum(x ** 2)
    grad = jax.grad(lambda x: loss)(x)
    return loss, grad

# 3. Random number handling - JAX requires explicit PRNG key management
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)  # Get new keys for random operations

### 2. Flax Neural Networks

Flax is a neural network library built on JAX. It uses a class-based approach with the `nn.Module` system:

In [None]:
from flax import linen as nn

class SimpleNN(nn.Module):
    hidden_size: int  # Configuration parameters are class attributes
    
    @nn.compact  # Makes the module stateful
    def __call__(self, x):
        # Layers are created on first call
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        return nn.Dense(1)(x)

# Initialize model and parameters
model = SimpleNN(hidden_size=64)
key = jax.random.PRNGKey(0)
x = jnp.ones((1, 10))
params = model.init(key, x)  # Returns initialized parameters

# Applying the model requires passing the parameters
output = model.apply(params, x)

### 3. Optax Optimizers

Optax provides optimizers and gradient transformation utilities:

In [None]:
import optax

# Create optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)  # Initialize optimizer state

# Training step
@jax.jit
def train_step(params, opt_state, batch):
    def loss_fn(params):
        output = model.apply(params, batch)
        return jnp.mean(output ** 2)
    
    # Get loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(params)
    
    # Update parameters
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss

### 4. Key Patterns in the ENF Codebase

1. **Model Structure**: Models are defined as Flax modules (like `EquivariantNeuralField`)
2. **Training Loop Pattern**:
   - Inner loop optimization (for meta-learning)
   - Outer loop updates model parameters
3. **Functional Updates**: Everything is immutable - new states/parameters are returned rather than modified in-place

### Common Gotchas

1. JAX arrays are immutable - operations return new arrays
2. `@jax.jit` functions must be pure (same inputs → same outputs)
3. Random operations need explicit PRNG key management
4. Shape errors can be cryptic - use `jax.debug.print()` for debugging

In [None]:
import optax

class MetaLearningExample(nn.Module):
    hidden_size: int
    
    @nn.compact
    def __call__(self, x, z):
        # x: input data
        # z: latent variables to be optimized in inner loop
        combined = jnp.concatenate([x, z], axis=-1)
        return nn.Dense(1)(nn.relu(nn.Dense(self.hidden_size)(combined)))

# Initialize model
model = MetaLearningExample(hidden_size=64)
params = model.init(key, x, jnp.ones((1, 5)))

# Create optimizer for outer loop (model parameters)
outer_optimizer = optax.adam(learning_rate=1e-3)
outer_opt_state = outer_optimizer.init(params)

@jax.jit
def inner_loop(params, x, y, z, inner_steps=3, inner_lr=0.1):
    """Optimize latent variables z to fit current data."""
    def loss_fn(z):
        pred = model.apply(params, x, z)
        return jnp.mean((pred - y) ** 2)
    
    def inner_step(z, _):
        loss, grads = jax.value_and_grad(loss_fn)(z)
        z = z - inner_lr * grads  # Simple gradient descent
        return z, loss
    
    # Run inner optimization loop
    z, losses = jax.lax.scan(inner_step, z, None, length=inner_steps)
    
    return loss_fn(z), z

@jax.jit
def outer_step(params, opt_state, x, y, z):
    """Update model parameters using meta-gradients."""
    def meta_loss(params):
        loss, optimal_z = inner_loop(params, x, y, z)
        return loss
    
    # Get meta-gradients
    loss, grads = jax.value_and_grad(meta_loss)(params)
    
    # Update model parameters
    updates, opt_state = outer_optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss

# Training loop
for epoch in range(num_epochs):
    for batch_x, batch_y in dataloader:
        # Initialize latent variables for this batch
        z = jnp.zeros((batch_x.shape[0], 5))
        
        # Outer loop update
        params, outer_opt_state, loss = outer_step(params, outer_opt_state, batch_x, batch_y, z)