In [1]:
import jax
import jax.numpy as jnp
from jax import lax

some basic operation with a matrix which takes a lot of memory, but most of it is redundant

In [2]:
def loop_body(_, params): # calculate 2*a in unnecessarily complicated way
    a, key = params
    huge_matrix = jnp.identity(10000)*jax.random.uniform(key, minval=-2, maxval=2)
    huge_matrix = a * huge_matrix
    huge_vector = jnp.array([1.0] + [0.0]*9999)
    a_new = huge_vector @ huge_matrix @ huge_vector.T
    key = jax.random.split(key)[0]
    return a_new, key

Run the cell below. Jax will build the whole graph and then optimize it (takes 16s on first run, 3.5s on second)

In [3]:
@jax.jit
def multiply(key):
    result = 1
    for i in range(200):
        result, key = loop_body(i, (result, key))
    return result

multiply(jax.random.PRNGKey(0))


Array(-3.969895e-37, dtype=float32)

In [4]:
multiply(jax.random.PRNGKey(0))

Array(-3.969895e-37, dtype=float32)

Run the cell below. Jax will compile the loop body, optimize it and then build the recursive graph (does not explode in RAM)

In [None]:
@jax.jit
def multiply(key):
    result, _ = lax.fori_loop(0, 200, loop_body, (1, key))

    return result

multiply(jax.random.PRNGKey(0))


Array(-3.969895e-37, dtype=float32)

In [8]:
multiply(jax.random.PRNGKey(0))

Array(-3.969895e-37, dtype=float32)