# Introduction to JAX: The Framework That Thinks in Gradients

**Google's secret weapon for differentiable everything**

---

There's a quiet revolution happening in machine learning.

While everyone argues about PyTorch vs TensorFlow, a third framework has been silently conquering research labs at DeepMind, Google Brain, and top universities. It powers AlphaFold. It runs cutting-edge reinforcement learning. It trains transformers at unprecedented speeds.

It's called **JAX**.

And it's not just another deep learning framework. It's a fundamentally different way of thinking about numerical computation.

This notebook will take you from zero to dangerous. By the end, you'll understand not just *how* to use JAX, but *why* it exists and when it's the right tool for the job.

## Part 1: What Is JAX?

JAX is three things fused into one:

1. **NumPy, but faster** - Same API you know, but running on GPU/TPU
2. **Automatic differentiation** - Gradients of arbitrary Python functions
3. **A compiler** - XLA compilation for insane speedups

The name "JAX" comes from:
- **J**ust-in-time compilation
- **A**utomatic differentiation
- **X**LA (Accelerated Linear Algebra)

### The Philosophy

JAX is built on a radical idea: **functions should be transformable**.

You write a function. JAX gives you tools to:
- Get its gradient (`grad`)
- Compile it for speed (`jit`)
- Vectorize it over batches (`vmap`)
- Parallelize it across devices (`pmap`)

These transformations compose. You can take the gradient of a jit-compiled, vmapped function. This composability is JAX's superpower.

In [None]:
# Install JAX (CPU version for this notebook)
# For GPU: pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install jax jaxlib -q

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
import numpy as np

print(f"JAX version: {jax.__version__}")
print(f"Devices available: {jax.devices()}")

## Part 2: JAX as NumPy's Cooler Sibling

If you know NumPy, you already know 80% of JAX. The API is intentionally familiar.

In [None]:
# NumPy code
np_array = np.array([1.0, 2.0, 3.0])
np_result = np.sin(np_array) + np.cos(np_array)

# JAX code - literally the same, just change the import
jax_array = jnp.array([1.0, 2.0, 3.0])
jax_result = jnp.sin(jax_array) + jnp.cos(jax_array)

print("NumPy result:", np_result)
print("JAX result:  ", jax_result)
print("\nAre they equal?", np.allclose(np_result, np.array(jax_result)))

In [None]:
# Most NumPy operations work identically

# Matrix operations
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])

print("Matrix multiplication:")
print(A @ B)

print("\nElement-wise operations:")
print(A * B)

print("\nReductions:")
print(f"Sum: {jnp.sum(A)}, Mean: {jnp.mean(A)}, Max: {jnp.max(A)}")

print("\nLinear algebra:")
print(f"Determinant: {jnp.linalg.det(A.astype(float))}")
print(f"Eigenvalues: {jnp.linalg.eigvals(A.astype(float))}")

In [None]:
# JAX arrays are immutable - this is a KEY difference

# NumPy allows in-place modification
np_arr = np.array([1, 2, 3])
np_arr[0] = 100  # This works
print("NumPy (mutated):", np_arr)

# JAX does NOT allow this
jax_arr = jnp.array([1, 2, 3])
try:
    jax_arr[0] = 100  # This will fail!
except TypeError as e:
    print(f"JAX error: {e}")

# Instead, use .at[].set() which returns a NEW array
jax_arr_new = jax_arr.at[0].set(100)
print("JAX (new array):", jax_arr_new)
print("Original unchanged:", jax_arr)

### Why Immutability?

This isn't JAX being difficult. It's *functional programming*, and it enables:

1. **Safe parallelism** - No race conditions if data can't change
2. **Reliable gradients** - Autodiff needs predictable behavior
3. **Aggressive compilation** - XLA can optimize better when it knows data won't mutate

Think of it like this: every JAX array is a mathematical value, not a mutable container.

## Part 3: Automatic Differentiation with `grad`

This is where JAX starts to feel like magic.

You write a Python function. JAX gives you its derivative. Automatically.

In [None]:
# Let's start simple: f(x) = x^2
# We know the derivative is f'(x) = 2x

def f(x):
    return x ** 2

# grad() returns a NEW function that computes the derivative
df = grad(f)

# Test it
x = 3.0
print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {df(x)}")
print(f"Expected: 2 * {x} = {2 * x}")

In [None]:
# It works with complex functions too

def complex_function(x):
    """f(x) = sin(x^2) + e^(-x) * cos(x)"""
    return jnp.sin(x**2) + jnp.exp(-x) * jnp.cos(x)

# Get the derivative
d_complex = grad(complex_function)

# And the second derivative!
d2_complex = grad(grad(complex_function))

x = 1.0
print(f"f({x})   = {complex_function(x):.6f}")
print(f"f'({x})  = {d_complex(x):.6f}")
print(f"f''({x}) = {d2_complex(x):.6f}")

In [None]:
# Gradients work with control flow - JAX traces through your code

def piecewise_function(x):
    """A function with branches."""
    if x < 0:
        return -x  # f(x) = -x for x < 0, so f'(x) = -1
    else:
        return x ** 2  # f(x) = x^2 for x >= 0, so f'(x) = 2x

d_piecewise = grad(piecewise_function)

print("Derivative at x = -2.0:", d_piecewise(-2.0))  # Should be -1
print("Derivative at x = 3.0:", d_piecewise(3.0))    # Should be 6

In [None]:
# Multi-variable functions: gradients become vectors

def multivar_function(params):
    """f(x, y) = x^2 + 3xy + y^2"""
    x, y = params
    return x**2 + 3*x*y + y**2

# grad computes partial derivatives
gradient_fn = grad(multivar_function)

params = jnp.array([1.0, 2.0])  # x=1, y=2
grads = gradient_fn(params)

print(f"Point: x={params[0]}, y={params[1]}")
print(f"f(x,y) = {multivar_function(params)}")
print(f"Gradient: [df/dx, df/dy] = {grads}")
print()
print("Manual calculation:")
print(f"  df/dx = 2x + 3y = 2(1) + 3(2) = 8")
print(f"  df/dy = 3x + 2y = 3(1) + 2(2) = 7")

### How Does Autodiff Work?

JAX uses **reverse-mode automatic differentiation** (backpropagation).

It's NOT:
- Symbolic differentiation (like Mathematica) - that explodes for complex functions
- Numerical differentiation (finite differences) - that's slow and imprecise

It IS:
- Tracing your function to build a computation graph
- Applying the chain rule backwards through that graph
- Returning exact derivatives (up to floating point)

The cost? Computing gradients costs roughly 2-3x the cost of computing the function itself.

In [None]:
# value_and_grad gives you both the function value AND gradient in one pass
from jax import value_and_grad

def loss_function(params):
    return jnp.sum(params ** 2)

# This is more efficient than calling f(x) and grad(f)(x) separately
value_and_grad_fn = value_and_grad(loss_function)

params = jnp.array([1.0, 2.0, 3.0])
value, grads = value_and_grad_fn(params)

print(f"Loss: {value}")
print(f"Gradients: {grads}")

## Part 4: JIT Compilation - Making Code Fly

Python is slow. JAX fixes this by compiling your functions with XLA.

`jit` (just-in-time) compilation traces your function once, then runs the optimized version.

In [None]:
import time

def slow_function(x):
    """A computation-heavy function."""
    for _ in range(10):
        x = jnp.sin(x) + jnp.cos(x)
        x = jnp.tanh(x @ x.T)
    return jnp.sum(x)

# Create a JIT-compiled version
fast_function = jit(slow_function)

# Test data
x = random.normal(random.PRNGKey(0), (1000, 1000))

# Warm up JIT (first call triggers compilation)
_ = fast_function(x).block_until_ready()

# Time comparison
start = time.time()
for _ in range(10):
    result_slow = slow_function(x).block_until_ready()
time_slow = time.time() - start

start = time.time()
for _ in range(10):
    result_fast = fast_function(x).block_until_ready()
time_fast = time.time() - start

print(f"Without JIT: {time_slow:.3f}s")
print(f"With JIT:    {time_fast:.3f}s")
print(f"Speedup:     {time_slow/time_fast:.1f}x")

In [None]:
# You can also use @jit as a decorator

@jit
def neural_network_layer(params, x):
    """A simple dense layer."""
    W, b = params['W'], params['b']
    return jnp.tanh(x @ W + b)

# Initialize
key = random.PRNGKey(42)
params = {
    'W': random.normal(key, (784, 256)) * 0.01,
    'b': jnp.zeros(256)
}
x = random.normal(key, (32, 784))  # Batch of 32 images

# This will be fast
output = neural_network_layer(params, x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

### JIT Gotchas

JIT traces your function with abstract values. This means:

1. **No Python side effects** - print() won't work as expected inside JIT
2. **Static shapes preferred** - Dynamic shapes require recompilation
3. **Control flow constraints** - Use `jax.lax.cond` instead of Python `if` for traced values

In [None]:
# Example: print() inside JIT only runs during tracing

@jit
def sneaky_function(x):
    print("This only prints once during tracing!")  # Not every call
    return x ** 2

print("First call (triggers tracing):")
result1 = sneaky_function(2.0)

print("\nSecond call (uses cached compilation):")
result2 = sneaky_function(3.0)

print(f"\nResults: {result1}, {result2}")

## Part 5: Vectorization with `vmap`

You write a function for a single example. `vmap` makes it work on batches.

No manual batch dimensions. No reshape gymnastics. Just vmap.

In [None]:
# Function that works on a single vector
def single_prediction(weights, x):
    """Predict for ONE input vector."""
    return jnp.dot(weights, x)

# Create data
weights = jnp.array([1.0, 2.0, 3.0])
single_x = jnp.array([0.5, 0.3, 0.2])

# Works for single input
print("Single prediction:", single_prediction(weights, single_x))

# But what about a batch of inputs?
batch_x = random.normal(random.PRNGKey(0), (100, 3))  # 100 samples

# vmap to the rescue!
batched_prediction = vmap(single_prediction, in_axes=(None, 0))
#                                            weights: not batched, x: batched along axis 0

batch_results = batched_prediction(weights, batch_x)
print(f"Batch prediction shape: {batch_results.shape}")

In [None]:
# vmap is incredibly powerful for neural networks

def single_forward(params, x):
    """Forward pass for a SINGLE example."""
    # Layer 1
    h = jnp.tanh(x @ params['W1'] + params['b1'])
    # Layer 2
    out = h @ params['W2'] + params['b2']
    return out

# Initialize network
key = random.PRNGKey(42)
keys = random.split(key, 4)
params = {
    'W1': random.normal(keys[0], (784, 128)) * 0.01,
    'b1': jnp.zeros(128),
    'W2': random.normal(keys[1], (128, 10)) * 0.01,
    'b2': jnp.zeros(10)
}

# Single example
single_x = random.normal(keys[2], (784,))
print("Single output:", single_forward(params, single_x).shape)

# Batched - just vmap!
batched_forward = vmap(single_forward, in_axes=(None, 0))
batch_x = random.normal(keys[3], (64, 784))
print("Batch output:", batched_forward(params, batch_x).shape)

In [None]:
# vmap composes with grad - per-example gradients!

def loss_single(params, x, y):
    """Loss for a single example."""
    pred = single_forward(params, x)
    return jnp.sum((pred - y) ** 2)

# Gradient for single example
grad_single = grad(loss_single)

# Per-example gradients for a batch
per_example_grads = vmap(grad_single, in_axes=(None, 0, 0))

# Test
batch_x = random.normal(random.PRNGKey(0), (32, 784))
batch_y = random.normal(random.PRNGKey(1), (32, 10))

grads = per_example_grads(params, batch_x, batch_y)
print("Per-example gradient shapes:")
for name, g in grads.items():
    print(f"  {name}: {g.shape}")  # First dimension is batch!

## Part 6: Random Numbers in JAX

JAX handles randomness differently than NumPy. This trips up everyone at first, but it's actually better.

### The Problem with NumPy's RNG

NumPy uses global state for random numbers:
```python
np.random.seed(42)
x = np.random.randn(10)  # Uses and updates global state
```

This is:
- **Non-reproducible** across function calls
- **Not thread-safe** for parallel execution
- **Incompatible with JIT** (side effects!)

### JAX Solution: Explicit Keys

Every random operation takes a **key**. You split keys to get independent randomness.

In [None]:
from jax import random

# Create a starting key
key = random.PRNGKey(42)
print(f"Initial key: {key}")

# Generate random numbers
x = random.normal(key, (3,))
print(f"Random array: {x}")

# IMPORTANT: Using the same key gives the SAME numbers!
x_again = random.normal(key, (3,))
print(f"Same key again: {x_again}")
print(f"Identical? {jnp.allclose(x, x_again)}")

In [None]:
# To get different random numbers, SPLIT the key

key = random.PRNGKey(42)

# Split into 2 new keys
key1, key2 = random.split(key)

x1 = random.normal(key1, (3,))
x2 = random.normal(key2, (3,))

print(f"x1: {x1}")
print(f"x2: {x2}")
print(f"Different? {not jnp.allclose(x1, x2)}")

In [None]:
# The pattern: split before each use

def initialize_layer(key, in_features, out_features):
    """Initialize a layer with proper key handling."""
    key_w, key_b = random.split(key)
    W = random.normal(key_w, (in_features, out_features)) * 0.01
    b = random.normal(key_b, (out_features,)) * 0.01
    return {'W': W, 'b': b}

def initialize_network(key, layer_sizes):
    """Initialize a full network."""
    keys = random.split(key, len(layer_sizes) - 1)
    params = []
    for i, (k, in_f, out_f) in enumerate(zip(keys, layer_sizes[:-1], layer_sizes[1:])):
        params.append(initialize_layer(k, in_f, out_f))
    return params

# Initialize a 784 -> 256 -> 128 -> 10 network
key = random.PRNGKey(0)
network_params = initialize_network(key, [784, 256, 128, 10])

print("Network architecture:")
for i, layer in enumerate(network_params):
    print(f"  Layer {i}: W{layer['W'].shape}, b{layer['b'].shape}")

In [None]:
# Random distributions available

key = random.PRNGKey(42)
keys = random.split(key, 6)

print("Available distributions (examples):")
print(f"  Normal:      {random.normal(keys[0], (3,))}")
print(f"  Uniform:     {random.uniform(keys[1], (3,))}")
print(f"  Categorical: {random.categorical(keys[2], jnp.array([0.1, 0.3, 0.6]), shape=(5,))}")
print(f"  Bernoulli:   {random.bernoulli(keys[3], 0.7, shape=(5,))}")
print(f"  Randint:     {random.randint(keys[4], (5,), 0, 10)}")
print(f"  Permutation: {random.permutation(keys[5], 5)}")

## Part 7: PyTrees - Nested Data Structures

Neural networks have nested parameters: layers containing weights and biases, attention heads, etc.

JAX handles these with **PyTrees** - arbitrary nested structures of arrays.

In [None]:
from jax import tree_util

# PyTrees can be dicts, lists, tuples, or nested combinations
params = {
    'encoder': {
        'layer1': {'W': jnp.ones((10, 5)), 'b': jnp.zeros(5)},
        'layer2': {'W': jnp.ones((5, 3)), 'b': jnp.zeros(3)},
    },
    'decoder': {
        'layer1': {'W': jnp.ones((3, 5)), 'b': jnp.zeros(5)},
    }
}

# Get all leaves (the actual arrays)
leaves = tree_util.tree_leaves(params)
print(f"Number of parameter arrays: {len(leaves)}")
print(f"Total parameters: {sum(l.size for l in leaves)}")

In [None]:
# tree_map applies a function to all leaves

# Initialize all weights with small random values
def init_weight(x):
    return x * 0.01 if 'W' in str(type(x)) else x

# Or more practically: scale all parameters
scaled_params = tree_util.tree_map(lambda x: x * 0.1, params)

print("Original W shape and values:")
print(f"  {params['encoder']['layer1']['W'][0, :3]}")
print("Scaled:")
print(f"  {scaled_params['encoder']['layer1']['W'][0, :3]}")

In [None]:
# grad works on PyTrees automatically!

def simple_loss(params, x, y):
    """Loss with nested params."""
    W1, b1 = params['layer1']['W'], params['layer1']['b']
    W2, b2 = params['layer2']['W'], params['layer2']['b']
    
    h = jnp.tanh(x @ W1 + b1)
    pred = h @ W2 + b2
    return jnp.mean((pred - y) ** 2)

# Parameters
params = {
    'layer1': {'W': random.normal(random.PRNGKey(0), (10, 5)), 'b': jnp.zeros(5)},
    'layer2': {'W': random.normal(random.PRNGKey(1), (5, 2)), 'b': jnp.zeros(2)}
}

# Data
x = random.normal(random.PRNGKey(2), (32, 10))
y = random.normal(random.PRNGKey(3), (32, 2))

# Gradient is a PyTree with same structure!
grads = grad(simple_loss)(params, x, y)

print("Gradient structure matches params:")
print(f"  grads['layer1']['W'].shape = {grads['layer1']['W'].shape}")
print(f"  grads['layer2']['b'].shape = {grads['layer2']['b'].shape}")

In [None]:
# SGD update with PyTrees

def sgd_update(params, grads, learning_rate=0.01):
    """Apply SGD update to nested params."""
    return tree_util.tree_map(
        lambda p, g: p - learning_rate * g,
        params, grads
    )

# One training step
loss_before = simple_loss(params, x, y)
grads = grad(simple_loss)(params, x, y)
params = sgd_update(params, grads)
loss_after = simple_loss(params, x, y)

print(f"Loss before: {loss_before:.4f}")
print(f"Loss after:  {loss_after:.4f}")
print(f"Improved: {loss_after < loss_before}")

## Part 8: Putting It All Together - Training a Neural Network

Let's train a real neural network from scratch using everything we've learned.

In [None]:
# Generate synthetic classification data

def generate_spiral_data(key, n_points, n_classes=3):
    """Generate spiral dataset for classification."""
    keys = random.split(key, n_classes)
    
    X_list, y_list = [], []
    for c, k in enumerate(keys):
        # Spiral parameters
        r = jnp.linspace(0.1, 1, n_points)  # Radius
        t = jnp.linspace(c * 4, (c + 1) * 4, n_points) + random.normal(k, (n_points,)) * 0.2
        
        # Convert to Cartesian
        x = r * jnp.sin(t)
        y = r * jnp.cos(t)
        
        X_list.append(jnp.stack([x, y], axis=1))
        y_list.append(jnp.full(n_points, c))
    
    X = jnp.concatenate(X_list)
    y = jnp.concatenate(y_list)
    
    return X, y

# Generate data
key = random.PRNGKey(42)
X, y = generate_spiral_data(key, 100, n_classes=3)
print(f"Data shape: X={X.shape}, y={y.shape}")
print(f"Classes: {jnp.unique(y)}")

In [None]:
# Define the network

def init_mlp(key, layer_sizes):
    """Initialize MLP parameters."""
    params = []
    keys = random.split(key, len(layer_sizes) - 1)
    
    for k, n_in, n_out in zip(keys, layer_sizes[:-1], layer_sizes[1:]):
        k1, k2 = random.split(k)
        # Xavier initialization
        W = random.normal(k1, (n_in, n_out)) * jnp.sqrt(2.0 / n_in)
        b = jnp.zeros(n_out)
        params.append({'W': W, 'b': b})
    
    return params

def mlp_forward(params, x):
    """Forward pass through MLP."""
    for layer in params[:-1]:
        x = jnp.tanh(x @ layer['W'] + layer['b'])
    # No activation on final layer (logits)
    x = x @ params[-1]['W'] + params[-1]['b']
    return x

def softmax_cross_entropy(logits, labels):
    """Compute cross-entropy loss."""
    # Numerically stable softmax
    logits = logits - jnp.max(logits, axis=-1, keepdims=True)
    log_probs = logits - jnp.log(jnp.sum(jnp.exp(logits), axis=-1, keepdims=True))
    # One-hot encode labels
    one_hot = jax.nn.one_hot(labels, logits.shape[-1])
    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))

def loss_fn(params, x, y):
    """Compute loss for a batch."""
    logits = mlp_forward(params, x)
    return softmax_cross_entropy(logits, y)

def accuracy(params, x, y):
    """Compute classification accuracy."""
    logits = mlp_forward(params, x)
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == y)

# Initialize
key = random.PRNGKey(0)
params = init_mlp(key, [2, 64, 64, 3])  # 2 inputs, 2 hidden layers, 3 outputs

print("Network initialized:")
for i, layer in enumerate(params):
    print(f"  Layer {i}: W{layer['W'].shape}, b{layer['b'].shape}")

In [None]:
# Training loop

@jit
def train_step(params, x, y, learning_rate=0.1):
    """One training step."""
    loss, grads = value_and_grad(loss_fn)(params, x, y)
    # SGD update
    params = tree_util.tree_map(
        lambda p, g: p - learning_rate * g,
        params, grads
    )
    return params, loss

# Train!
n_epochs = 500
history = []

for epoch in range(n_epochs):
    params, loss = train_step(params, X, y)
    
    if epoch % 50 == 0:
        acc = accuracy(params, X, y)
        history.append((epoch, float(loss), float(acc)))
        print(f"Epoch {epoch:4d} | Loss: {loss:.4f} | Accuracy: {acc:.2%}")

# Final accuracy
final_acc = accuracy(params, X, y)
print(f"\nFinal accuracy: {final_acc:.2%}")

In [None]:
# Visualize decision boundary

import matplotlib.pyplot as plt

# Create grid
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = jnp.meshgrid(jnp.linspace(x_min, x_max, 100),
                       jnp.linspace(y_min, y_max, 100))
grid = jnp.c_[xx.ravel(), yy.ravel()]

# Get predictions
logits = mlp_forward(params, grid)
Z = jnp.argmax(logits, axis=-1).reshape(xx.shape)

# Plot
plt.figure(figsize=(10, 8))
plt.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', edgecolors='black', s=50)
plt.colorbar(scatter)
plt.xlabel('x1')
plt.ylabel('x2')
plt.title(f'JAX Neural Network - Spiral Classification\nAccuracy: {final_acc:.2%}')
plt.tight_layout()
plt.show()

## Part 9: Advanced Topics - Quick Reference

### `pmap` - Parallel Execution Across Devices

```python
from jax import pmap

# Runs on all available devices in parallel
@pmap
def parallel_fn(x):
    return x ** 2
```

### `lax.scan` - Efficient Loops

```python
from jax import lax

# Instead of Python loops (slow, unrolled)
def scan_fn(carry, x):
    return carry + x, carry

final, history = lax.scan(scan_fn, init=0, xs=jnp.arange(100))
```

### `lax.cond` - JIT-Compatible Conditionals

```python
# Instead of Python if (which can't be traced)
result = lax.cond(condition, true_fn, false_fn, operand)
```

### Custom Gradients

```python
from jax import custom_vjp

@custom_vjp
def my_fn(x):
    return x ** 2

def my_fn_fwd(x):
    return my_fn(x), x  # (output, residuals)

def my_fn_bwd(residuals, g):
    x = residuals
    return (2 * x * g,)  # Custom gradient

my_fn.defvjp(my_fn_fwd, my_fn_bwd)
```

## Part 10: The JAX Ecosystem

JAX is low-level by design. Higher-level libraries build on top:

| Library | Purpose | Style |
|---------|---------|-------|
| **Flax** | Neural networks | Functional, explicit state |
| **Haiku** | Neural networks | Transformed functions (DeepMind) |
| **Optax** | Optimizers | Composable gradient transforms |
| **Equinox** | Neural networks | PyTorch-like, but functional |
| **Diffrax** | Differential equations | ODEs, SDEs, CDEs |
| **JAXopt** | Optimization | Differentiable optimizers |

### Quick Flax Example

In [None]:
!pip install flax optax -q

import flax.linen as nn
import optax

# Define model with Flax
class MLP(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.tanh(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.tanh(x)
        x = nn.Dense(self.output_dim)(x)
        return x

# Initialize
model = MLP(hidden_dim=64, output_dim=3)
params = model.init(random.PRNGKey(0), jnp.ones((1, 2)))

print("Flax model initialized!")
print(f"Parameter count: {sum(p.size for p in tree_util.tree_leaves(params))}")

In [None]:
# Optax for optimization

# Create optimizer
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)

# Loss function
def flax_loss_fn(params, x, y):
    logits = model.apply(params, x)
    return softmax_cross_entropy(logits, y)

@jit
def flax_train_step(params, opt_state, x, y):
    loss, grads = value_and_grad(flax_loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Quick training
for epoch in range(200):
    params, opt_state, loss = flax_train_step(params, opt_state, X, y)
    if epoch % 50 == 0:
        acc = jnp.mean(jnp.argmax(model.apply(params, X), axis=-1) == y)
        print(f"Epoch {epoch:3d} | Loss: {loss:.4f} | Acc: {acc:.2%}")

## Part 11: Common Gotchas and Debugging

### Gotcha 1: Arrays Are Immutable

In [None]:
# WRONG: Trying to mutate
x = jnp.array([1, 2, 3])
# x[0] = 100  # Error!

# RIGHT: Create new array
x = x.at[0].set(100)
print(x)

### Gotcha 2: Shapes Must Be Static in JIT

In [None]:
# WRONG: Dynamic shape causes recompilation
@jit
def bad_fn(x, n):
    return x[:n]  # n changes shape!

# RIGHT: Use static_argnums for values that affect shapes
@jit
def good_fn(x, n):
    return x[:n]

good_fn_static = jit(good_fn, static_argnums=(1,))

x = jnp.arange(10)
print(good_fn_static(x, 5))

### Gotcha 3: Random Keys Must Be Split

In [None]:
# WRONG: Reusing key gives same values
key = random.PRNGKey(0)
a = random.normal(key, (3,))
b = random.normal(key, (3,))  # Same as a!
print(f"Reused key: a={a}, b={b}")

# RIGHT: Split key
key = random.PRNGKey(0)
key1, key2 = random.split(key)
a = random.normal(key1, (3,))
b = random.normal(key2, (3,))
print(f"Split keys: a={a}, b={b}")

### Gotcha 4: print() Doesn't Work as Expected in JIT

In [None]:
# Use jax.debug.print for debugging inside JIT
from jax import debug

@jit
def debuggable_fn(x):
    debug.print("x = {x}", x=x)  # Works!
    return x ** 2

result = debuggable_fn(jnp.array([1.0, 2.0, 3.0]))

## Part 12: When to Use JAX

### JAX Excels At:
- Research requiring custom gradients or architectures
- Anything involving Hessians or higher-order derivatives
- Probabilistic programming (see NumPyro)
- Scientific computing with autodiff
- TPU training (first-class support)
- Functional programming enthusiasts

### Maybe Use Something Else If:
- You need extensive pre-built model zoo (PyTorch/TF have more)
- You're deploying to mobile/edge (TFLite, CoreML have better tooling)
- You want maximum beginner-friendliness (PyTorch is gentler)
- You need eager execution for debugging (JAX prefers traced execution)

### The Sweet Spot:
JAX shines when you need **composable transformations**. If you find yourself saying "I wish I could take the gradient of X" or "I wish this loop was faster" or "I want per-example gradients" -- JAX probably has a clean solution.

## Key Takeaways

1. **JAX = NumPy + Autodiff + XLA** - Familiar API with superpowers

2. **Four Core Transformations**:
   - `grad` - automatic differentiation
   - `jit` - compilation for speed
   - `vmap` - automatic vectorization
   - `pmap` - parallel across devices

3. **Functional Paradigm**:
   - Immutable arrays
   - Pure functions
   - Explicit random keys

4. **PyTrees** handle nested structures naturally

5. **Ecosystem**: Flax/Haiku for NNs, Optax for optimizers

---

### Resources

- **Official Docs**: https://jax.readthedocs.io/
- **JAX GitHub**: https://github.com/google/jax
- **Flax**: https://github.com/google/flax
- **Optax**: https://github.com/deepmind/optax
- **JAX Tutorial (Google)**: https://jax.readthedocs.io/en/latest/tutorials.html

---

*"The best code is no code. The second best is code that writes itself."*

JAX doesn't write your code, but it transforms it in ways that feel like magic. Once you internalize the functional mindset, you'll wonder how you ever lived without `grad`, `jit`, and `vmap`.

Now go differentiate everything.

---

**About this notebook**: A comprehensive introduction to JAX covering core concepts, transformations, and practical neural network training from scratch.

**Connect**: [kaggle.com/seki32](https://kaggle.com/seki32) | [github.com/Rekhii](https://github.com/Rekhii)