## **JAX**

JAX is a library for array-oriented numerical computation, with automatic differentiation and JIT compilation to enable high-performance machine learning research

1. JAX provide a unified NumPy-like interface to computations that run on CPU, GPU or TPI, in local or distributed settings,
2. JAX features built-in Jut-in-Time (JIT) compilation, and open source machine learning compiler ecosystem.
3. JAX functions support efficient evalution of gradients via its automatic differentiation transformations.
4. JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs

In [1]:
import jax.numpy as jnp

With the above import, we can immediately start using JAX in a similar manner to NumPy

In [2]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


JAX works great for many numerical and scientific programs, but only if they are written with certain constraints, as explained in [tutorial_n.ipynb](#add_link_when_done)

### **Just-in-time compilation with `jax.jit()`**

JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above code, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the `jax.jit()` function to compile this sequence of operations together using XLA.


We can use python's `%timeit` to quickly benchmark our `selu` function, using `block_until_ready()` to account for JAX's dynamic dispatch. See [tutorial_async](#add_it_too) for more

In [18]:
from jax import random

key = random.key(135)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

845 μs ± 32.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


We can speed the execution time for this function with `jax.jit()` transformation, which will `jit-compile` the first time `selu` is called and it will be cached forever

In [19]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # warmup
%timeit selu_jit(x).block_until_ready()

342 μs ± 32.9 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


This is just the execution time on CPU, the same code can be run on GPU, or TPU, typically for even greater speedup