# An Introduction to JAX

----

#### John Stachurski

#### Prepared for the CBC Workshop (May 2024)

----


This lecture provides a short introduction to [Google JAX](https://github.com/google/jax).

What GPUs do we have access to?

In [None]:
!nvidia-smi

## JAX as a NumPy Replacement


One way to use JAX is as a plug-in NumPy replacement. Let's look at the
similarities and differences.

### Similarities


The following import is standard, replacing `import numpy as np`:

In [None]:
import jax
import jax.numpy as jnp

Now we can use `jnp` in place of `np` for the usual array operations:

In [None]:
a = jnp.array((1.0, 3.2, -1.5))

In [None]:
print(a)

In [None]:
print(jnp.sum(a))

In [None]:
print(jnp.mean(a))

In [None]:
print(jnp.dot(a, a))

In [None]:
print(a @ a)  # Equivalent

However, the array object `a` is not a NumPy array:

In [None]:
a

In [None]:
type(a)

Even scalar-valued maps on arrays return JAX arrays.

In [None]:
jnp.sum(a)

JAX arrays are also called "device arrays," where term "device" refers to a
hardware accelerator (GPU or TPU).

(In the terminology of GPUs, the "host" is the machine that launches GPU operations, while the "device" is the GPU itself.)



Operations on higher dimensional arrays are also similar to NumPy:

In [None]:
A = jnp.ones((2, 2))
B = jnp.identity(2)
A @ B

In [None]:
from jax.numpy import linalg

In [None]:
linalg.inv(B)   # Inverse of identity is identity

In [None]:
result = linalg.eigh(B)  # Computes eigenvalues and eigenvectors
result.eigenvalues

In [None]:
result.eigenvectors

### Differences

Let's now look at the differences between JAX and NumPy

#### 32 bit floats

One difference between NumPy and JAX is that JAX currently uses 32 bit floats by default.  

This is standard for GPU computing and can lead to significant speed gains with small loss of precision.

However, for some calculations precision matters.  In these cases 64 bit floats can be enforced via the command

In [None]:
jax.config.update("jax_enable_x64", True)

Let's check this works:

In [None]:
jnp.ones(3)

#### Mutability

As a NumPy replacement, a more significant difference is that arrays are treated as **immutable**.  

For example, with NumPy we can write

In [None]:
import numpy as np
a = np.linspace(0, 1, 3)
a

and then mutate the data in memory:

In [None]:
a[0] = 1
a

In JAX this fails:

In [None]:
a = jnp.linspace(0, 1, 3)
a

In [None]:
a[0] = 1   # uncommenting produces a TypeError

The designers of JAX chose to make arrays immutable because JAX uses a
functional programming style.  More on this below.  

#### Sneaky mutation

Note that, while mutation is discouraged, it is in fact possible with `at`, as in

In [None]:
a = jnp.linspace(0, 1, 3)
id(a)

In [None]:
a

In [None]:
a.at[0].set(1)

We can check that the array is mutated by verifying its identity is unchanged:

In [None]:
id(a)

In general it's better to avoid mutating arrays --- more discussion below.

## Random Numbers

Random numbers are also a bit different in JAX, relative to NumPy.  


### Controlling the state

Typically, in JAX, the state of the random number generator needs to be controlled explicitly.

(This is also related to JAX's functional programming paradigm, discussed below.  JAX does not typically work with objects that maintain state, such as the state of a random number generator.)

In [None]:
import jax.random as random

First we produce a key, which seeds the random number generator.

In [None]:
key = random.PRNGKey(1)

In [None]:
type(key)

In [None]:
print(key)

Now we can use the key to generate some random numbers:

In [None]:
x = random.normal(key, (3, 3))
x

If we use the same key again, we initialize at the same seed, so the random numbers are the same:

In [None]:
random.normal(key, (3, 3))

### Generating fresh draws

To produce a (quasi-) independent draw, we can use `split`

In [None]:
new_keys = random.split(key, 5)   # Generate 5 new keys

In [None]:
len(new_keys)

In [None]:
for key in new_keys:
    print(random.normal(key, (3, )))

Another function we can use to update the key is `fold_in`.

In [None]:
seed = 1234  # seed to generate new key from old
key = jax.random.fold_in(key, seed)

random.normal(key, (3, 1))

This is often used in loops -- here's an example that produces `k` (quasi-) independent random `n x n` matrices using this procedure and prints their determinants.

In [None]:
def gen_random_matrices(seed=1234, n=10, k=5):
    key = random.PRNGKey(seed)
    for i in range(k):
        key = random.fold_in(key, i)
        d = jnp.linalg.det(random.uniform(key, (n, n)))
        print(f"Determinant = {d:.4}")

gen_random_matrices()

## JIT compilation

The JAX just-in-time (JIT) compiler generates efficient, parallelized machine code optimized for either the CPU or the GPU/TPU, depending on whether one of these accelerators is detected.

### A first example

To see the JIT compiler in action, consider the following function.

In [None]:
def f(x):
    a = 3*x + jnp.sin(x) + jnp.cos(x**2) - jnp.cos(2*x) - x**2 * 0.4 * x**1.5
    return jnp.sum(a)

Let's build an array to call the function on.

In [None]:
n = 50_000_000
x = jnp.ones(n)

How long does the function take to execute?

In [None]:
%time f(x).block_until_ready()

(In order to measure actual speed, we use `block_until_ready()` method 
to hold the interpreter until the results of the computation are returned from
the device. This is necessary because JAX uses asynchronous dispatch, which
allows the Python interpreter to run ahead of GPU computations.)

The code doesn't run as fast as we might hope, given that it's running on a GPU.

But if we run it a second time it becomes much faster:

In [None]:
%time f(x).block_until_ready()

In [None]:
%timeit f(x).block_until_ready()

This is because the built in functions like `jnp.cos` are JIT compiled and the
first run includes compile time.

### When does JAX recompile?

You might remember that Numba recompiles if we change the types of variables in a function call.

JAX recompiles more often --- in particular, it recompiles every time we change array sizes.

For example, let's try

In [None]:
m = n + 1
y = jnp.ones(m)

In [None]:
%time f(y).block_until_ready()

Notice that the execution time increases, because now new versions of 
the built-ins like `jnp.cos` are being compiled, specialized to the new array
size.

If we run again, the code is dispatched to the correct compiled version and we
get faster execution.

In [None]:
%time f(y).block_until_ready()

Why does JAX generate fresh machine code every time we change the array size???

The compiled versions for the previous array size are still available in memory
too, and the following call is dispatched to the correct compiled code.

In [None]:
%time f(x).block_until_ready()

### Compiling user-built functions

We can instruct JAX to compile entire functions that we build.

For example, consider

In [None]:
def g(x):
    y = jnp.zeros_like(x)
    for i in range(10):
        y += x**i
    return y

In [None]:
n = 1_000_000
x = jnp.ones(n)

Let's time it.

In [None]:
%time g(x).block_until_ready()

In [None]:
%time g(x).block_until_ready()

In [None]:
g_jit = jax.jit(g)   # target for JIT compilation

Let's run once to compile it:

In [None]:
g_jit(x)

And now let's time it.

In [None]:
%time g_jit(x).block_until_ready()

Note the speed gain.

This is because 

1. the loop is compiled and
2. the array operations are fused and no intermediate arrays are created.


Incidentally, a more common syntax when targetting a function for the JIT
compiler is

In [None]:
@jax.jit
def g_jit_2(x):
    y = jnp.zeros_like(x)
    for i in range(10):
        y += x**i
    return y

In [None]:
%time g_jit_2(x).block_until_ready()

In [None]:
%time g_jit_2(x).block_until_ready()

#### Static arguments

Because the compiler specializes on array sizes, it needs to recompile code when array sizes change.

As a result, any argument that determines sizes of arrays should be flagged by `static_argnums` -- a signal that JAX can treat that variable as a compile-time constant (and recompile when it changes).

Here's a example.

In [None]:
def f(n, seed=1234):
    key = jax.random.PRNGKey(seed)
    x = jax.random.normal(key, (n, ))
    return x.std()

In [None]:
f(5)

In [None]:
f_jitted = jax.jit(f)

In [None]:
f_jitted(5)

Let's fix this:

In [None]:
f_jitted = jax.jit(f, static_argnums=(0, ))   # First argument is static
f_jitted(5)

## Functional Programming

From JAX's documentation:

*When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.*


In other words, JAX assumes a functional programming style.

The major implication is that JAX functions should be pure.
    
A pure function will always return the same result if invoked with the same inputs.

In particular, a pure function has

* no dependence on global variables and
* no side effects

### Examples: Python/NumPy/Numba style code is not pure

#### Example 1

Here's an example to show that NumPy functions are not pure:

In [None]:
np.random.randn()

In [None]:
np.random.randn()

This function returns the different results when called on the same inputs!

The issue is that the function maintains state between function calls --- the state of the random number generator.

#### Example 2

Here's a function that's not pure because it depends on a global

In [None]:
a = 10
def f(x): return a * x

f(1)

In [None]:
a = 20
f(1)

(Notice that the output of the function cannot be fully predicted from the inputs.)

#### Example 3

Here's a function that fails to be pure because it modifies external state.

In [None]:
def double_input(x):   # Not pure -- side effects
    x[:] = 2 * x
    return None

x = np.ones(5)
x

In [None]:
double_input(x)
x

Here's a pure version:

In [None]:
def double_input(x):
    y = 2 * x
    return y

#### Example 4

The following function is also not pure, since it modifies a global variable (similar to the last example).

In [None]:
a = 1
def f():
    global a
    a += 1
    return None

In [None]:
a

In [None]:
f()

In [None]:
a

### Compiling impure functions

JAX does not insist on pure functions.

For example, JAX will not usually throw errors when compiling impure functions 

However, execution becomes unpredictable!

Here's an illustration of this fact, using global variables:

In [None]:
a = 1  # global

@jax.jit
def f(x):
    return a + x

In [None]:
x = jnp.ones(2)

In [None]:
x

In [None]:
f(x)

In the code above, the global value `a=1` is fused into the jitted function.

Even if we change `a`, the output of `f` will not be affected --- as long as the same compiled version is called.

In [None]:
a = 42

In [None]:
f(x)

Notice that the change in the value of `a` takes effect in the code below:

In [None]:
x = jnp.ones(3)

In [None]:
f(x)

Can you explain why?

#### Moral

Moral of the story: write pure functions when using JAX!

## Gradients

JAX can use automatic differentiation to compute gradients.

This can be extremely useful for optimization and solving nonlinear systems.

We will see significant applications later in this lecture series.

For now, here's a very simple illustration involving the function

In [None]:
def f(x):
    return (x**2) / 2

Let's take the derivative:

In [None]:
f_prime = jax.grad(f)

In [None]:
f_prime(10.0)

Let's plot the function and derivative, noting that $f'(x) = x$.

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
x_grid = jnp.linspace(-4, 4, 200)
ax.plot(x_grid, f(x_grid), label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend(loc='upper center')
plt.show()

## Writing vectorized code

Writing fast JAX code requires shifting repetitive tasks from loops to array processing operations, so that the JAX compiler can easily understand the whole operation and generate more efficient machine code.

This procedure is called **vectorization** or **array programming**, and will be familiar to anyone who has used NumPy or MATLAB.

In some ways, vectorization is the same in JAX as it is in NumPy.

But there are also differences, which we highlight here.

As a running example, consider the function

$$
    f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2}
$$

Suppose that we want to evaluate this function on a square grid of $x$ and $y$ points.


### A slow version with loops

To clarify, here is the slow `for` loop version, which we run in a setting where `len(x) = len(y)` is very small.

In [None]:
@jax.jit
def f(x, y):
    return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)

n = 80
x = jnp.linspace(-2, 2, n)
y = x

z_loops = np.empty((n, n))

In [None]:
%%time
for i in range(n):
    for j in range(n):
        z_loops[i, j] = f(x[i], y[j])

Even for this very small grid, the run time is extremely slow.

(Notice that we used a NumPy array for `z_loops` because we wanted to write to it.)

OK, so how can we do the same operation in vectorized form?

If you are new to vectorization, you might guess that we can simply write

In [None]:
z_bad = f(x, y)

But this gives us the wrong result because JAX doesn't understand the nested for loop.

In [None]:
z_bad.shape

Here is what we actually wanted:

In [None]:
z_loops.shape

### Vectorization attempt 1

To get the right shape and the correct nested for loop calculation, we can use a `meshgrid` operation that originated in MATLAB and was replicated in NumPy and then JAX:

In [None]:
x_mesh, y_mesh = jnp.meshgrid(x, y)

Now we get what we want and the execution time is fast.

In [None]:
z_mesh = f(x_mesh, y_mesh) 

Let's confirm that we got the right answer.

In [None]:
jnp.allclose(z_mesh, z_loops)

Now we can set up a serious grid and run the same calculation (on the larger grid) in a short amount of time.

In [None]:
n = 6000
x = jnp.linspace(-2, 2, n)
y = x
x_mesh, y_mesh = jnp.meshgrid(x, y)

In [None]:
%%time
z_mesh = f(x_mesh, y_mesh) 

In [None]:
%%time
z_mesh = f(x_mesh, y_mesh) 

But there is one problem here: the mesh grids use a lot of memory.

In [None]:
(x_mesh.nbytes + y_mesh.nbytes) / 1_000_000  # MB of memory

By comparison, the flat array `x` is just

In [None]:
x.nbytes / 1_000_000   # and y is just a pointer to x

This extra memory usage can be a big problem in actual research calculations.

In [None]:
del x_mesh  # Free memory
del y_mesh  # Free memory

### Vectorization attempt 2

We can achieve a similar effect through NumPy style broadcasting rules.

In [None]:
x_reshaped = jnp.reshape(x, (n, 1))   # Give x another dimension (column)
y_reshaped = jnp.reshape(y, (1, n))   # Give y another dimension (row)

When we evaluate $f$ on these reshaped arrays, we replicate the nested for loops in the original version.

In [None]:
%time z_reshaped = f(x_reshaped, y_reshaped)

In [None]:
%time z_reshaped = f(x_reshaped, y_reshaped)

Let's check that we got the same result

In [None]:
jnp.allclose(z_reshaped, z_mesh)

The memory usage for the inputs is much more moderate.

In [None]:
(x_reshaped.nbytes + y_reshaped.nbytes) / 1_000_000

### Vectorization attempt 3


There's another approach to vectorization we can pursue, using [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)

It runs out that, when we are working with complex functions and operations, this `vmap` approach can be the easiest to implement.

It's also very memory parsimonious.

The first step is to vectorize the function `f` in `y`.

In [None]:
f_vec_y = jax.vmap(f, in_axes=(None, 0))  

In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`.

Next, we vectorize in the first argument, which is `x`.

In [None]:
f_vec = jax.vmap(f_vec_y, in_axes=(0, None))

Finally, we JIT-compile the result:

In [None]:
f_vec = jax.jit(f_vec)

With this construction, we can now call the function $f$ on flat (low memory) arrays.

In [None]:
%%time
z_vmap = f_vec(x, y)

In [None]:
%%time
z_vmap = f_vec(x, y)

Let's check we produce the correct answer:

In [None]:
jnp.allclose(z_vmap, z_mesh)

Let's finish by cleaning up.

In [None]:
del z_mesh
del z_vmap
del z_reshaped

### Exercises

Repeat the exercise of computing the approximation to $\pi$ by simulation:

1. draw $n$ observations of a bivariate uniform on the unit square
2. count the fraction that fall in the unit circle (radius 0.5) centered on (0.5, 0.5)
3. multiply the result by 4

Use JAX

In [None]:
for i in range(12):
    print("Solution below 🐠")

In [None]:
def approx_pi(n, key):
    u = jax.random.uniform(key, (2, n))
    distances = jnp.sqrt((u[0, :] - 0.5)**2 + (u[1, :] - 0.5)**2)
    fraction_in_circle = jnp.mean(distances < 0.5)
    return fraction_in_circle * 4  # dividing by radius**2

n = 1_000_000 # sample size for Monte Carlo simulation
key = jax.random.PRNGKey(1234)

In [None]:
%time approx_pi(n, key)

In [None]:
%time approx_pi(n, key)

In [None]:
approx_pi_jitted = jax.jit(approx_pi, static_argnums=(0,))

In [None]:
%time approx_pi_jitted(n, key)

In [None]:
%time approx_pi_jitted(n, key)

**Exercise**

In a previous notebook we used Monte Carlo to price a European call option and
constructed a solution using Numba.

The code looked like this:

In [None]:
import numba
from numpy.random import randn
M = 10_000_000

n, β, K = 20, 0.99, 100
μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0

@numba.jit(parallel=True)
def compute_call_price_parallel(β=β,
                                μ=μ,
                                S0=S0,
                                h0=h0,
                                K=K,
                                n=n,
                                ρ=ρ,
                                ν=ν,
                                M=M):
    current_sum = 0.0
    # For each sample path
    for m in numba.prange(M):
        s = np.log(S0)
        h = h0
        # Simulate forward in time
        for t in range(n):
            s = s + μ + np.exp(h) * randn()
            h = ρ * h + ν * randn()
        # And add the value max{S_n - K, 0} to current_sum
        current_sum += np.maximum(np.exp(s) - K, 0)
        
    return β**n * current_sum / M

Let's run it once to compile it:

In [None]:
compute_call_price_parallel()

And now let's time it:

In [None]:
%%time 
compute_call_price_parallel()

Try writing a version of this operation for JAX, using all the same
parameters.

If you are running your code on a GPU, you should be able to achieve
significantly faster execution.

In [None]:
for i in range(12):
    print("Solution below 🐠")

**Solution**

Here is one solution:

In [None]:
@jax.jit
def compute_call_price_jax(β=β,
                           μ=μ,
                           S0=S0,
                           h0=h0,
                           K=K,
                           n=n,
                           ρ=ρ,
                           ν=ν,
                           M=M,
                           key=jax.random.PRNGKey(1)):

    s = jnp.full(M, np.log(S0))
    h = jnp.full(M, h0)
    for t in range(n):
        key, subkey = jax.random.split(key)
        Z = jax.random.normal(subkey, (2, M))
        s = s + μ + jnp.exp(h) * Z[0, :]
        h = ρ * h + ν * Z[1, :]
    expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
        
    return β**n * expectation

Let's run it once to compile it:

In [None]:
compute_call_price_jax()

And now let's time it:

In [None]:
%%time 
compute_call_price_jax().block_until_ready()