# An Introduction to JAX

----

#### John Stachurski

#### Prepared for the CBC Workshop (May 2024)

----


This lecture provides a short introduction to [Google JAX](https://github.com/google/jax).

What GPUs do we have access to?

In [None]:
!nvidia-smi

## Speed Test -- JAX vs NumPy

Let's start with some speed comparisons between NumPy and JAX.

(After that we'll look at learning how to use JAX more systematically.)


### Transformations

Let's evaluate the cosine function at 50 points.

In [None]:
x = np.linspace(0, 10, 50)
y = np.cos(x)

fig, ax = plt.subplots()
ax.scatter(x, y)
plt.show()

Now suppose we want to evaluate the cosine function at many points.

In [None]:
n = 50_000_000
x = np.linspace(0, 10, n)

### With NumPy

In [None]:
%time np.cos(x)

In [None]:
%time np.cos(x)

The next line of code frees some memory.

In [None]:
x = None

### With JAX

In [None]:
x_jax = jnp.linspace(0, 10, n)

Let's run the same operation on JAX

(The `block_until_ready()` method is explained a bit later.)

In [None]:
%time jnp.cos(x_jax).block_until_ready()

In [None]:
%time jnp.cos(x_jax).block_until_ready()

In [None]:
x_jax = None  # Free memory

### Evaluating a more complicated function

In [None]:
def f(x):
    y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2
    return y

In [None]:
fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
ax.plot(x, f(x))
ax.scatter(x, f(x))
plt.show()

Now let's try with a large array.

### With NumPy

In [None]:
n = 50_000_000
x = np.linspace(0, 10, n)

In [None]:
%time f(x)

In [None]:
%time f(x)

### With JAX

In [None]:
def f(x):
    y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
    return y

In [None]:
x_jax = jnp.linspace(0, 10, n)

In [None]:
%time f(x_jax).block_until_ready()

In [None]:
%time f(x_jax).block_until_ready()

### Compiling the Whole Function

In [None]:
f_jax = jax.jit(f)

In [None]:
%time f_jax(x_jax).block_until_ready()

In [None]:
%time f_jax(x_jax).block_until_ready()

## JAX as a NumPy Replacement

Now let's slow down and try to figure out how JAX works.

One way to use JAX is as a plug-in NumPy replacement. 

Let's look at the similarities and differences.

### Similarities


The following import is standard, replacing `import numpy as np`:

In [None]:
import jax
import jax.numpy as jnp

Now we can use `jnp` in place of `np` for the usual array operations:

In [None]:
a = jnp.array((1.0, 3.2, -1.5))

In [None]:
print(a)

In [None]:
print(jnp.sum(a))

In [None]:
print(jnp.mean(a))

In [None]:
print(jnp.dot(a, a))

In [None]:
print(a @ a)  # Equivalent

However, the array object `a` is not a NumPy array:

In [None]:
a

In [None]:
type(a)

Even scalar-valued maps on arrays return JAX arrays.

In [None]:
jnp.sum(a)

Operations on higher dimensional arrays are also similar to NumPy:

In [None]:
A = jnp.ones((2, 2))
B = jnp.identity(2)
A @ B

In [None]:
jnp.linalg.inv(B)   # Inverse of identity is identity

In [None]:
result = jnp.linalg.eigh(B)  # Computes eigenvalues and eigenvectors
result.eigenvalues

In [None]:
result.eigenvectors

### Differences

Let's now look at the differences between JAX and NumPy

#### 32 bit floats

JAX uses 32 bit floats by default.

If necessary we can enforce 64 bit floats via

In [None]:
jax.config.update("jax_enable_x64", True)

Let's check this works:

In [None]:
jnp.ones(3)

#### Mutability


For example, with NumPy we can write

In [None]:
import numpy as np
a = np.linspace(0, 1, 3)
a

and then mutate the data in memory:

In [None]:
a[0] = 1
a

In JAX this fails because arrays are immutable:

In [None]:
a = jnp.linspace(0, 1, 3)
a

In [None]:
a[0] = 1   # uncommenting produces a TypeError

Why???

## Random Numbers

Random numbers are also a bit different in JAX, relative to NumPy.  

(There are good reasons why, which we'll discuss later.)

### Controlling the state

In JAX, the state of the random number generator needs to be controlled explicitly.


First we produce a key, which seeds the random number generator.

In [None]:
key = jax.random.PRNGKey(1)

In [None]:
type(key)

In [None]:
print(key)

Now we can use the key to generate some random numbers:

In [None]:
x = jax.random.normal(key, (3, 3))
x

If we use the same key again, we initialize at the same seed, so the random numbers are the same:

In [None]:
jax.random.normal(key, (3, 3))

### Generating fresh draws

To produce a (quasi-) independent draw, we can use `split`

In [None]:
new_keys = jax.random.split(key, 5)   # Generate 5 new keys

In [None]:
len(new_keys)

In [None]:
for key in new_keys:
    print(jax.random.normal(key, (3, )))

## JIT compilation

The JAX just-in-time (JIT) compiler generates efficient, parallelized machine
code optimized for either the CPU or the GPU/TPU, depending on whether one of
these accelerators is detected.

### A first example

To see the JIT compiler in action, consider the following function.

In [None]:
def f(x):
    a = jnp.sin(x) + jnp.cos(x**2)
    return jnp.sum(a)

Let's build an array to call the function on.

In [None]:
n = 50_000_000
x = jnp.ones(n)

How long does the function take to execute?

In [None]:
%time f(x).block_until_ready()

What is `block_until_ready()` for?

If we run it a second time it becomes much faster:

In [None]:
%time f(x).block_until_ready()

Why?

### When does JAX recompile?

Let's run `f()` on new data:

In [None]:
y = jnp.ones(n + 1)

In [None]:
%time f(y).block_until_ready()

Notice that the execution time increases again --- why??!

(This wouldn't happen with Julia/Numba, which recompile only if we change the
*types* of variables in a function call.)

In [None]:
%time f(y).block_until_ready()

And now runtime goes down again --- why?

Note that the previous compiled versions are still available in memory
too, and the following call is dispatched to the correct compiled code.

In [None]:
%time f(x).block_until_ready()

### Compiling user-built functions

As we saw above, we can also instruct JAX to compile an entire user-defined function.

For example, consider

In [None]:
def g(x):
    y = jnp.zeros_like(x)
    for i in range(10):
        y += x**i
    return y

In [None]:
n = 1_000_000
x = jnp.ones(n)

Let's time it.

In [None]:
%time g(x).block_until_ready()

In [None]:
%time g(x).block_until_ready()

In [None]:
g_jit = jax.jit(g)   # target for JIT compilation

Let's run once to compile it:

In [None]:
g_jit(x)

And now let's time it.

In [None]:
%time g_jit(x).block_until_ready()

Why do we get a speed gain?

Incidentally, a more common syntax when targetting a function for the JIT
compiler is

In [None]:
@jax.jit
def g_jit_2(x):
    y = jnp.zeros_like(x)
    for i in range(10):
        y += x**i
    return y

In [None]:
%time g_jit_2(x).block_until_ready()

In [None]:
%time g_jit_2(x).block_until_ready()

## Functional Programming

From JAX's documentation:

*When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.*

In other words, JAX assumes a functional programming style.

The major implication is that JAX functions should be pure.
    
A pure function 

* will always return the same result if invoked with the same inputs (no dependence on external state) and
* have no side effects

### Examples: Python/NumPy/Numba style code is not generally pure

#### Example 1

Here's an example to show that NumPy functions are *not* pure:

In [None]:
np.random.randn(3)

In [None]:
np.random.randn(3)

This function returns the different results when called on the same inputs!

The issue is that the function maintains state between function calls --- the state of the random number generator.

In [None]:
np.random.get_state()[2]

In [None]:
np.random.randn(3)

In [None]:
np.random.get_state()[2]

#### Example 2

Is this guy pure?

In [None]:
a = 10
def f(x): 
    return a * x

In [None]:
f(1)

Now let's change the global:

In [None]:
a = 20

In [None]:
f(1)

#### Example 3

Is this guy pure?  Why? / Why not?

In [None]:
def change_input(x):   # Not pure -- side effects
    x[0] = 42
    return None

x = np.ones(5)
x

In [None]:
change_input(x)
x

### Compiling impure functions

JAX will *not* usually throw errors when compiling impure functions 

However, execution becomes unpredictable!

Here's an illustration of this fact, using global variables:

In [None]:
a = 1  # global

@jax.jit
def f(x):
    return a + x

In [None]:
x = jnp.ones(2)

In [None]:
f(x)

In the code above, the global value `a=1` is fused into the jitted function.

Even if we change `a`, the output of `f` will not be affected --- as long as the same compiled version is called.

In [None]:
a = 42

In [None]:
f(x)

Notice that the change in the value of `a` takes effect in the code below:

In [None]:
x = jnp.ones(3)

In [None]:
f(x)

Why?

#### Moral

Moral of the story: write pure functions when using JAX!

## Gradients

JAX can use automatic differentiation to compute gradients.

This can be extremely useful for optimization and solving nonlinear systems.

We will see significant applications later in this lecture series.

For now, here's a very simple illustration involving the function

In [None]:
def f(x):
    return (x**2) / 2

Let's take the derivative:

In [None]:
f_prime = jax.grad(f)

In [None]:
f_prime(10.0)

Let's plot the function and derivative, noting that $f'(x) = x$.

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
x_grid = jnp.linspace(-4, 4, 200)
ax.plot(x_grid, f(x_grid), label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend(loc='upper center')
plt.show()

## Writing vectorized code

Consider the function

$$
    f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2}
$$

Suppose that we want to evaluate this function on a square grid of $x$ and $y$ points.


### A slow version with loops

Here's a slow `for` loop version, which we run in a setting where `len(x) = len(y)` is very small.

In [None]:
@jax.jit
def f(x, y):
    return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)

n = 80
x = jnp.linspace(-2, 2, n)
y = x

z_loops = np.empty((n, n))

In [None]:
%%time
for i in range(n):
    for j in range(n):
        z_loops[i, j] = f(x[i], y[j])

Even for this small grid, the run time is very slow.

(Notice that we used a NumPy array for `z_loops` because we wanted to write to it.)

OK, so how can we do the same operation in vectorized form?

If you are new to vectorization, you might guess that we can simply write

In [None]:
z_bad = f(x, y)

But this gives us the wrong result because JAX doesn't understand the nested for loop.

In [None]:
z_bad.shape

In [None]:
z_loops.shape

### Vectorization attempt 1

To get the right shape and the correct nested for loop calculation, we can use a `meshgrid` operation that originated in MATLAB and was replicated in NumPy and then JAX:

In [None]:
x_mesh, y_mesh = jnp.meshgrid(x, y)

Now we get what we want and the execution time is fast.

In [None]:
z_mesh = f(x_mesh, y_mesh) 

Let's confirm that we got the right answer.

In [None]:
jnp.allclose(z_mesh, z_loops)

Now we can set up a serious grid and run the same calculation (on the larger grid) in a short amount of time.

In [None]:
n = 6000
x = jnp.linspace(-2, 2, n)
y = x
x_mesh, y_mesh = jnp.meshgrid(x, y)

In [None]:
%time z_mesh = f(x_mesh, y_mesh) 

In [None]:
%time z_mesh = f(x_mesh, y_mesh) 

But there is one problem here: the mesh grids use a lot of memory.

In [None]:
(x_mesh.nbytes + y_mesh.nbytes) / 1_000_000  # MB of memory

By comparison, the flat array `x` is just

In [None]:
x.nbytes / 1_000_000   # and y is just a pointer to x

This extra memory usage can be a big problem in actual research calculations.

In [None]:
del x_mesh  # Free memory
del y_mesh  # Free memory

### Vectorization attempt 2

We can achieve a similar effect through NumPy style broadcasting rules.

In [None]:
x_reshaped = jnp.reshape(x, (n, 1))   # Give x another dimension (column)
y_reshaped = jnp.reshape(y, (1, n))   # Give y another dimension (row)

When we evaluate $f$ on these reshaped arrays, we replicate the nested for loops in the original version.

In [None]:
%time z_reshaped = f(x_reshaped, y_reshaped)

In [None]:
%time z_reshaped = f(x_reshaped, y_reshaped)

Let's check that we got the same result

In [None]:
jnp.allclose(z_reshaped, z_mesh)

The memory usage for the inputs is much more moderate.

In [None]:
(x_reshaped.nbytes + y_reshaped.nbytes) / 1_000_000

### Vectorization attempt 3


There's another approach to vectorization we can pursue, using [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)

In [None]:
f = jax.vmap(f, in_axes=(None, 0))   # vectorize in y
f = jax.vmap(f, in_axes=(0, None))   # and then vectorize in x

With this construction, we can now call the function $f$ on flat (low memory) arrays.

In [None]:
%time z_vmap = f(x, y)

In [None]:
%time z_vmap = f(x, y)

Let's check we produce the correct answer:

In [None]:
jnp.allclose(z_vmap, z_mesh)

Let's finish by cleaning up.

In [None]:
del z_mesh
del z_vmap
del z_reshaped

### Exercises

Compute an approximation to $\pi$ by simulation:

1. draw $n$ observations of a bivariate uniform on the unit square
2. count the fraction that fall in the unit circle (radius 0.5) centered on (0.5, 0.5)
3. multiply the result by 4

Use JAX

In [None]:
for i in range(12):
    print("Solution below 🐠")

In [None]:
def approx_pi(n, key):
    u = jax.random.uniform(key, (2, n))
    distances = jnp.sqrt((u[0, :] - 0.5)**2 + (u[1, :] - 0.5)**2)
    fraction_in_circle = jnp.mean(distances < 0.5)
    return fraction_in_circle * 4  # dividing by radius**2

n = 1_000_000 # sample size for Monte Carlo simulation
key = jax.random.PRNGKey(1234)

In [None]:
%time approx_pi(n, key)

In [None]:
%time approx_pi(n, key)

In [None]:
approx_pi_jitted = jax.jit(approx_pi, static_argnums=(0,))

In [None]:
%time approx_pi_jitted(n, key)

In [None]:
%time approx_pi_jitted(n, key)

**Exercise**

In a previous notebook we used Monte Carlo to price a European call option and
constructed a solution using Numba.

The code looked like this:

In [None]:
import numba
from numpy.random import randn
M = 10_000_000

n, β, K = 20, 0.99, 100
μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0

@numba.jit(parallel=True)
def compute_call_price_parallel(β=β,
                                μ=μ,
                                S0=S0,
                                h0=h0,
                                K=K,
                                n=n,
                                ρ=ρ,
                                ν=ν,
                                M=M):
    current_sum = 0.0
    # For each sample path
    for m in numba.prange(M):
        s = np.log(S0)
        h = h0
        # Simulate forward in time
        for t in range(n):
            s = s + μ + np.exp(h) * randn()
            h = ρ * h + ν * randn()
        # And add the value max{S_n - K, 0} to current_sum
        current_sum += np.maximum(np.exp(s) - K, 0)
        
    return β**n * current_sum / M

Let's run it once to compile it:

In [None]:
compute_call_price_parallel()

And now let's time it:

In [None]:
%%time 
compute_call_price_parallel()

Try writing a version of this operation for JAX, using all the same
parameters.

If you are running your code on a GPU, you should be able to achieve
significantly faster execution.

In [None]:
for i in range(12):
    print("Solution below 🐠")

**Solution**

Here is one solution:

In [None]:
@jax.jit
def compute_call_price_jax(β=β,
                           μ=μ,
                           S0=S0,
                           h0=h0,
                           K=K,
                           n=n,
                           ρ=ρ,
                           ν=ν,
                           M=M,
                           key=jax.random.PRNGKey(1)):

    s = jnp.full(M, np.log(S0))
    h = jnp.full(M, h0)
    for t in range(n):
        key, subkey = jax.random.split(key)
        Z = jax.random.normal(subkey, (2, M))
        s = s + μ + jnp.exp(h) * Z[0, :]
        h = ρ * h + ν * Z[1, :]
    expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
        
    return β**n * expectation

Let's run it once to compile it:

In [None]:
compute_call_price_jax()

And now let's time it:

In [None]:
%%time 
compute_call_price_jax().block_until_ready()