In [3]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Multiplying Matrices

In [4]:
key = random.PRNGKey(0)
x = random.normal(key,(10,))
print(x)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [5]:
size = 3000
x = random.normal(key,(size,size),dtype=jnp.float32)
%timeit jnp.dot(x,x.T).block_until_ready()

752 ms ± 30.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
import numpy as np

In [7]:
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

667 ms ± 53.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Notes

* jit() for speeding up your code
* grad() for taking derivatives
* vmap() for automatic vectorization or batching

## jit()

In [8]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda*jnp.where(x>0,x,alpha*jnp.exp(x)-alpha)

x = random.normal(key,(1000000,))
%timeit selu(x).block_until_ready()

8.56 ms ± 996 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
selu_jit =jit(selu)
%timeit selu_jit(x).block_until_ready()

1.51 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## grad()

In [10]:
def sum_logistic(x):
    return jnp.sum(1.0/(1.0+jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
%timeit derivative_fn(x_small)
jit_derivative_fn = jit(derivative_fn)
%timeit jit_derivative_fn(x_small)

22.3 ms ± 4.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.92 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [13]:
def first_finite_differences(f, x):
    eps = 1e-3
    return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [14]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


For more advance autodiff, you can use jax.vjp() for reverse-mode vector-Jacobian products and jax.jvp() for foward-mode Jacobian-vector products.

In [15]:
from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))

## Auto-vectorization with vmap()

In [26]:
mat = random.normal(key,(150,100))
batched_x = random.normal(key, (10,100))

def apply_matrix(v):
    return jnp.dot(mat,v)

In [27]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
5.92 ms ± 188 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched,mat.T)

print("Manually batched")
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
497 µs ± 74.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [29]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print("Auto-vectorized with vmap")
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
115 µs ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
