# Jit with Jax

## How Jax transforms work

To transform python fuctions, Jax converts the function into an intermediate lanauge called jaxpr. The Jax transformations then work on the jaxpr representation of the function.

Jax does not deal with side effects and is functional in this manner so in the code below we would expect the jaxpr to ignore the global_list.append(x) as this is a side effect within the code, as it modifies some state variable value(s) outside its local environment, that is to say has an observable effect besides returning a value. In this case the modification of global_list.

In [1]:
import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


This is purely functional representation of the function provided to the jaxpr, ignoring the side effect.

Its important to note that global_list is interacted on in the first pass by the Jax tracer object, which is used to construct the entire function. However, the tracers do not record the side-effects so they do not appear in the jaxpr but they do happen in the trace itself.

In [2]:
global_list

[Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>]

**One should not rely on this as its strictly an implemtation detail**

An important thing to understand is that jaxpr jaxpr capture the function as executed on the parameters given. Therefore if there is a condtional jaxpr will only construct the lanague on the branch taken.

In [3]:
def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

{ lambda ; a:i32[3]. let  in (a,) }


In [4]:
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([[1.0, 2.0],[1.0,2.0]])))

{ lambda ; a:f32[2,2]. let
    b:f32[2,2] = log a
    c:f32[] = log 2.0
    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
    e:f32[2,2] = div b d
  in (e,) }


## JIT compiling a function

Jax allows for code to be run on the CPU/GPU/TPU
using exactly the same code.

Here we are going to look at the SELU operation:

$$ SELU(x) = \lambda \begin{cases}
    \mbox{$x$} & \mbox{if } x > 0\\
    \mbox{$\alpha e^x-\alpha$} & \mbox{if } x \leq 0
    \end{cases} $$

In [5]:
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

803 µs ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In this example the code is sending one operation at a time to the accelerator limimiting the XLA compilers ability to optimize the function.


For this code to be the most performant, we want to give the XLA compiler as much code as possible. For this jax provides the jax.jit transformation, this will jit a jax-compatible function, to speed it up.

In [6]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

90.1 µs ± 808 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


As we can see that is an order of mangnituded faster, so what have we done:

1. We compiled selu with jax.jit(selu) and called it selu_jit
2. We when ran selu_jit once on x, such that Jax can do the tracing. The jaxpr is then compiled using XLA into effiecent code. Now subsequent calls use the optimised compiled code.

Meaning that we no longer use the old python implemention at all.

*If we didn't include the warm-up call seperately, everything still works as expected, however the compilitation time would be included in the benchmark as well. It would still be faster but would not be a fair comparision.*