## Test the installation of JAX

In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


## Try jax with simple matrix multiplication

In [3]:
# multiple two matrices
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T)

275 ms ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# use regular NumPy arrays for the same computation
import numpy as np
x = np.random.normal(size=(size,size)).astype(np.float32)
%timeit jnp.dot(x, x.T)

252 ms ± 5.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Use *jit()* to speed up functions in jax

In [5]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

# call selu function without jit decorator
x = random.normal(key, (1000000,))
%timeit selu(x)

4.8 ms ± 99.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
# make selu function jit
selu_jit = jit(selu)
%timeit selu_jit(x)

515 µs ± 14.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
# another way to use jit, use decorator in function
@jit
def selu_with_decorator(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

%timeit selu_with_decorator(x)

572 µs ± 32.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Use *grad()* to take derivatives

In [11]:
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
print(x_small)

# derivative of this function
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

# take derivative of derivative
derivative2_fn = jit(grad(derivative_fn))
print(derivative2_fn(1.0))

[0. 1. 2.]
[0.25       0.19661197 0.10499357]
-0.09085775


## Use *vmap()* for auto-vectorization

In [13]:
# 3 ways to promote matrix-vector products into matrix-matrix products
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)

# (1). Naive 
def naive_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])
print('(1) Naively batched')
%timeit naive_batched_apply_matrix(batched_x)

(1) Naively batched
2.96 ms ± 57 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
# (2). Manually batched
@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, mat.T)

print('(2) Manually batched')
%timeit batched_apply_matrix(batched_x)

(2) Manually batched
12.5 µs ± 78 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [None]:
# (3). Auto-vectorized with vmap
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print('(3) Auto-vectorize batched with vmap')
%timeit vmap_batched_apply_matrix(batched_x)