# JAX Tutorial
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

- JAX can automatically differentiate native Python and NumPy code. It can differentiate through a large subset of Python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives of derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

- JAX uses XLA to compile and run NumPy code on accelerators, like GPUs and TPUs.
    - Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed.
    - JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API.
    - Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.

- XLA (Accelerated Linear Algebra) is an open-source compiler for machine learning. The XLA compiler takes models from popular frameworks such as PyTorch, TensorFlow, and JAX, and optimizes the models for high-performance execution across different hardware platforms including GPUs, CPUs, and ML accelerators.

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.key(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

71.5 ms ± 1.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on CPU

74.4 ms ± 864 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


The output of `device_put()` still acts like an NDArray, but it only copies values back to the CPU when they’re needed for printing, plotting, saving to disk, branching, etc. The behavior of `device_put()` is equivalent to the function `jit(lambda x: x)`, but it’s faster.

In [5]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

70.6 ms ± 2.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Running on gpu with jit()
JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the @jit decorator to compile multiple operations together using XLA. Let’s try that.

In [6]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

1.01 ms ± 60 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

292 µs ± 5.16 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Auto-differentiation with grad()
In addition to evaluating numerical functions, we also want to transform them. One transformation is automatic differentiation. In JAX, just like in Autograd, you can compute gradients with the `grad()` function.

In [8]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [9]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569  0.10502338]


Taking derivatives is as easy as calling `grad()`. `grad()` and `jit()` compose and can be mixed arbitrarily. In the above example we jitted sum_logistic and then took its derivative. We can go further:

In [10]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


For more advanced autodiff, you can use `jax.vjp()` for reverse-mode vector-Jacobian products and `jax.jvp()` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

In [11]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

## Auto-vectorization with vmap()
JAX has one more transformation in its API that you might find useful: `vmap()`, the vectorizing map. It has the familiar semantics of **mapping a function along array axes**, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with jit(), it can be just as fast as adding the batch dimensions by hand.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap(). Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

In [12]:
mat = random.normal(key, (150, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

Given a function such as apply_matrix, we can loop over a batch dimension in Python, but usually the performance of doing so is poor.

In [None]:
batched_x = random.normal(key, (10, 100))

In [13]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
259 µs ± 2.76 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


We know how to batch this operation manually. In this case, jnp.dot handles extra batch dimensions transparently.

In [14]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
9.37 µs ± 64.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


However, suppose we had a more complicated function without batching support. We can use vmap() to add batching support automatically.

In [15]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
13.3 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## jax.numpy v.s. jax.lax
Key Concepts:
- `jax.numpy` is a high-level wrapper that provides a familiar interface.
- `jax.lax` is a lower-level API that is stricter and often more powerful.

All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.
- If you look at the source of jax.numpy, you’ll see that all the operations are eventually expressed in terms of functions defined in jax.lax.
- You can think of jax.lax as a stricter, but often more powerful, API for working with multi-dimensional arrays.
    - For example, while jax.numpy will implicitly promote arguments to allow operations between mixed data types, jax.lax will not:

In [22]:
import jax.numpy as jnp
from jax import lax
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.

Array(2., dtype=float32, weak_type=True)

In [23]:
# lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.

In [24]:
lax.add(jnp.float32(1), 1.0)

Array(2., dtype=float32)

In [25]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)

Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

In [27]:
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]

Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

## To JIT or not to JIT
Key Concepts: **By default JAX executes operations one at a time, in sequence**
- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
- Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.
- For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of jax.numpy operations:

In [28]:
import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

A just-in-time compiled version of the function can be created using the jax.jit transform:

In [29]:
from jax import jit
norm_compiled = jit(norm)

This function returns the same results as the original, up to standard floating-point accuracy:

In [30]:
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)

True

But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of block_until_ready() to account for JAX’s asynchronous dispatch):

In [31]:
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()

264 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
256 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


That said, jax.jit does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation.

For example, this operation can be executed in op-by-op mode:

In [36]:
def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)

Array([-0.16529104, -0.5616348 , -0.02158209, -0.37602973, -0.82070136],      dtype=float32)

In [38]:
# jit(get_negatives)(x)

## JIT mechanics: tracing and static variables

Key Concepts: **JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type.**
- Variables that you don’t want to be traced can be marked as static
- To use jax.jit effectively, it is useful to understand how it works. Let’s put a few print() statements within a JIT-compiled function and then call the function:



In [39]:
@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1) # jax step
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([5.132693 , 5.024117 , 2.1233816], dtype=float32)

Notice that the print statements execute, but **rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them.**

- These tracer objects are what jax.jit uses to extract the **sequence of operations specified by the function**.
- Basic tracers are stand-ins that **encode the shape and dtype of the arrays**, but are **agnostic to the values**. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.

When we call the compiled function again on matching inputs, **no re-compilation is required** and nothing is printed because the result is computed in compiled XLA rather than in Python:
- The extracted sequence of operations is encoded in a JAX expression, or jaxpr for short. You can view the jaxpr using the jax.make_jaxpr transformation:

In [40]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)

Array([ 5.9691086, 13.150997 , 10.325281 ], dtype=float32)

In [41]:
from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)

{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

Note one consequence of this: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails:

In [43]:
# @jit
# def f(x, neg):
#   return -x if neg else x

# f(1, True)

If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:

Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:

In [44]:
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)

Array(-1, dtype=int32, weak_type=True)

In [45]:
f(1, False)

Array(1, dtype=int32, weak_type=True)

## Static vs Traced Operations
Key Concepts: **Just as values can be either static or traced, operations can be static or traced.**
- Static operations **are evaluated at compile-time** in Python
- Traced operations are compiled & evaluated at run-time in XLA.

Use numpy for operations that you want to be static; use jax.numpy for operations that you want to be traced. This distinction between static and traced values makes it important to think about how to keep a static value static.

Consider this function:

In [47]:
import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
# f(x)

This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let’s add some print statements to the function to understand why this is happening:

In [48]:
@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)

x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>


Notice that although x is traced, x.shape is a static value.

However, when we use jnp.array and jnp.prod on this static value, it becomes a traced value, at which point it cannot be used in a function like reshape() that requires a static input (recall: array shapes must be static).

A useful pattern is to use numpy for operations that should be static (i.e. done at compile-time), and use jax.numpy for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:

In [49]:
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)

Array([1., 1., 1., 1., 1., 1.], dtype=float32)