In [2]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

import numpy as onp  # original CPU-backed NumPy
from jax import device_put

Generating random normal data. All random functions must have a PRN key.

In [3]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


Multiplying a $3000 \times 3000$ matrix with its transpose. Notice that we can/should specify the dtype to be single precision.

In [4]:
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T).block_until_ready()  # runs on the GPU

9.09 ms ± 730 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
# JAX NumPy functions work on regular NumPy arrays.
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T).block_until_ready()

40.8 ms ± 607 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T).block_until_ready()

9.19 ms ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Running the same code with the original numpy.

In [9]:
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit onp.dot(x, x.T)

135 ms ± 887 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Using JIT

In [10]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

1.6 ms ± 647 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

111 µs ± 1.84 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Taking Derivatives with `grad`

In [12]:
def sum_logistic(x):
  return np.sum(1.0 / (1.0 + np.exp(-x)))

x_small = np.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [13]:
def first_finite_differences(f, x):
  eps = 1e-3
  return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in np.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [15]:
for v in np.eye(3):
    print(v)

[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]


## Auto-Vectorization with `vmap`

In [16]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

In [17]:
def apply_matrix(v):
  return np.dot(mat, v)

In [18]:
def naively_batched_apply_matrix(v_batched):
  return np.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
2.96 ms ± 255 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
@jit
def batched_apply_matrix(v_batched):
  return np.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
112 µs ± 855 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [20]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
111 µs ± 1.28 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [21]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def predict(W, b, inputs):
    return sigmoid(np.dot(inputs, W) + b)

In [22]:
inputs = np.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])

In [23]:
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -np.sum(np.log(label_probs))

In [24]:
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

In [25]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)

W_grad [-0.16965581 -0.8774646  -1.4901344 ]
W_grad [-0.16965581 -0.8774646  -1.4901344 ]
b_grad -0.29227242
W_grad [-0.16965581 -0.8774646  -1.4901344 ]
b_grad -0.29227242
