In [1]:
# Jax reduces each function into primitive operations in a language called jaxpr

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [3]:
import jax
import jax.numpy as jnp

In [4]:
global_list = []

def log2(x):
    global_list.append(x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    
    return ln_x / ln_2

In [6]:
print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [7]:
# Note that jaxpr doesn't capture the side effect of the function (only the purely functional part of the function).
# This is a feature. JAX transformations are designed to understand side-effect-free code (functionally pure).
# Impure functions are dangerous because under JAX transformations, they are not likely to behave as expected. 
# If you want print, use jax.debug.print() at the cost of performance.
# When tracing, jax wraps all arguments by a tracer object, that record all jax operations performed on them. 
# This happens in regular python. Then jax uses the tracer records to reconstruct the entire function in jaxpr.
# Since side-effects happen in python, they do not appear in jaxpr. But the side-effects happen during trace itself.

In [8]:
# Note that python print is not pure, the text output is a side-effect. Therefore it doesn't appear in jaxpr

In [9]:
# Important: Jaxpr captures the function as executed on the parameters given to it. If python has a branching condition,
# jaxpr will only know about the branch taken.

In [46]:
def selu(x, alpha=1.67, lmbda=1.05):
    if x.ndim < 2:
        return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [47]:
x = jnp.arange(1_000_000)

In [48]:
%timeit selu(x).block_until_ready()

445 μs ± 40.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [49]:
a = selu(x)

In [50]:
a

Array([0.0000000e+00, 1.0500000e+00, 2.0999999e+00, ..., 1.0499968e+06,
       1.0499979e+06, 1.0499989e+06], dtype=float32)

In [51]:
b = a.block_until_ready()

In [52]:
b

Array([0.0000000e+00, 1.0500000e+00, 2.0999999e+00, ..., 1.0499968e+06,
       1.0499979e+06, 1.0499989e+06], dtype=float32)

In [53]:
# Note that block_until_ready is an identity function. And when you call a = selu(x), block_until_ready is implicitly called.

In [54]:
# The above code is sending one operation at a time to the accelerator. This limits the ability of the XLA complier to 
# optimize our functions

In [55]:
selu_jit = jax.jit(selu)

In [56]:
# Pre-compile to time accurately
selu_jit(x).block_until_ready()

Array([0.0000000e+00, 1.0500000e+00, 2.0999999e+00, ..., 1.0499968e+06,
       1.0499979e+06, 1.0499989e+06], dtype=float32)

In [57]:
%timeit selu_jit(x).block_until_ready()

48.5 μs ± 2.08 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [58]:
# The first call to selu_jit is where jax does the tracing. jaxpr is generated, which is then compiled using XLA for the
# specific hardware. Subsequent calls will use the compiled code directly.

In [59]:
jax.make_jaxpr(selu_jit)(x)

{ lambda ; a:i32[1000000]. let
    b:f32[1000000] = pjit[
      name=selu
      jaxpr={ lambda ; c:i32[1000000]. let
          d:bool[1000000] = gt c 0
          e:f32[1000000] = convert_element_type[
            new_dtype=float32
            weak_type=False
          ] c
          f:f32[1000000] = exp e
          g:f32[1000000] = mul 1.6699999570846558 f
          h:f32[1000000] = sub g 1.6699999570846558
          i:f32[1000000] = pjit[
            name=_where
            jaxpr={ lambda ; j:bool[1000000] k:i32[1000000] l:f32[1000000]. let
                m:f32[1000000] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] k
                n:f32[1000000] = select_n j l m
              in (n,) }
          ] d c h
          o:f32[1000000] = mul 1.0499999523162842 i
        in (o,) }
    ] a
  in (b,) }

In [60]:
# block_until_ready is required to time accurately due to JAX's async dispatch.

## When not to use JIT

In [68]:
# 1. Contitions on x.

def f(x):
    print(x)
    if x > 10:
        return x
    else:
        return x * 2

In [69]:
jax.jit(f)(2)

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>


TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_3704065/1268204656.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [72]:
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return 1

jax.jit(g)(10, 20)  # Raises an error

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_3704065/53159033.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values.

Traced values within JIT can only affect control flow via their static attributes such as shape or dtype. 

To avoid this, you can only jit compile part of the function that has no conditions. Just write it in another function and jit compile that.

You can also use jax.lax.cond() for conditioning inside the jit compilable function. More on it later.

You can mark inputs as static_argnums or static_argnames. With this, you can condition on these variables, but jax compiles everytime there's a new value for this.

In [73]:
def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

In [74]:
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))

10


In [77]:
f_jit_correct = jax.jit(f, static_argnames=['x'])
print(f_jit_correct(10))

10


In [78]:
# you can use python's functools.partial when jitting as a decorator

In [80]:
from functools import partial
@partial(jax.jit, static_argnums=0)
def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

In [81]:
# Functools.partial returns another function that behaves like the one passed to it, with the positional arguments and kwargs
# passed to it.

## JIT and caching

In [82]:
# Avoid calling jax.jit on temporary functions defined inside loops or scopes. Lambda and partial return different function hash
# each time. jax.jit(lambda x: f(x)) will compile every time inside a loop.