In [0]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

In [0]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [0]:
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T).block_until_ready()  # runs on the GPU

100 loops, best of 3: 7.09 ms per loop


In [0]:
import numpy as onp  # original CPU-backed NumPy
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T).block_until_ready()

10 loops, best of 3: 35.2 ms per loop


In [0]:
from jax import device_put

x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T).block_until_ready()

100 loops, best of 3: 7.15 ms per loop


In [0]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

The slowest run took 271.46 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 1.48 ms per loop


In [0]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

The slowest run took 525.91 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 217 µs per loop


In [0]:
def sum_logistic(x):
  return np.sum(1.0 / (1.0 + np.exp(-x)))

x_small = np.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [0]:
from jax import jacfwd, jacrev
def sum_logistic(x):
  return 1.0 / (1.0 + np.exp(-x))

x_small = np.arange(3.)
derivative_fn = jacrev(sum_logistic)
print(derivative_fn(x_small))

[[ 0.25       -0.         -0.        ]
 [-0.          0.19661197 -0.        ]
 [-0.         -0.          0.10499357]]


In [0]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return np.dot(mat, v)

def naively_batched_apply_matrix(v_batched):
  return np.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

@jit
def batched_apply_matrix(v_batched):
  return np.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
The slowest run took 31.89 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 3: 4.34 ms per loop
Manually batched
The slowest run took 945.08 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 150 µs per loop
Auto-vectorized with vmap
The slowest run took 145.28 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 131 µs per loop
