In [3]:
!pip install jax



# Jax as numpy

In [9]:
# jax as numpy
import jax.numpy as jnp

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(x)
print(selu(x))

[0. 1. 2. 3. 4.]
[0.        1.05      2.1       3.1499999 4.2      ]


# Just-in-time compilation with jax.jit()

In [7]:
# Without jit
from jax import random
key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

# With jit
from jax import jit
selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()

4.51 ms ± 668 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
982 µs ± 7.87 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Taking derivatives with jax.grad()

In [10]:
from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
print(x_small)
derivative_fn = grad(sum_logistic) # this is the gradient/derivative function of the function sum_logistic
print(derivative_fn(x_small)) # for each element in the array x_small, compute its value

[0. 1. 2.]
[0.25       0.19661197 0.10499357]


Let's verify with finite differences that our result is correct.

In [13]:
def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569  0.10502338]


The grad() and jit() transformations compose and can be mixed arbitrarily. For instance, while the sum_logistic function was differentiated directly in the previous example, it could also be JIT-compiled, and these operations can be combined. We can go further:

In [14]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


Full Jacobian matrix for vector-valued functions

In [15]:
from jax import jacobian
print(jacobian(jnp.exp)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]


For more advanced autodiff operations, you can use jax.vjp() for reverse-mode vector-Jacobian products, and jax.jvp() and jax.linearize() for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. For example, jax.jvp() and jax.vjp() are used to define the forward-mode jax.jacfwd() and reverse-mode jax.jacrev() for computing Jacobians in forward- and reverse-mode, respectively. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

In [16]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]


# Auto-vectorization with jax.vmap()

Another useful transformation is vmap(), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping over function calls, it transforms the function into a natively vectorized version for better performance. When composed with jit(), it can be just as performant as manually rewriting your function to operate over an extra batch dimension.

In [None]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))
print("key1:", key1)
print("key2:", key2)
print("mat:", mat)
print("batched_x:", batched_x)

def apply_matrix(x):
  return jnp.dot(mat, x)

The apply_matrix function maps a vector to a vector, but we may want to apply it row-wise across a matrix. We could do this by looping over the batch dimension in Python, but this usually results in poor performance.

In [20]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
2.93 ms ± 466 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


Instead of using matrix * vector and stack them up, we can directly compute matrix * matrix to get our final answer matrix.

In [21]:
import numpy as np

@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
37.6 µs ± 2.37 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone. The vmap() transformation is designed to automatically transform a function into a batch-aware version:

In [22]:
from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
50.4 µs ± 2.28 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


As you would expect, vmap() can be arbitrarily composed with jit(), grad(), and any other JAX transformation.