In [10]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [11]:
key = random.PRNGKey(0)
x = random.normal(key,(10,))
print(x)

[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.59086424  0.73168874  0.56730247]


In [14]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
print(x)

# we added that block_until_ready because JAX uses asynchronous execution by default

[[ 1.3890219e+00 -3.2292128e-01  1.5543434e-01 ...  1.6672325e-01
   1.0217547e+00  9.6981682e-02]
 [ 1.0637624e+00 -1.8089767e+00 -7.7910066e-02 ...  1.1778634e+00
  -4.3357384e-01 -2.7877539e-01]
 [-4.4029760e-01 -3.2537556e-01  2.7817249e-01 ...  6.8317264e-01
  -6.1108202e-01 -6.3071579e-01]
 ...
 [ 2.9218221e-01 -4.0055814e-01 -1.4978162e+00 ...  3.0673573e+00
  -1.1350130e+00  4.0964663e-01]
 [ 2.7635777e-01  1.5621802e-01  2.2996697e-03 ...  6.8930730e-02
  -4.0692575e-02  4.1683865e-01]
 [ 1.0231307e+00 -2.7423620e-01 -8.0369943e-01 ...  1.9415880e+00
   1.0946989e+00  2.1876075e+00]]


In [13]:
# JAX Numpy functions work on regular NumPy arrays
import numpy as np
x = np.random.normal(size=(size,size)).astype(np.float32)
%timeit jnp.dot(x,x.T).block_until_ready()

185 ms ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [15]:
# you can ensure that an NDArray is backed by device memory using device_put()
from jax import device_put

x = np.random.normal(size=(size,size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

190 ms ± 15.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
# using jit() to speed up functions
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x >0, x, alpha*jnp.exp(x)-alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

2.33 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
# we can speed it up with @jit, which will jit-compile the first time selu is called and will be cached thereafter
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

662 µs ± 18.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### taking derivatives with Grad()

In [19]:
def sum_logistic(x):
    return jnp.sum(1.0/(1.0+jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [20]:
def first_finite_difference(f,x):
    eps = 1e-3
    return jnp.array([(f(x+eps*v)-f(x-eps*v))/(2*eps) for v in jnp.eye(len(x))])
print(first_finite_difference(sum_logistic,x_small))

[0.24998187 0.1964569  0.10502338]


In [21]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.035325598


In [22]:
from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))

### Auto-vectorization with vmap()

In [23]:
mat = random.normal(key, (150,100))
batched_x = random.normal(key,(10,100))

def apply_matrix(v):
    return jnp.dot(mat, v)

In [24]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print("Naively batched")
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
1.47 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, mat.T)
print('manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

manually batched
12.2 µs ± 850 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [26]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)
print('auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

auto-vectorized with vmap
17.8 µs ± 892 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
