# 01 Just-In-Time Compilation

Original Documentation: https://docs.jax.dev/en/latest/jit-compilation.html


In [32]:
import jax
import jax.numpy as jnp

## How JAX transformations work

The JIT compiler reduces a JAX function into a sequence of primitive operations. Primitives are units of computation; most functions in `jax.lax` represent a primitive.

`jax.make_jaxpr` can show the sequence of primitives in a JAX function:


In [33]:
global_list = []


def log2(x):
    global_list.append(x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2


print(jax.make_jaxpr(log2)(3.0))

{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = log a
    c[35m:f32[][39m = log 2.0:f32[]
    d[35m:f32[][39m = div b c
  [34;1min [39;22m(d,) }


Notice that the side-effect of appending to the `global_list` is not captured.

JAX functions must be pure functions (note: `jax.experimental.io_callback()` does allow calling side-effects at the cost of performance).

When tracing a function, JAX wraps each argument in a Tracer object. The Tracer objects record all JAX operations performed on them. Then, JAX uses these records to reconstruct the entire function to produce the jaxpr.

For example, since `print()` is a side-effect, any calls to it will only happen at trace-time and will not appear in the jaxpr:


In [34]:
def log2_with_print(x):
    print("printed x:", x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2


print(jax.make_jaxpr(log2_with_print)(3.0))

printed x: JitTracer<~float32[]>
{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = log a
    c[35m:f32[][39m = log 2.0:f32[]
    d[35m:f32[][39m = div b c
  [34;1min [39;22m(d,) }


Notice that `x` is a `Traced` object. We can also use it while debugging to print out intermediate values of a computation.

JAX will only capture the function based on the execution at trace-time. For example, if we have a Python conditional, the jaxpr will only know about the branch we take:


In [35]:
def log2_if_rank_2(x):
    if x.ndim == 2:
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2
    else:
        return x


print(jax.make_jaxpr(log2_if_rank_2)(jnp.array([1, 2, 3])))

{ [34;1mlambda [39;22m; a[35m:i32[3][39m. [34;1mlet[39;22m  [34;1min [39;22m(a,) }


Since we took the else branch at trace-time, the jaxpr only knows to immediately return the first argument.

## JIT compiling a function

Example of JIT compiling a `selu`:


In [36]:
import time


def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


x = jnp.arange(1_000_000)
start = time.time()
selu(x).block_until_ready()
print(time.time() - start)

start = time.time()
jax.jit(selu)(x).block_until_ready()
print(time.time() - start)

0.0027778148651123047
0.048252105712890625


Before we JIT compile it, we are sending one operation at a time to the accelerator, which limits the XLA compiler’s ability to optimize anything.

`jax.jit()` gives the XLA compiler all of the code upfront. On the first call, the XLA compiler will trace the function and emit jaxpr which is then compiled into efficient accelerator-specific instructions.

Any subsequent calls will use the efficient compiled accelerator instructions directly.

## Why can’t we JIT everything?

Cannot JIT functions that use Python conditionals or loops.


In [37]:
@jax.jit
def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i


print("Answer:", g(10, 20))

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /var/folders/r0/zlgwr7551fg116vdpn_9s7ph0000gn/T/ipykernel_85568/222786207.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

Since the sequence of operations depends on a runtime value not known at trace-time so it is not possible to compile.

Traced values in JIT like `x` and `n` can only affect control-flow based on their static attributes like `shape` or `dtype`, not their values:


In [None]:
@jax.jit
def g(x, y):
    if x.dtype == jax.dtypes.bfloat16:
        return x
    else:
        return y


print("Answer:", g(10, 20))

Answer: 20


Generally, avoid conditionals on values. If conditionals are necessary, use something like `jax.lax.cond()` for conditionals on traced values.

However, sometimes even that is not enough. In that case, we can JIT compile part of the function. For example, if the computationally expensive portion of the function is the loop body, we could JIT compile that:


In [None]:
@jax.jit
def loop_body(prev_i):  # <PjitFunction of <function loop_body at 0x102c56340>>
    return prev_i + 1


def g(x, n):  # <function g at 0x10aa4c900>
    i = 0
    while i < n:
        i = loop_body(i)
    return x + i


print("Answer:", g(10, 20))

Answer: 30


## Marking arguments as static

If we really need to JIT compile a function that is dependent on a runtime value, we can mark the argument static which will wrap it in a less restrictive tracer.

The downside is that the emitted jaxpr is specific to the value passed (the value is treated as a constant in the jaxpr). For every new value of the static argument, JAX will recompile the function.

It is recommended only to use static arguments for inputs with a fixed set of possible values.


In [None]:
from functools import partial


@partial(jax.jit, static_argnames=("n"))
def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i


print("Answer:", g(10, 20))

Answer: 30


## JIT and caching

When a function marked with `@jax.jit`, it will be compiled and the emitted XLA will be cached. Subsequent calls to the function will use the cached XLA, so the overhead is amortized out.

If we specify `static_argnums` / `static_argnames`, only the XLA for the specific set of inputs to the static arguments will be cached. If any of the inputs to the static arguments changes, the function must be recompiled.

Do not call `jax.jit()` on temporary functions defined inside loops. Since the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined. This will recompile the function on each temporary definition.


In [None]:
def loop_body(i):
    return i + 1


def g_temporary_inner_function(x, n):
    i = 0
    while i < n:
        i = jax.jit(lambda x: loop_body(x))(i)  # Bad! Recompilation on each iteration.
    return x + i


def g_inner_function(x, n):
    i = 0
    while i < n:
        i = jax.jit(loop_body)(i)  # Good! Subsequent calls will hit XLA cache.
    return x + i


start = time.time()
g_temporary_inner_function(10, 20).block_until_ready()
print(time.time() - start)

start2 = time.time()
g_inner_function(10, 20).block_until_ready()
print(time.time() - start2)

0.22260403633117676
0.011249303817749023


The performance of the first function is significantly worse because it recompiles to XLA after each loop iteration. The second function hits the XLA cache after the second loop iteration.
