# Learning JAX

JAX is a Python library developed by Google Research for high-performance numerical computing, especially well-suited for machine learning research. It combines the familiar NumPy API with powerful transformations like automatic differentiation, JIT compilation, and vectorization, enabling you to write highly efficient code that can run on CPUs, GPUs, and TPUs.


This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:

- JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.

- JAX features built-in Just-In-Time (JIT) compilation via Open XLA[https://github.com/openxla], an open-source machine learning compiler ecosystem.

- JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.

- JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.

In [8]:
# !pip install jax

In [9]:
# !pip install -U "jax[cuda12]"

JAX as NumPy

In [10]:
import jax.numpy as jnp

In [11]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


Just-in-time compilation with jax.jit()

In [12]:
from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

3.18 ms ± 150 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()

345 μs ± 51.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Taking derivatives with jax.grad()

In [14]:
from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [15]:
def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569  0.10502338]


In [16]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


jax.jacobian()

In [17]:
from jax import jacobian
print(jacobian(jnp.exp)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]


For more advanced autodiff operations, you can use jax.vjp() for reverse-mode vector-Jacobian products, and jax.jvp() and jax.linearize() for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. For example, jax.jvp() and jax.vjp() are used to define the forward-mode jax.jacfwd() and reverse-mode jax.jacrev() for computing Jacobians in forward- and reverse-mode, respectively. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

In [18]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]


## Auto-vectorization with jax.vmap()
Another useful transformation is vmap(), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping over function calls, it transforms the function into a natively vectorized version for better performance. When composed with jit(), it can be just as performant as manually rewriting your function to operate over an extra batch dimension.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap(). Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

In [19]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

In [20]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
579 μs ± 17.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [21]:
import numpy as np

@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
16.7 μs ± 2.18 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [22]:
from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
20.9 μs ± 2.15 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
