- Why Jax ?
  - Automatically differentiate native python and numpy code to run on CPU, GPU and TPU 
  - Jax uses XLA to compile and run numpy code on hardware acclerators.
    - jax also allows you to just in time compile your python functions into XLA optimized kernels
    

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random, device_put
import numpy as onp

In [3]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.59086424  0.73168874  0.56730247]


In [4]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

113 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
ox = onp.random.random((size, size))
%timeit onp.dot(x, x.T)

116 ms ± 1.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


- What in devil's name is this thing. How is it 10 times faster than numpy ?
  - Jax can run on GPU and is this 10 times faster

- Jax works on regular numpy arrays
  - but this is slower because jax has to copy numpy array into gpu

In [11]:
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

34.7 ms ± 878 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [15]:
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

20.9 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


- JAX is more than a GPU backend for numpy
  - jit(), for speeding up your code
    - JIT allows fast execution by converting intermediate code to native code and stores native code. 
    - JIT compiler runs with execution of program
  - grad(), for taking derivatives
  - vmap(), for automatic vectorization or batching.

In [4]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

1.77 ms ± 50.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [5]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

219 µs ± 18.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [6]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [10]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

first_finite_differences(sum_logistic, x_small)

DeviceArray([0.24998187, 0.1964569 , 0.10502338], dtype=float32)

In [39]:
mat = random.normal(key, (5000, 1000))
v = random.normal(key, (10, 1000))

@jit
def apply_matrix(v):
    return jnp.dot(v, mat.T)

def naive_apply_matrix(v):
    return jnp.stack([apply_matrix(r) for r in v])

@jit
def vmap_apply_matrix(v):
    return vmap(apply_matrix)(v)

In [40]:
%timeit apply_matrix(v)

582 µs ± 8.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [41]:
%timeit naive_apply_matrix(v)

6.38 ms ± 89.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [42]:
%timeit vmap_apply_matrix(v).block_until_ready()

602 µs ± 183 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


- Running on GPU, Jit and Vmap really speeds things up
