# JAX tutorial

- What is JAX?

- JAX basics

- Things no one tells you

- Sampling using NumPyro HMC and JAX

## What is JAX?

JAX is a library for array-oriented numerical computation.

- NumPy-like interface which runs on CPU, GPU, or TPU.
- Features Just-In-Time (JIT) compilation
- Automatic differentiation of functions.
- Automatic vectorization.

## JAX basics

In [None]:
# Run this if you've never used JAX or NumPyro.

# !pip install numpyro
# !pip install jax

In [None]:
import jax
import jax.numpy as jnp
import jax.scipy as js
import jax.random as jr

# for comparison
import numpy as np
import scipy

# for sampling later
import numpyro
import numpyro.distributions as dist
numpyro.set_host_device_count(10)

# plotting packages
import corner
import matplotlib.pyplot as plt
plt.style.use('dark_background')

### Single-precision

In [None]:
# by default JAX uses single-precision (float32)
# jax.config.update('jax_enable_x64', True)
print(np.cos(np.pi / 3))
print(jnp.cos(jnp.pi / 3))

### NumPy-like arrays

In [None]:
array1 = jnp.array([1., -2.5, 0.6])
array2 = jnp.array([5.3, 5.002, -11.5])

print(array1 + array2)
print()

print(jnp.inner(array1, array2))
print()

a = jnp.zeros((3, 3))       # 2D array of zeros
b = jnp.ones((3, 3, 5))     # 3D array of ones
c = jnp.arange(0, 10, 2)    # [0, 2, 4, 6, 8]
d = jnp.linspace(0, 1, 5)   # [0., 0.25, 0.5, 0.75, 1.]
e = jnp.eye(4)              # 4x4 identity matrix

In [None]:
# index slicing and broadcasting
print(b[3, 1])
print(c[1:3])

f = (a[..., None] * b)[..., None] / jnp.outer(c, d)
print(f.shape)

## Bayesian hierarchical linear regression

We perform linear regression on a number of groups, each with the same number of data points and noise distribution. The coefficients of the linear regression are modeled hierarchically with a normal distribution.

\begin{equation}
    p(\vec{m}, \vec{y}_0, \mu_m, \sigma_m, \mu_{y0}, \sigma_{y0} | \vec{d}) \propto p(\vec{d} | \vec{m}, \vec{y}_0) \cdot p(\vec{m} | \mu_m, \sigma_m) \cdot p(\vec{y}_0 | \mu_{y0}, \sigma_{y0}) \cdot p(\mu_m, \sigma_m, \mu_{y0}, \sigma_{y0})
\end{equation}

In [None]:
# simulate data

# data simulation random key
simulation_key = jr.key(170817)
slope_simulation_key, intercept_key, noise_key = jr.split(simulation_key, 3)

# number of groups (how many linear regressions)
num_groups = 1000

# how many samples per linear regression
Nt = 10
times = jnp.linspace(0., 1., Nt)

# hyper-parameters injected
slopes_mean = 3.6
slopes_stdev = 3.77
y_inter_mean = -1.1
y_inter_stdev = 4.9
hypers_inj = jnp.array([slopes_mean, slopes_stdev, y_inter_mean, y_inter_stdev])
num_hypers = hypers_inj.shape[0]

# linear regression parameters injected
slopes_inj = slopes_mean + jr.normal(key=slope_simulation_key, shape=(num_groups,)) * slopes_stdev
y_intercepts_inj = y_inter_mean + jr.normal(key=intercept_key, shape=(num_groups,)) * y_inter_stdev

# all injected parameters
x_inj = jnp.concatenate((hypers_inj, slopes_inj, y_intercepts_inj))

# noise distribution
noise_stdev = 7.2

# lines injected into data
lines_inj = slopes_inj[:, None] * times[None, :] + y_intercepts_inj[:, None]

# add noise to data
noise = jr.normal(key=noise_key, shape=(num_groups, Nt)) * noise_stdev
data = lines_inj + noise

# later on we will demonstrate automatic vectorization, so let's make many slope and intercept inputs
num_batch = 1000
slopes_stacked = slopes_mean + jr.normal(key=slope_simulation_key, shape=(num_batch, num_groups)) * slopes_stdev
y_intercepts_stacked = y_inter_mean + jr.normal(key=intercept_key, shape=(num_batch, num_groups)) * y_inter_stdev

## Just-In-Time (JIT) compilation

In [None]:
# likelihood function in NumPy
# need to convert objects to NumPy arrays
# (they are defined as JAX arrays)
data_np = np.array(data)
times_np = np.array(times)
slopes_inj_np = np.array(slopes_inj)
y_intercepts_inj_np = np.array(y_intercepts_inj)

def numpy_lnlike(slopes, y_intercepts, data):
    # residuals
    lines = slopes[:, None] * times_np[None, :] + y_intercepts[:, None]
    residuals = data - lines
    return -0.5 * np.sum(residuals**2.) / noise_stdev**2.

# likelihood function in JAX
def jax_lnlike(slopes, y_intercepts, data):
    # residuals
    lines = slopes[:, None] * times[None, :] + y_intercepts[:, None]
    residuals = data - lines
    return -0.5 * jnp.sum(residuals**2.) / noise_stdev**2.

In [None]:
%timeit numpy_lnlike(slopes_inj_np, y_intercepts_inj_np, data_np)

In [None]:
%timeit jax_lnlike(slopes_inj, y_intercepts_inj, data)

In [None]:
# JIT JAX likelihood
fast_lnlike = jax.jit(jax_lnlike)

# test
print(fast_lnlike(slopes_inj, y_intercepts_inj, data))

In [None]:
%timeit fast_lnlike(slopes_inj, y_intercepts_inj, data)

## Automatic vectorization

In [None]:
# suppose we want to evaluate the likelihood for many inputs
slopes_stacked_np = np.array(slopes_stacked)
y_intercepts_stacked_np = np.array(y_intercepts_stacked)

In [None]:
%timeit [numpy_lnlike(slopes, y_intercepts, data) for slopes, y_intercepts in zip(slopes_stacked_np, y_intercepts_stacked_np)]

In [None]:
# we can automatically vectorize JAX functions
vectorized_lnlike = jax.jit(jax.vmap(fast_lnlike, in_axes=(0, 0, None)))

# test
print(vectorized_lnlike(slopes_stacked, y_intercepts_stacked, data))

In [None]:
%timeit vectorized_lnlike(slopes_stacked, y_intercepts_stacked, data)

## Automatic differentiation

Suppose want the partial derivatives of

\begin{equation*}
    f(x_1, x_2) = x_1 \,\text{exp}\bigg[-\frac{1}{2}(x_1^2 + x_2^2)\bigg]\,.
\end{equation*}

<img src="computational_graph.png" alt="Negative" style="filter: invert(1);">

<img src="computational_graph_forward.png" alt="Negative" style="filter: invert(1);">

In [None]:
# we can take derivatives of JAX functions
grad_lnlike = jax.jit(jax.grad(fast_lnlike, argnums=(0, 1)))

# test
print(grad_lnlike(slopes_inj, y_intercepts_inj, data))

In [None]:
%timeit grad_lnlike(slopes_inj, y_intercepts_inj, data)

In [None]:
# Fisher matrices!
neg_Fisher = jax.jit(jax.hessian(fast_lnlike, argnums=(0, 1)))
print(neg_Fisher(slopes_inj, y_intercepts_inj, data))

In [None]:
%timeit neg_Fisher(slopes_inj, y_intercepts_inj, data)

In [None]:
# try yourself
# jax.hessian
# jax.jacobian
# jax.jit(jax.vmap(jax.hessian(...)))

## JAX difficulties

What no one tells you...

In [None]:
# return True if even, False if odd
def is_even(integer):
    if integer % 2 == 0:  # even
        return True
    else:
        return False
    
fast_is_even = jax.jit(is_even)

# print(is_even(6))
# print(fast_is_even(6))

In [None]:
# conditionals work-around
def spit_True():
    return True

def spit_False():
    return False

def is_even_redo(integer):
    even = (integer % 2 == 0)
    return jax.lax.cond(even, spit_True, spit_False)

fast_is_even_redo = jax.jit(is_even_redo)
# print(is_even_redo(7))
# print(fast_is_even_redo(6))

In [None]:
# %timeit is_even(6)

In [None]:
# %timeit fast_is_even_redo(6)

In [None]:
# changing elements of an array
arr = np.array([0, 1, 2])
arr[2] = 0
print(arr)

# jax_arr = jnp.array([0, 1, 2])
# jax_arr[2] = 0
# # jax_arr = jax_arr.at[2].set(0)
# print(jax_arr)

In [None]:
# %timeit arr[2] = 0

In [None]:
# %timeit jax_arr.at[2].set(0)

In [None]:
# %timeit np.array([5, 5, 5])

In [None]:
# %timeit jnp.array([5, 5, 5])

## Best practices

- JIT compile functions.

- Compose auto-diff, vectorization, and JIT wrappers.

- Avoid if/else statements when possible.

- Initialize all constant jax.arrays at first. Avoid creating/updating jax.arrays in code which is run many times over.

- Don't mix JAX and NumPy within a function.

- It’s easier to start a project in JAX from the beginning than to add it later.

## Sampling with NumPyro

In [None]:
def sampling_model(data):

    # hyper-prior
    slopes_mean = numpyro.sample('slopes_mean', dist.Uniform(-10., 10.))
    slopes_stdev = numpyro.sample('slopes_stdev', dist.Uniform(0., 10.))
    intercepts_mean = numpyro.sample('intercepts_mean', dist.Uniform(-10., 10.))
    intercepts_stdev = numpyro.sample('intercepts_stdev', dist.Uniform(0., 10.))

    # prior
    slopes = numpyro.sample('slopes', dist.Normal(slopes_mean, slopes_stdev).expand((num_groups,)))
    intercepts = numpyro.sample('intercepts', dist.Normal(intercepts_mean, intercepts_stdev).expand((num_groups,)))

    # likelihood
    numpyro.factor('lnlike', fast_lnlike(slopes, intercepts, data))


nuts_kernel = numpyro.infer.NUTS(model=sampling_model)
mcmc = numpyro.infer.MCMC(sampler=nuts_kernel,
                          num_warmup=int(5e3),
                          num_samples=int(1e4),
                          num_chains=1,
                          )
mcmc.run(jr.key(150914), data)
samples_dict = mcmc.get_samples()

In [None]:
# organize samples
hyper_samples = np.array([samples_dict['slopes_mean'],
                          samples_dict['slopes_stdev'],
                          samples_dict['intercepts_mean'],
                          samples_dict['intercepts_stdev']]).T
samples = np.concatenate((hyper_samples,
                          samples_dict['slopes'],
                          samples_dict['intercepts']), axis=1)
hyper_labels = np.array([r'$\mu_m$', r'$\sigma_m$', r'$\mu_{y_0}$', r'$\sigma_{y_0}$'])
slope_labels = np.array([rf'$m_{{{i}}}$' for i in range(1, num_groups + 1)])
intercept_labels = np.array([rf'$y0_{{{i}}}$' for i in range(1, num_groups + 1)])
labels = np.concatenate((hyper_labels, slope_labels, intercept_labels))

In [None]:
# trace plot
for i in range(3):
    plt.plot(samples[:, i], color=f'C{i}', alpha=0.5)
    plt.axhline(x_inj[i], color=f'C{i}', alpha=0.8, label=labels[i])
plt.xlabel('HMC iteration')
plt.ylabel('parameter values')
plt.show()

In [None]:
# corner plot
ndxs = np.r_[:6, num_groups + 4 : num_groups + 6]
fig = corner.corner(samples[:, ndxs],
                    bins=40,
                    labels=labels[ndxs],
                    truths=x_inj[ndxs])

In [None]:
from emcee.autocorr import integrated_time

auto_corr_per_parameter = integrated_time(samples, has_walkers=False)
print(f'maximum auto-correlation = {np.max(auto_corr_per_parameter)}')
plt.figure(figsize=(10, 4))
plt.bar(np.arange(samples.shape[1]), auto_corr_per_parameter)
plt.xlabel('parameter indices')
plt.ylabel('auto-correlation')
plt.show()

In [None]:
r_circ = 1.
def circular_lnlike(x, y):
    r = jnp.sqrt(x**2 + y**2)
    return -10. * (r - r_circ)**2

fast_circular_lnlike = jax.jit(circular_lnlike)

In [None]:
def circular_model():    
    x = numpyro.sample('x', dist.Uniform(-10., 10.))
    y = numpyro.sample('y', dist.Uniform(-10., 10.))
    numpyro.factor('lnlike', fast_circular_lnlike(x, y))

nuts_kernel = numpyro.infer.NUTS(model=circular_model)
mcmc = numpyro.infer.MCMC(sampler=nuts_kernel,
                          num_warmup=int(5e3),
                          num_samples=int(1e5),
                          num_chains=10,
                          )
mcmc.run(jr.key(1))
samples_dict = mcmc.get_samples()

In [None]:
xy_samples = np.array([samples_dict['x'], samples_dict['y']]).T
fig = corner.corner(xy_samples,
                    labels=['x', 'y'],
                    bins=40)

In [None]:
# compute Euclidian distance between all samples
Euc_dist = np.sqrt(np.sum((xy_samples[1:] - xy_samples[:-1])**2, axis=1))
plt.hist(Euc_dist, density=True, color='C0', alpha=0.8, bins=60)
plt.xlabel('Euclidean distance between successive samples')
plt.ylabel('density')
plt.show()