# JAX / NumPyro tutorial

- What is JAX?

- JAX basics

- Things no one tells you

- Sampling using NumPyro HMC and JAX

### NVIDIA stock
<img src="nvidia.png" alt="" style="">

## What is JAX?

JAX is a Python library for array-oriented numerical computation.

- NumPy-like interface which runs on CPU, GPU, or TPU.
- Features Just-In-Time (JIT) compilation
- Automatic vectorization.
- Automatic differentiation of functions.

## JAX basics

In [None]:
# Run this if you've never used JAX or NumPyro.

# !pip install numpyro
# !pip install jax
# !pip install corner

In [None]:
import jax
import jax.numpy as jnp
import jax.scipy as js
import jax.random as jr

# for comparison
import numpy as np
import scipy

# for sampling later
import numpyro
import numpyro.distributions as dist
numpyro.set_host_device_count(10)

# plotting packages
import corner
import matplotlib.pyplot as plt
plt.style.use('dark_background')

### Single-precision

In [None]:
# by default JAX uses single-precision (float32)
# jax.config.update('jax_enable_x64', True)
print(np.cos(np.pi / 3))
print(jnp.cos(jnp.pi / 3))

### NumPy-like arrays

In [None]:
numpy_array = np.array([4.3, -7.3, 0.001])
jax_array = jnp.array([4.3, -7.3, 0.001])
print(type(numpy_array))
print(type(jax_array))

In [None]:
print(type(np.array(jax_array)))
print(type(jnp.array(numpy_array)))

In [None]:
array1 = jnp.array([1., -2.5, 0.6])
array2 = jnp.array([5.3, 5.002, -11.5])
array1 + array2

In [None]:
jnp.inner(array1, array2)

In [None]:
jnp.outer(array1, array2)

In [None]:
jnp.linalg.eig(jnp.outer(array1, array2))

In [None]:
# other common arrays
a = jnp.zeros((3, 3))       # 2D array of zeros
b = jnp.ones((3, 3, 5))     # 3D array of ones
c = jnp.arange(0, 10, 2)    # [0, 2, 4, 6, 8]
d = jnp.linspace(0, 1, 5)   # [0., 0.25, 0.5, 0.75, 1.]
e = jnp.eye(4)              # 4x4 identity matrix

In [None]:
# array slicing
b[3, 1]
c[0:4]

In [None]:
# broadcasting
f = (a[..., None] * b)[..., None] / jnp.outer(c, d)
print(f.shape)

In [None]:
# jax.random uses PRNG keys
random_seed = 150914
a_PRNG_key = jr.key(random_seed)
random_numbers = jr.normal(key=a_PRNG_key, shape=(100, 100))

# need many PRNG keys?
many_PRNG_keys = jr.split(key=a_PRNG_key, num=50)

### Some things are different...

In [None]:
# JAX arrays are expensive to create!!!
%timeit np.array([5, 5, 5])
%timeit jnp.array([5, 5, 5])

In [None]:
# Can you modify JAX arrays? No!

numpy_array = np.array([0, 1, 2, 3, 4, 5])
jax_array = jnp.array([0, 1, 2, 3, 4, 5])

numpy_array[3] = 0
print(numpy_array)

# jax_array[3] = 0    # error!
# # jax_array = jax_array.at[3].set(0)
# print(jax_array)

# # also try...
# numpy_array[2] += 1
# # jax_array[2] += 1   # error!
# jax_array = jax_array.at[2].add(1)

In [None]:
%timeit numpy_array[3] = 0
%timeit jax_array.at[3].set(0)

In [None]:
# Can we speed up JAX array modification?
arr = jnp.arange(10, dtype=jnp.float32)
zeros = jnp.zeros(5, dtype=jnp.float32)
%timeit arr.at[-5:10].set(0.)
%timeit jnp.concatenate((arr[:-5], zeros), axis=0)

#### Takeaways:

- JAX arrays operate similar to NumPy arrays.
- JAX arrays are expensive to create and modify!

## Just-In-Time (JIT) compilation

Suppose we have the function $f: \mathbb{R}^2 \rightarrow \mathbb{R}$,

\begin{equation*}
    f(x, y) = \text{Tr}\big[(\mathbf{A} + x\mathbf{I})^\text{T}\,(\mathbf{A} + y\mathbf{I})^3\big]\,,
\end{equation*}

where $\mathbf{A}$ is a constant matrix and $\mathbf{I}$ is the identity.

In [None]:
def numpy_func(x, y, A):
    M = (A + x * np.eye(A.shape[0])).T @ (A + y * np.eye(A.shape[0]))**3
    return np.trace(M)

def jax_func(x, y, A):
    M = (A + x * jnp.eye(A.shape[0])).T @ (A + y * jnp.eye(A.shape[0]))**3
    return jnp.trace(M)

# some inputs
x = 2.5
y = -3.9
jax_A = jr.normal(jr.key(200129), (100, 100))
numpy_A = np.array(jax_A)

print(numpy_func(x, y, numpy_A))
print(jax_func(x, y, jax_A))

In [None]:
%timeit numpy_func(x, y, numpy_A)
%timeit jax_func(x, y, jax_A)

In [None]:
# JIT JAX likelihood
fast_jax_func = jax.jit(jax_func)

# test
print(fast_jax_func(x, y, jax_A))

In [None]:
%timeit fast_jax_func(x, y, jax_A)

### Can you JIT everything? No...

In [None]:
# return True if even, False if odd
def is_even(integer):
    if integer % 2 == 0:  # even
        return True
    else:  # odd
        return False
    
fast_is_even = jax.jit(is_even)

# print(is_even(6))
print(fast_is_even(6))

In [None]:
# conditionals work-around
def spit_True():
    return True

def spit_False():
    return False

def is_even_redo(integer):
    even = (integer % 2 == 0)
    return jax.lax.cond(even, spit_True, spit_False)

fast_is_even_redo = jax.jit(is_even_redo)
print(fast_is_even_redo(7))

In [None]:
%timeit is_even(6)
%timeit fast_is_even_redo(6)

#### Takeaways:

- JIT compile expensive functions which are called many times (e.g. likelihood evaluations)
- You can't JIT compile anything...

## Automatic vectorization

In [None]:
# suppose we want to evaluate the function over many inputs
num_evaluations = 100
many_x = jr.normal(jr.key(230814), (num_evaluations,))
many_y = jr.normal(jr.key(250114), (num_evaluations,))
jax_A = jr.normal(jr.key(170817), (10, 10))

In [None]:
%timeit [fast_jax_func(x, y, jax_A) for x, y in zip(many_x, many_y)]

In [None]:
# we can automatically vectorize JAX functions
vectorized_jax_func = jax.jit(jax.vmap(fast_jax_func, in_axes=(0, 0, None)))

# test
print(vectorized_jax_func(many_x, many_y, jax_A))

In [None]:
%timeit vectorized_jax_func(many_x, many_y, jax_A)

## Automatic differentiation

Suppose $f(x) = x\cos(3x),\hspace{1.5mm} g(x) = e^{x^2}\,\ln(x), \hspace{1.5mm}\text{and}\hspace{1.5mm} h(x) = \cos(x)e^{\tan(x)}$. The $\textit{analytic}$ derivative of the product,

\begin{equation*}
    \frac{d}{dx}\bigg[f(x)\,g(x)\,h(x)\bigg] = f'\, g\, h + f\, g'\, h + f\, g\, h'\,,
\end{equation*}

grows geometrically due to the product and chain rule.

Numerical derivatives, like finite differencing, are computationally expensive in high-dimensions, and must be tuned to be numerically stable.

Automatic differentiation computes derivatives with directed acyclic graphs (DAGs). This is $\textit{fast}$ and $\textit{exact}$, even for high-dimensional functions. Consider the example function,

\begin{equation*}
    f(x_1, x_2) = x_1 \,\text{exp}\bigg[-\frac{1}{2}(x_1^2 + x_2^2)\bigg]\,.
\end{equation*}

<img src="graph.png" alt="Negative" style="filter: invert(1);">

In [None]:
# we can take derivatives of JAX functions
grad_jax_func = jax.jit(jax.grad(fast_jax_func, argnums=(0, 1)))

# test
print(grad_jax_func(x, y, jax_A))

In [None]:
%timeit fast_jax_func(x, y, jax_A)
%timeit grad_jax_func(x, y, jax_A)

In [None]:
# try these
# jax.hessian
# jax.jacobian
# jax.jit(jax.vmap(jax.hessian(...)))

## Best practices

- Initialize and store as many jax.arrays as possible at the start of your code. Avoid creating/modifying jax.arrays in code which is executed many times.

- JIT compile functions.

- Compose auto-diff, auto-vec, and JIT wrappers.

- Break up your code into JAX blocks, and non-JAX blocks. Use JAX for heavy repeated computations (e.g. likelihood evaluations), NumPy for bookkeeping (e.g. modifying chains).

- It’s easier to start a project in JAX from the beginning than to add it later.

## Sampling with HMC in NumPyro

For a target density, $\pi(\mathbf{q})$, from which we want to sample, Hamiltonian Monte Carlo (HMC) defines a Hamiltonian,

\begin{align*}
    H(\mathbf{q}, \mathbf{p}) &= T(\mathbf{p}) + V(\mathbf{q}) \\
    &= T(\mathbf{p}) - \ln\pi(\mathbf{q})\,.
\end{align*}

If we start from a sample $\mathbf{q}_0$ (and randomized initial momentum $\mathbf{p}_0$), subsequent samples are proposed by integrating Hamilton's equations,

\begin{align*}
    \dot{q}_i = &\frac{\partial H}{\partial p_i} \\
    \dot{p}_i = -&\frac{\partial H}{\partial q_i}
\end{align*}

for some period of time. After integrating for some time to points $(\mathbf{q}_\text{final}, \mathbf{p}_\text{final})$, the acceptance probability is,

\begin{equation*}
    \alpha = \text{min}\bigg(1,\,\frac{\text{exp}[-H(\mathbf{q}_\text{final}, \mathbf{p}_\text{final})]}{\text{exp}[-H(\mathbf{q}_0, \mathbf{p}_0)]}\bigg)\,.
\end{equation*}

#### Pros
- Scales very well with high-dimension (easily samples 1000s of dimensions)
- Finds a (local) peak very quickly
- Long jump proposals, i.e. low auto-correlation in chain

#### Cons
- Need partial derivatives of target density (not a con if you're using JAX)
- Struggles with multi-modal distributions
- Difficult to mix in other proposals in current implementations


Animation of HMC: https://chi-feng.github.io/mcmc-demo/app.html

Start by defining a likelihood in JAX.

In [None]:
jax.config.update('jax_enable_x64', False)

In [None]:
r_unit = 1.
def ring_lnlike(x, y):
    r = jnp.sqrt(x**2 + y**2)
    return -5. * (r - r_unit)**2

vectorized_ring_lnlike = jax.vmap(ring_lnlike, in_axes=(0, 0))

def mobius_ladder_lnlike(xs, coupling_weights):
    n = xs.shape[0]

    # rails of ladder
    rail_terms = vectorized_ring_lnlike(xs, jnp.roll(xs, 1)) + vectorized_ring_lnlike(xs, jnp.roll(xs, -1))

    # rungs of ladder
    rung_terms = vectorized_ring_lnlike(xs, jnp.roll(xs, n // 2))
    
    total_lnlike_val = jnp.sum((rail_terms - 0.1 * rung_terms) * coupling_weights)
    return total_lnlike_val

fast_mobius_ladder_lnlike = jax.jit(mobius_ladder_lnlike)

In [None]:
# application of auto-differentiation
example_x_input = jr.normal(jr.key(190521), (1000,))
example_coupling_weights = jnp.ones_like(example_x_input)

grad_mobius_ladder_lnlike = jax.jit(jax.grad(fast_mobius_ladder_lnlike, argnums=(0)))
print(grad_mobius_ladder_lnlike(example_x_input, example_coupling_weights))

In [None]:
%timeit fast_mobius_ladder_lnlike(example_x_input, example_coupling_weights)

In [None]:
%timeit grad_mobius_ladder_lnlike(example_x_input, example_coupling_weights)

In [None]:
get_negative_Fisher = jax.jit(jax.hessian(fast_mobius_ladder_lnlike, argnums=(0)))
print(get_negative_Fisher(example_x_input, example_coupling_weights))

A fun hierarchical model,

\begin{equation*}
    p(\mathbf{x}, \sigma, k | d) \propto p(d | \mathbf{x}) \cdot p(\mathbf{x} | \sigma) \cdot p(\sigma | k) \cdot p(k)\,,
\end{equation*}

where

\begin{align*}
    d | \mathbf{x} \,\sim & \,\text{``sum of rings (whose covariance is topologically a M\"{o}bius ladder)"} \\
    \mathbf{x} | \sigma \,\sim & \,\mathcal{N}(0, \sigma^2) \\
    \sigma | k \,\sim & \,\chi^2(k) \\
    k \,\sim & \,\text{Uniform(0, 100)}\,. \\
\end{align*}

In [None]:
# NumPyro sampling model
def model(ndim, coupling_weights):
    
    # hyper-hyper-prior
    k = numpyro.sample('k', dist.Uniform(0., 100.))

    # hyper-prior
    sigma = numpyro.sample('sigma', dist.Chi2(k))

    # prior
    x = numpyro.sample('x', dist.Normal(0., sigma).expand((ndim,)))

    # likelihood
    numpyro.factor('lnlike', fast_mobius_ladder_lnlike(x, coupling_weights))

In [None]:
# define No U-Turn Sampling kernel
nuts_kernel = numpyro.infer.NUTS(model=model)

# set-up MCMC
mcmc = numpyro.infer.MCMC(sampler=nuts_kernel,
                          num_warmup=int(5e3),
                          num_samples=int(1e5),
                          num_chains=10)

# run MCMC
ndim = 8
# ndim = 100  # try this...
coupling_weights = 3.5 * jnp.sin(jnp.pi * np.arange(ndim) / ndim)
mcmc.run(jr.key(170817), ndim, coupling_weights)

# save chain
samples = mcmc.get_samples()

In [None]:
# plot distribution on low-level parameters
x_samples = np.array(samples['x'])
x_labels = np.array([rf'$x_{{{i}}}$' for i in range(1, ndim + 1)])
fig = corner.corner(x_samples,
                    labels=x_labels,
                    bins=40,
                    label_kwargs={'fontsize': 14})

In [None]:
# # plot distribution on hyper-parameters
# fig = corner.corner(np.array([samples['k'], samples['sigma']]).T,
#                     labels=[r'$k$', r'$\sigma$'],
#                     bins=40,
#                     range=[0.99]*2)