We're going to follow all of JAX's tutorials so that we don't have to hobble around it anymore.

Using Jax.jit() we can compile sequences of operations to be done using XLA. We can start by implementing a SiLU function

Jax allows for numpy like interface.

In [4]:
import jax.numpy as jnp

def selu(x, alpha=1.67, lmbda = 1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


In [5]:
from jax import random
key = random.key(69)
x = random.normal(key, (1000000,))
print(selu(x).block_until_ready)
%timeit selu(x).block_until_ready() # This seems to call a nanobinding on C++ below python ?

<nanobind.nb_bound_method object at 0x11fd86d40>
1.46 ms ± 28 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
from jax import random
from jax import jit
key = random.key(69)
x = random.normal(key, (1000000,))
print(selu(x).block_until_ready)
%timeit jit(selu)(x).block_until_ready() # This seems to call a nanobinding on C++ below python ?

<nanobind.nb_bound_method object at 0x11faf5380>
448 μs ± 9.85 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Damn, got that nice 50% speedup !

JAX is also about autograd, or methods that perform gradients for us on functions.

In [3]:
from jax import grad

def sum_logistic(x):
    return jnp.sum(1.0/ (1.0 + jnp.exp(-x))) # sigmoid time

x_small = jnp.arange(3.)
print(x_small)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0. 1. 2.]
[0.25       0.19661197 0.10499357]


In [4]:
import jax.numpy as jnp
print(jnp.eye(3))
# Checking that it's correct
def first_finite_diff(f, x, eps=1e-2):
    return jnp.array([(f(x + eps * v) - f(x - eps * v))/ (2*eps) for v in jnp.eye(len(x))])

print(first_finite_diff(sum_logistic, x_small))
print(jnp.eye(len(x_small)))
    

[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[0.25000572 0.19661188 0.10499954]
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]


Differentiation can go much further, with forward and reverse mode jacobians as well as hessians being available.

We can also use vmap to map functions. JAX transforms the function into a vector that is applied.

In [5]:
import jax 
import jax.random as random
key1, key2 = jax.random.split(jax.random.PRNGKey(9), 2)


mat = random.normal(key2, shape=(150, 100))
batched_x = random.normal(key1, shape=(10, 100))

def mat_mul(x):
    return jnp.dot(mat, x)

Naive version

In [7]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([mat_mul(v) for v in v_batched])

print("naive")
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

naive
409 μs ± 3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
@jax.jit
def vmapped(x):
    return jax.vmap(mat_mul)(x)


%timeit vmapped(batched_x).block_until_ready()


29.6 μs ± 297 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
