Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 63 additions & 113 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,34 @@ directly into the graph-theoretic representations supported by JAX.

Random number generation in JAX differs significantly from the patterns found in NumPy or MATLAB.

At first you might find the syntax rather verbose.

But the syntax and semantics are necessary to maintain the functional programming style we just discussed.

Moreover, full control of random state is essential for parallel programming,
such as when we want to run independent experiments along multiple threads.
### NumPy / MATLAB Approach

In NumPy / MATLAB, generation works by maintaining hidden global state.

```{code-cell} ipython3
np.random.seed(42)
print(np.random.randn(2))
```

Each time we call a random function, the hidden state is updated:

```{code-cell} ipython3
print(np.random.randn(2))
```

This function is *not pure* because:

* It's non-deterministic: same inputs, different outputs
* It has side effects: it modifies the global random number generator state

Dangerous under parallelization --- must carefully control what happens in each
thread!


### JAX

### Random number generation

In JAX, the state of the random number generator is controlled explicitly.

Expand Down Expand Up @@ -547,105 +566,30 @@ def gen_random_matrices(key, n=2, k=3):
key, subkey = jax.random.split(key)
A = jax.random.uniform(subkey, (n, n))
matrices.append(A)
print(A)
return matrices
```

```{code-cell} ipython3
seed = 42
key = jax.random.key(seed)
matrices = gen_random_matrices(key)
```

We can also use `fold_in` when iterating in a loop:

```{code-cell} ipython3
def gen_random_matrices(key, n=2, k=3):
matrices = []
for i in range(k):
step_key = jax.random.fold_in(key, i)
A = jax.random.uniform(step_key, (n, n))
matrices.append(A)
print(A)
return matrices
```

```{code-cell} ipython3
key = jax.random.key(seed)
matrices = gen_random_matrices(key)
```


### Why explicit random state?

Why does JAX require this somewhat verbose approach to random number generation?

One reason is to maintain pure functions.

Let's see how random number generation relates to pure functions by comparing NumPy and JAX.

#### NumPy's approach

In NumPy's legacy random number generation API (which mimics MATLAB), generation
works by maintaining hidden global state.

Each time we call a random function, this state is updated:

```{code-cell} ipython3
np.random.seed(42)
print(np.random.randn()) # Updates state of random number generator
print(np.random.randn()) # Updates state of random number generator
gen_random_matrices(key)
```

Each call returns a different value, even though we're calling the same function with the same inputs (no arguments).

This function is *not pure* because:

* It's non-deterministic: same inputs (none, in this case) give different outputs
* It has side effects: it modifies the global random number generator state


#### JAX's approach

As we saw above, JAX takes a different approach, making randomness explicit through keys.

For example,

```{code-cell} ipython3
def random_sum_jax(key):
key1, key2 = jax.random.split(key)
x = jax.random.normal(key1)
y = jax.random.normal(key2)
return x + y
```

With the same key, we always get the same result:

```{code-cell} ipython3
key = jax.random.key(42)
random_sum_jax(key)
```

```{code-cell} ipython3
random_sum_jax(key)
```
This function is *pure*

To get new draws we need to supply a new key.
* Deterministic: same inputs, same output
* No side effects: no hidden state is modified

The function `random_sum_jax` is pure because:

* It's deterministic: same key always produces same output
* No side effects: no hidden state is modified
### Benefits

The explicitness of JAX brings significant benefits:

* Reproducibility: Easy to reproduce results by reusing keys
* Parallelization: Each thread can have its own key without conflicts
* Debugging: No hidden state makes code easier to reason about
* Parallelization: Control what happens on separate threads
* Debugging: No hidden state makes code easier to test
* JIT compatibility: The compiler can optimize pure functions more aggressively

The last point is expanded on in the next section.


## JIT Compilation

Expand All @@ -655,17 +599,20 @@ efficient machine code that varies with both task size and hardware.
We saw the power of JAX's JIT compiler combined with parallel hardware when we
{ref}`above <jax_speed>`, when we applied `cos` to a large array.

Let's try the same thing with a more complex function:
Here we study JIT compilation for more complex functions


### With NumPy

We'll try first with NumPy, using

```{code-cell}
def f(x):
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
return y
```

### With NumPy

We'll try first with NumPy
Let's run with large `x`

```{code-cell}
n = 50_000_000
Expand All @@ -678,11 +625,20 @@ with qe.Timer():
y = f(x)
```

**Eager** execution model

* Each operation is executed immediately as it is encountered, materializing its
result before the next operation begins.

### With JAX
Disadvantages

Now let's try again with JAX.
* Minimal parallelization
* Heavy memory footprint --- produces many intermediate arrays
* Lots of memory read/write



### With JAX

As a first pass, we replace `np` with `jnp` throughout:

Expand Down Expand Up @@ -716,14 +672,15 @@ with qe.Timer():
The outcome is similar to the `cos` example --- JAX is faster, especially on the
second run after JIT compilation.

However, with JAX, we have another trick up our sleeve --- we can JIT-compile
the entire function, not just individual operations.
But we are still using eager execution --- lots of memory and read/write


### Compiling the Whole Function

The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array
operations into a single optimized kernel.
Fortunately, with JAX, we have another trick up our sleeve --- we can JIT-compile
the entire function, not just individual operations.

The compiler fuses all array operations into a single optimized kernel

Let's try this with the function `f`:

Expand All @@ -747,11 +704,11 @@ with qe.Timer():
jax.block_until_ready(y);
```

The runtime has improved again --- now because we fused all the operations,
allowing the compiler to optimize more aggressively.
The runtime has improved again --- now because we fused all the operations

For example, the compiler can eliminate multiple calls to the hardware
accelerator and the creation of a number of intermediate arrays.
* Aggressive optimization based on entire computational sequence
* Eliminates multiple calls to the hardware accelerator
* No creation of intermediate arrays

Incidentally, a more common syntax when targeting a function for the JIT
compiler is
Expand All @@ -777,16 +734,12 @@ subsequent calls with the same input shapes and types reuse the cached
compiled code and run at full speed.



### Compiling non-pure functions

Now that we've seen how powerful JIT compilation can be, it's important to
understand its relationship with pure functions.

While JAX will not usually throw errors when compiling impure functions,
execution becomes unpredictable.
execution becomes unpredictable!

Here's an illustration of this fact, using global variables:
Here's an illustration of this fact:

```{code-cell} ipython3
a = 1 # global
Expand Down Expand Up @@ -871,16 +824,13 @@ for row in X:
However, Python loops are slow and cannot be efficiently compiled or
parallelized by JAX.

Using `vmap` keeps the computation on the accelerator and composes with other
JAX transformations like `jit` and `grad`:
With `vmap`, we can avoid loops and keep the computation on the accelerator:

```{code-cell} ipython3
batch_mm_diff = jax.vmap(mm_diff)
batch_mm_diff(X)
batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version
batch_mm_diff(X) # Apply to each row of X
```

The function `mm_diff` was written for a single array, and `vmap` automatically
lifted it to operate row-wise over a matrix --- no loops, no reshaping.

### Combining transformations

Expand Down
Loading
Loading