<a href="https://colab.research.google.com/github/HuaiyuZhang/DeepLearning/blob/main/JAX_quick_start.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Jax quick start

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

JAX even lets you just-in-time compile your own Python functions into **XLA-optimized** kernels using a one-function API. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.

In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

Matrix multiplication

In [3]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [4]:
# multiply jax array
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

16.2 ms ± 3.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
# multiply regular numpy array using jax function
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

54.3 ms ± 5.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
# That’s slower because it has to transfer data to the GPU every time. 
# You can ensure that an NDArray is backed by device memory using device_put().
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

18.4 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

815 µs ± 126 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:

- jit(), for speeding up your code

- grad(), for taking derivatives

- vmap(), for automatic vectorization or batching.

## jit()

In [8]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

1.1 ms ± 196 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [9]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

139 µs ± 5.75 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## grad(): autograd

In [10]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
print(x_small)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0. 1. 2.]
[0.25       0.19661194 0.10499357]


In [11]:
# mathematical verification
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [12]:
# nest git and grad
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


## vmap(): auto vectorization

In [13]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100)) # 10-batch vectors

def apply_matrix(v):
  return jnp.dot(mat, v)

In [14]:
def naively_batched_apply_matrix(v_batched):
  # loop over batch
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
3.68 ms ± 152 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
84.2 µs ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [17]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
82.2 µs ± 12.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
