In [1]:
import jax.numpy as jnp

In [10]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


### Section 1: Four main transformation functions of Jax

#### Sec 1.1: jax.jit() -- to transform functions into just-in-time compiled versions

In [12]:
from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

6.14 ms ± 168 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)
%timeit selu_jit(x).block_until_ready()

1.33 ms ± 6.33 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [16]:
selu_jit1 = jit(selu)
%timeit selu_jit1(x).block_until_ready()

1.33 ms ± 9.39 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Sec 1.2: jax.grad() -- for evaluating the gradient function of the input function

In [17]:
from jax import grad

def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.0)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [24]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
print(grad(jit(grad(jit(grad(jit(sum_logistic))))))(1.0))

-0.0353256
-0.0353256


In [25]:
from jax import jacobian

print(jacobian(jnp.exp)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]


In [26]:
from jax import jacfwd, jacrev

def hessian(fn):
    return jit(jacfwd(jacrev(fn)))

print(hessian(sum_logistic)(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]


#### Sec 1.3: jax.vmap() -- for automatic vectorization of operations

In [27]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2,(10, 100))

def apply_matrix(v):
    return jnp.dot(mat,v)

def naively_batched_apply_matrix(batched_v):
    return jnp.stack([apply_matrix(v) for v in batched_v])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
1.8 ms ± 477 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
@jit
def batched_apply_matrix(batched_v):
    return jnp.dot(batched_v, mat.T)

print("Manually batched")
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
14.2 μs ± 520 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [29]:
from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_v):
    return vmap(apply_matrix)(batched_v)

print("Auto-vectorized with vmap")
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
25.5 μs ± 5.6 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### Sec 1.4: jax.pmap() -- for easy parallelization of computations