# JAX

[JAX](https://jax.readthedocs.io/en/latest/index.html) is a Python library for high performance array computing. This notebook demonstrates some basic features of JAX. The [documentation](https://jax.readthedocs.io/en/latest/index.html) contains a lot more information.

## Automatic differentiation: autodiff

Use `jax.grad` to get the gradient of a function. The `jax.grad` takes a function and returns the gradient,

In [1]:
from jax import grad

def f(x):
    return x**2.

df_dx = grad(f)

print(df_dx(1.))

2.0


## `jax.numpy`

`JAX` does not work with standard `numpy` functions e.g. if we write a function which uses `np.cos` then try to get the gradient,

In [2]:
import numpy as np

def f(x):
    return np.cos(x)

In [3]:
df_dx = grad(f)

print(df_dx(0.))

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[]
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

we get a `TracerArrayConversionError`.

Instead, we can use the [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) library - this provides a drop-in replacement for most `numpy` functions.

In [4]:
import jax.numpy as jnp

def f(x):
    return jnp.cos(x)

df_dx = grad(f)

x = jnp.array(0.)
print(df_dx(x))

-0.0


Similarly [`jax.scipy`](https://jax.readthedocs.io/en/latest/jax.scipy.html) provides a replacement for most `scipy` functions.

## Random numbers in `JAX`

Setting a random seed allows you to produce reproducible random outputs.

In `numpy` you can set the random seed like this:

In [5]:
import numpy as np
np.random.seed(42) # where 42 can be any integer 

This does not with JAX. The equivalent JAX statement looks like this:

In [6]:
from jax import random

key = random.key(42)

This random key is passed as an input parameter to many `numpyro` functions.

## Other 

This notebook just scratches the surface of what you can do with JAX. 

Other useful features:

- [automatic vectorisation](https://jax.readthedocs.io/en/latest/automatic-vectorization.html#)
- [just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html)

Also see [The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) for common mistakes.