In [1]:
from typing import Tuple
from time import time as tt
import jax
import jax.numpy as jnp

# Vectorized function
Here is the function that we will derivate. This is a simple inversion of a tridiagonal system (find $X \in \mathbb R^n$ so that $MX = F$ where $M \in \mathbb R^{n \times n}$ is a tridiagonal matrix). The solver is detailed [here](https://fr.wikipedia.org/wiki/Matrice_tridiagonale). It takes the 3 diagonals of the matrix $M$ and the right hand side of the system $F$ in input and returns the solution $X$.

In [2]:
def tridiag_solve(a: jnp.ndarray, b: jnp.ndarray, c: jnp.ndarray, f: jnp.ndarray) -> jnp.ndarray:
    def forward_scan_scal(carry: Tuple[float, float], x: jnp.ndarray) -> Tuple[Tuple[float, float], Tuple[float, float]]:
        f_im1, q_im1 = carry
        a, b, c, f = x
        cff = 1./(b+a*q_im1)
        f_i = cff*(f-a*f_im1)
        q_i = -cff*c
        carry = f_i, q_i
        return carry, carry
    init = f[0]/b[0], -c[0]/b[0]
    xs = jnp.stack([a, b, c, f])[:, 1:].T
    _, (f, q) = jax.lax.scan(forward_scan_scal, init, xs)
    f = jnp.concat([jnp.array([init[0]]), f])
    q = jnp.concat([jnp.array([init[1]]), q])

    def reverse_scan_scal(carry: float, x: jnp.ndarray) -> Tuple[float, float]:
        q_rev, f_rev = x
        carry = f_rev + q_rev*carry
        return carry, carry
    init = f[-1]
    xs = jnp.stack([q[::-1], f[::-1]])[:, 1:].T
    _, x = jax.lax.scan(reverse_scan_scal, init, xs)
    x = jnp.concat([jnp.array([init]), x])

    return x[::-1]

# Usage
Simple call with random values

In [3]:
nz = 10
key_a, key_b, key_c, key_f = jax.random.split(jax.random.PRNGKey(0), 4)
a = jax.random.uniform(key_a, (nz,), minval=.5, maxval=1.5)
b = jax.random.uniform(key_b, (nz,), minval=2., maxval=3.)
c = jax.random.uniform(key_c, (nz,), minval=.5, maxval=1.5)
f = jax.random.uniform(key_f, (nz,))
tridiag_solve(a, b, c, f)


Array([ 0.03539155,  0.13971858,  0.30265936,  0.08421622, -0.24612431,
        0.43353987,  0.19535881, -0.02925097,  0.30980924,  0.25675532],      dtype=float32)

# Scalar function to derivate
Using this tridiagonal inversion, we create a function $\mathbb R^d \rightarrow \mathbb R$ that we will derivate after. This function builds the diagonals and the right hand side from the $d$ parameters and returns the $L^2$ norm of the solution of the tridiagonal system. We take $n=100$ for the end of the notebook.

In [4]:
nz = 100
def scal_fun(params: jnp.ndarray) -> float:
    n_tiles = nz//params.shape[0]+1
    a = jnp.tile(params, n_tiles)[:nz] + .5
    params_rolled = jnp.roll(params, shift=1)
    b = jnp.tile(params_rolled, n_tiles)[:nz] + 2
    params_rolled = jnp.roll(params, shift=2)
    c = jnp.tile(params_rolled, n_tiles)[:nz] + .5
    params_rolled = jnp.roll(params, shift=3)
    f = jnp.tile(params_rolled, n_tiles)[:nz]
    x = tridiag_solve(a, b, c, f)
    return jnp.linalg.norm(x)

# Gradient and jitification inside and outside
Here we jitify "inside" and "outside" the gradient, and we also create a gradient without jit.

In [5]:
grad_scal_fun = jax.grad(scal_fun)
grad_jit_scal_fun = jax.grad(jax.jit(scal_fun))
jit_grad_scal_fun = jax.jit(jax.grad(scal_fun))

# Consistency
We check that these functions computes the same thing.

In [13]:
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, (5, ))
print(f'grad=grad(jit) : {jnp.all(grad_scal_fun(x) == grad_jit_scal_fun(x))}')
print(f'grad=jit(grad) : {jnp.all(grad_scal_fun(x) == jit_grad_scal_fun(x))}')

grad=grad(jit) : True
grad=jit(grad) : True


# Time benchmark
Here we compare the time execution of these 3 versions of the gradient. The function `block_until_ready` is used to be sure that JAX computes the gradient (sometimes if a variable is not used later, it is note computed). We run one time every function before the time benchmark to not take into account the compilation cost. And between every function benchmark, we remove the cache to be sure that some inner compilation doesn't help the following functions to be faster.

In [10]:
n_trials = 100
n_par = 5

key = jax.random.PRNGKey(0)
all_x = jax.random.uniform(key, (n_trials+1, n_par))

# no jit
jax.clear_caches()
jax.block_until_ready(grad_scal_fun(all_x[0, :]))
times = 0
for i_trial in range(n_trials):
    t = tt()
    jax.block_until_ready(grad_scal_fun(all_x[i_trial+1, :]))
    times += tt()-t
print(f'no jit : {times/n_trials}s')

# grad(jit)
jax.clear_caches()
jax.block_until_ready(grad_jit_scal_fun(all_x[0, :]))
times = 0
for i_trial in range(n_trials):
    old_key, key = jax.random.split(key, 2)
    x = jax.random.uniform(old_key, (n_par,))
    t = tt()
    jax.block_until_ready(grad_jit_scal_fun(all_x[i_trial+1, :]))
    times += tt()-t
print(f'grad(jit) : {times/n_trials}s')

# jit(grad)
jax.clear_caches()
jax.block_until_ready(jit_grad_scal_fun(all_x[0, :]))
times = 0
for i_trial in range(n_trials):
    old_key, key = jax.random.split(key, 2)
    x = jax.random.uniform(old_key, (n_par,))
    t = tt()
    jax.block_until_ready(jit_grad_scal_fun(all_x[i_trial+1, :]))
    times += tt()-t
print(f'jit(grad) : {times/n_trials}s')

no jit : 0.2117520809173584s
grad(jit) : 0.000748436450958252s
jit(grad) : 0.00021117210388183594s


# Conclusion
- almost mandatory to jit a gradient for calculation cost
- better to jit "outside"