# JAX Practice Worksheet

A simple study guide covering JAX fundamentals.

### What is JAX?

JAX is Google's library for high-performance numerical computing and machine learning research. Think of it as **NumPy on steroids** — it gives you a familiar NumPy-like API but adds three superpowers:

1. **Automatic differentiation** (`grad`) — compute gradients of any function automatically, which is the backbone of training neural networks and optimization in general.
2. **Just-in-time compilation** (`jit`) — compile your Python functions down to optimized machine code using XLA (Accelerated Linear Algebra), the same compiler backend that powers TensorFlow.
3. **Auto-vectorization** (`vmap`) — write a function that works on a single example, then instantly vectorize it to work on entire batches with no manual loop writing.

### Why JAX instead of NumPy or PyTorch?

- **vs NumPy**: JAX can run on GPU/TPU and supports autodiff. NumPy is CPU-only and has no built-in gradients.
- **vs PyTorch**: JAX takes a more *functional* approach — no classes, no `nn.Module`, just pure functions. This makes code easier to reason about and compose. PyTorch is more object-oriented and imperative.
- **Composability**: JAX transformations (`grad`, `jit`, `vmap`) can be freely composed. You can `jit(vmap(grad(f)))` and it just works.

### Key mental model

JAX functions should be **pure functions** — they take inputs and return outputs with no side effects. This is what enables all the powerful transformations to work correctly.

## 1. Setup & Imports

We import four core pieces of JAX:

- **`jax`** — the top-level module; gives us `jax.devices()` to check what hardware we're running on.
- **`jax.numpy` (as `jnp`)** — a drop-in replacement for NumPy. Almost every `np.something()` has a `jnp.something()` equivalent. The key difference is that `jnp` arrays live on accelerators (GPU/TPU) and are immutable.
- **`grad`, `jit`, `vmap`** — the three core transformations. These are *higher-order functions*: they take a function as input and return a new, transformed function.
- **`jax.random`** — JAX's random number system. Unlike NumPy's `np.random`, JAX doesn't use global random state. Every random call requires an explicit key (more on this in Section 3).

In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

print("JAX version:", jax.__version__)
print("Devices:", jax.devices())

JAX version: 0.9.0
Devices: [CudaDevice(id=0)]


## 2. JAX Arrays vs NumPy

`jax.numpy` mirrors the NumPy API almost exactly, so if you know NumPy, you already know most of `jnp`. The critical differences:

### Immutability

JAX arrays are **immutable** — once created, you cannot modify them in-place. This means:

```python
# NumPy (works fine):
x[0] = 5

# JAX (raises an error!):
x[0] = 5  # TypeError: JAX arrays are immutable
```

Instead, JAX provides the `.at[].set()` syntax which returns a **new** array with the change applied, leaving the original untouched. This functional style is essential because JAX's transformations (grad, jit, vmap) rely on functions being pure — no side effects, no mutation.

### Other `.at[]` operations

Beyond `.set()`, you can also use:
- `x.at[i].add(v)` — add `v` to position `i`
- `x.at[i].multiply(v)` — multiply position `i` by `v`
- `x.at[i].min(v)` / `x.at[i].max(v)` — element-wise min/max

### Device placement

JAX arrays are automatically placed on the best available device (GPU > TPU > CPU). You can check with `x.devices()`. Data transfers between CPU and GPU happen automatically but can be a performance bottleneck if you're not careful.

In [10]:
# Creating arrays (just like numpy)
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.zeros((3, 3))
c = jnp.linspace(50, 100, 11)

print("a:", a)
print("b:\n", b)
print("c:", c)

a: [1. 2. 3.]
b:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
c: [ 50.  55.  60.  65.  70.  75.  80.  85.  90.  95. 100.]


In [11]:
# Immutable updates — use .at[].set()
x = jnp.zeros(5)
x_updated = x.at[2].set(99.0)

print("original:", x)
print("updated: ", x_updated)

original: [0. 0. 0. 0. 0.]
updated:  [ 0.  0. 99.  0.  0.]


### Exercise 2a
Create a 4x4 identity matrix using `jnp.eye()`, then replace the top-left element with `7.0`.

In [21]:
x = jnp.eye(4)
x_updated = x.at[0,0].set(7.0)
print(x)
print(x_updated)

[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
[[7. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]


## 3. Random Numbers (PRNGKey)

This is one of the **biggest differences** between JAX and NumPy, and it trips up almost everyone at first.

### The problem with NumPy's random

In NumPy, random numbers come from a **global** random state:

```python
np.random.seed(42)
x = np.random.normal(size=(3,))  # uses and mutates global state
```

This is convenient but causes problems:
- **Not reproducible under parallelism** — if two threads pull random numbers, the order is unpredictable.
- **Not compatible with `jit`** — JAX's JIT compiler needs pure functions with no hidden state.

### JAX's solution: explicit PRNG keys

In JAX, every random call takes an explicit **key** (a pair of 32-bit integers):

```python
key = random.PRNGKey(42)          # create a key from a seed
x = random.normal(key, shape=(3,)) # use the key
```

**Critical rule**: never reuse a key for two different random calls! If you do, you'll get the same "random" numbers. Instead, **split** the key:

```python
key, subkey = random.split(key)   # split into 2 new keys
x = random.normal(subkey, (3,))   # use the subkey, keep key for later
```

### The split pattern

`random.split(key, n)` takes one key and returns `n` independent new keys. The common pattern is:

```python
key, *subkeys = random.split(key, 4)  # keep key, get 3 subkeys
```

This feels verbose at first, but it guarantees **perfect reproducibility** regardless of execution order, parallelism, or hardware — which is essential for scientific computing and ML research.

In [22]:
key = random.PRNGKey(42)

# Split the key to get independent sub-keys
key, subkey1, subkey2 = random.split(key, 3)

x = random.normal(subkey1, shape=(3,))
y = random.uniform(subkey2, shape=(3,))

print("normal samples:", x)
print("uniform samples:", y)

normal samples: [ 0.60576403  0.7990441  -0.908927  ]
uniform samples: [0.6672406 0.7214867 0.1267947]


### Exercise 3a
Generate a 2x3 matrix of random integers between 0 and 10. (Hint: `random.randint`)

In [36]:
key = random.PRNGKey(1)
key,subkey = random.split(key,2)

x = random.randint(subkey,(2,3),0,10)
print(x)

key,subkey = random.split(key,2)
x = random.randint(subkey,(2,3),0,10)
print(x)

key,subkey = random.split(key,2)
x = random.randint(subkey,(2,3),0,10)
print(x)

[[8 6 6]
 [8 7 7]]
[[4 2 6]
 [4 1 4]]
[[5 0 5]
 [7 3 1]]


## 4. `grad` — Automatic Differentiation

This is arguably JAX's most important feature for machine learning.

### What is automatic differentiation?

Differentiation (finding derivatives/gradients) is the core of how neural networks learn. There are three ways to compute derivatives:

1. **Symbolic** — like you'd do by hand in calculus class. Exact but gets messy for complex functions.
2. **Numerical** — approximate with `(f(x+h) - f(x)) / h`. Simple but slow and imprecise.
3. **Automatic** — what JAX does. It traces through your Python code and applies the chain rule automatically. Exact *and* efficient.

### How `grad` works

`grad(f)` takes a function `f` and returns a **new function** that computes the derivative:

```python
def f(x):
    return x ** 2

df = grad(f)     # df is now a function that computes 2x
df(3.0)          # returns 6.0
```

Key details:
- **`grad` differentiates w.r.t. the first argument by default.** Use `argnums` to change this: `grad(f, argnums=1)` differentiates w.r.t. the second argument.
- **Input must be a float (or array of floats).** `grad` won't work on integers.
- **Output must be a scalar.** If your function returns an array, use `jax.jacobian` instead, or sum/mean the output first.
- **You can compose `grad`** — `grad(grad(f))` gives you the second derivative, `grad(grad(grad(f)))` the third, and so on.

### Why this matters for ML

In machine learning, we define a **loss function** that measures how wrong our model is. `grad` lets us compute exactly how to adjust each parameter to reduce that loss — that's the gradient, and it's the signal that drives learning.

In [5]:
def f(x): 
    return x ** 2

df = grad(f)
print(df(3.0))


6.0


In [6]:
def f(x):
    return x ** 3 + 2 * x ** 2 - 5 * x + 1

df = grad(f)        # first derivative
ddf = grad(grad(f))  # second derivative

x = 2.0
print(f"f({x})   = {f(x)}")
print(f"f'({x})  = {df(x)}")    # 3x^2 + 4x - 5 => 15
print(f"f''({x}) = {ddf(x)}")   # 6x + 4 => 16

f(2.0)   = 7.0
f'(2.0)  = 15.0
f''(2.0) = 16.0


### Exercise 4a
Define `g(x) = sin(x) * exp(-x)`. Compute its gradient at `x = 1.0`.

In [15]:
def g(x):
    return jnp.sin(x) * jnp.exp(-x)

x=1.0
print("Output: ", g(x))

dv = grad(g)
print("Derivative: ", dv(x))


Output:  0.3095599
Derivative:  -0.110793784


## 5. `jit` — Just-In-Time Compilation

### The problem: Python is slow

Python is an interpreted language, which means each operation is executed one at a time with lots of overhead. For numerical code with many operations, this overhead adds up fast.

### The solution: XLA compilation

When you wrap a function with `jit`, JAX doesn't run it immediately. Instead, it:

1. **Traces** the function — runs it once with abstract "placeholder" values to figure out what operations it performs.
2. **Compiles** the traced operations into a single optimized XLA program — fusing operations, eliminating redundant computation, and targeting your specific hardware (CPU/GPU/TPU).
3. **Caches** the compiled version — subsequent calls with the same input shapes skip tracing and run the optimized code directly.

### When to use `jit`

- Any function you call repeatedly with the same input shapes (e.g., a training step).
- Functions with many small operations that can be fused together.
- Inner loops of numerical algorithms.

### Gotchas to watch out for

- **First call is slow** — that's the tracing + compilation step. All subsequent calls are fast.
- **No Python side effects inside `jit`** — `print()` only runs during tracing, not on subsequent calls. Same for any Python-level if/else based on array values.
- **Input shapes must be static** — if you pass different-shaped inputs, JAX recompiles (slow). Use `jax.ensure_compile_time_eval()` or `static_argnums` for arguments that change but aren't arrays.
- **`block_until_ready()`** — JAX uses async dispatch, so timing benchmarks need this call to force the computation to finish before measuring.

In [17]:
def slow_fn(x):
    for _ in range(50):
        x = x @ x
    return x

fast_fn = jit(slow_fn)

mat = random.normal(random.PRNGKey(0), (100, 100))

%timeit slow_fn(mat).block_until_ready()
%timeit fast_fn(mat).block_until_ready()

2.55 ms ± 104 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.34 ms ± 12.2 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## 6. `vmap` — Auto-Vectorization

### The problem: batching is tedious

In ML, you almost always work with **batches** of data. Say you write a function that processes a single image — now you need it to work on 64 images at once. You have two bad options:

1. **Python loop** — `for img in batch: process(img)`. Works but extremely slow.
2. **Manual batching** — rewrite your function to handle an extra batch dimension everywhere. Works and is fast, but error-prone and clutters the code.

### The solution: `vmap`

`vmap(f)` takes a function `f` that works on a **single example** and returns a new function that works on a **batch of examples** — automatically. Under the hood, it transforms the function to operate over an extra leading axis, with no Python loops and no manual reshaping.

```python
# Works on a single vector
def normalize(x):
    return x / jnp.linalg.norm(x)

# Now works on a batch of vectors
batch_normalize = vmap(normalize)
```

### Key parameters

- **`in_axes`** — which axis of each input to map over. Default is `0` (first axis). Use `None` for arguments that shouldn't be batched.
  ```python
  # x is batched (axis 0), weights is shared across the batch
  vmap(f, in_axes=(0, None))(batch_x, weights)
  ```
- **`out_axes`** — which axis of the output the mapped dimension should appear on. Default is `0`.

### Why `vmap` matters

- **Clean code** — write single-example logic, get batch processing for free.
- **Performance** — `vmap` generates the same efficient batched code you'd write by hand.
- **Composability** — `vmap(vmap(f))` maps over two axes (e.g., batch of sequences of vectors).

In [None]:
def l2_norm(x):
    return jnp.sqrt(jnp.sum(x ** 2))

batch_l2 = vmap(l2_norm)

batch = random.normal(random.PRNGKey(1), (5, 3))
print("batch shape:", batch.shape)
print("norms:", batch_l2(batch))

### Exercise 6a
Write a function `dot_product(a, b)` for single vectors, then use `vmap` to compute dot products for a batch of 10 vector pairs (each of length 4).

In [None]:
# Your code here


## 7. Putting It Together — Simple Gradient Descent

Now we combine `grad` with a loop to do **gradient descent** — the fundamental optimization algorithm behind all of deep learning.

### How gradient descent works

1. Start with some initial guess for your parameter `x`.
2. Compute the **gradient** of your loss function at `x` — this tells you the direction of steepest *increase*.
3. Take a small step in the **opposite** direction (to decrease the loss): `x = x - lr * gradient`.
4. Repeat until the loss is small enough.

### The learning rate (`lr`)

The learning rate controls how big each step is:
- **Too large** — you overshoot the minimum and the loss explodes.
- **Too small** — convergence is painfully slow.
- **Just right** — smooth convergence to the minimum.

A common starting point is `0.01` or `0.1` for simple problems.

### What's happening in the code below

We minimize `f(x) = (x - 3)^2`, which has its minimum at `x = 3`. The gradient is `f'(x) = 2(x - 3)`. Starting from `x = 0`, each step nudges `x` closer to 3. JAX computes `f'(x)` for us automatically via `grad` — we never write the derivative by hand.

In [None]:
def loss(x):
    return (x - 3.0) ** 2

grad_loss = grad(loss)

x = 0.0
lr = 0.1

for i in range(20):
    x = x - lr * grad_loss(x)
    if i % 5 == 0:
        print(f"step {i:2d}: x = {x:.4f}, loss = {loss(x):.4f}")

print(f"\nFinal x: {x:.4f} (should be close to 3.0)")

### Exercise 7a — Linear Regression with Gradient Descent

This is the "real" version of what you just saw. Instead of optimizing a single number, you're optimizing a **vector of weights** `w` to fit a linear model `y = Xw`.

The loss function is the **mean squared error**: `f(w) = ||Xw - y||^2`

Steps:
- Generate random `X` (20x3) and `w_true` (3,), compute `y = X @ w_true`
- Start from random `w`, run gradient descent to recover `w_true`
- `grad` handles the vector calculus for you — it returns a gradient with the same shape as `w`

In [None]:
# Your code here


## Quick Reference

| Function | Purpose |
|----------|--------|
| `jnp.*` | NumPy-like array ops (immutable) |
| `grad(f)` | Auto-differentiation |
| `jit(f)` | XLA compilation for speed |
| `vmap(f)` | Auto-vectorize over batches |
| `random.PRNGKey(seed)` | Create explicit RNG key |
| `random.split(key, n)` | Split key into n sub-keys |
| `x.at[i].set(v)` | Immutable array update |