In [1]:
import jax.numpy as jnp, numpy as np
from jax import grad, jit, vmap, random, config
config.update("jax_enable_x64", True)

## JAX Quick Start

[Reference](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)

### Multiplying Matrices

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10, 2))
x

Array([[ 0.05904905, -1.78900426],
       [ 1.22788155,  0.45910259],
       [ 0.23551187,  0.78353811],
       [-1.1250817 ,  0.74172367],
       [ 0.54964903, -1.47316439],
       [-0.28765989, -1.28873064],
       [-0.02191405,  1.6508528 ],
       [-0.388559  ,  2.28720945],
       [-0.30024257, -0.80271294],
       [-0.76154052,  0.16121649]], dtype=float64)

In [3]:
size = 3000
x = random.normal(key, (size, size))
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

134 ms ± 53.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
x = np.random.normal(size=(size, size)).astype(np.float64)
%timeit jnp.dot(x, x.T).block_until_ready()

148 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


The output of `device_put()` still acts like an NDArray, but it only copies values back to the CPU when they’re needed for printing, plotting, saving to disk, branching, etc.

In [5]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

1.61 ms ± 7.61 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### JIT

In [6]:
def selu(x, alpha = 1.67, lamb = 1.05):
    return lamb * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

370 µs ± 30.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

184 µs ± 5.26 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### GRAD

In [8]:
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661193 0.10499359]


- `jax.vjp()`: reverse-mode vector-Jacobian products
- `jax.jvp()`: forward-mode Jacobian-vector products

In [9]:
from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))

### VMAP

In [10]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [11]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched


1.2 ms ± 67.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [12]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched


113 µs ± 1.48 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap


114 µs ± 700 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
