In [1]:
import jax
import jax.numpy as jnp
import jax.lax as lax

JAX allows us to transform Python functions. A Python function is first converted to an intermediate language called *jaxpr*. The transformations then work on the *jaxpr* expressions.

In [2]:
def squarelog(x):
  print(f"printed x: {x}")
  y = jnp.log(x)
  z = y**2
  return z

In [3]:
sqx = jax.make_jaxpr(squarelog)(1.7)

printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [4]:
print(sqx)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m b[35m:f32[][39m = log a; c[35m:f32[][39m = integer_pow[y=2] b [34m[22m[1min [39m[22m[22m(c,) }


JAX generates *jaxpr* by 'tracing': each argument is replaced by a tracer object. Each tracer object records all operations performed on it during the function call. Tracers have no way of recording side-effects. Above the *print()* call is only executed during tracing, and will not appear in the *jaxpr* expression.

*jaxpr* only captures the function as executed on the parameters given to it. For a conditional *jaxpr* will only know the branch which is taken (if the condition can be evaluated at all).

In [5]:
def condlog(x):
    if x.ndim == 1:
        y = jnp.log(x)
        z = y**2
        return z
    elif x.ndim == 2:
        return x + 1000
    else:
        print("No can do")

In [6]:
dim1 = jax.make_jaxpr(condlog)(jnp.array([1.43, 1.72, 2.28]))

In [7]:
print(dim1)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = log a
    c[35m:f32[3][39m = integer_pow[y=2] b
  [34m[22m[1min [39m[22m[22m(c,) }


In [8]:
dim2 = jax.make_jaxpr(condlog)(jnp.array([[1.43, 1.72], [3.15, 2.28]]))

In [9]:
print(dim2)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[2,2][39m. [34m[22m[1mlet[39m[22m[22m b[35m:f32[2,2][39m = add a 1000.0 [34m[22m[1min [39m[22m[22m(b,) }


In [10]:
dim0 = jax.make_jaxpr(condlog)(2.7)

No can do


In [11]:
print(dim0)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m  [34m[22m[1min [39m[22m[22m() }


## JIT-ing a function

Since Python is an interpreted language the code for *SELU* below will be sending one operation at a time to the accelerator. This 'limits' the ability of the XLA comiler to otimize the execution.

In [12]:
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

318 µs ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


Let's allow the XLA compiler to analyse and optimize the code. The first call of *selu_jit* is to do the tracing and translation to XLA HLO. The code is subsequently subjected to target (backend) independent and target-dependent optimizations.

In [13]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

60.3 µs ± 457 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Subsequent calls to *selu_jit* will use the optimized compiled code, skipping the Python implementation entirely.

### So why can't we *jit* everything?

In [14]:
def absval(x):
    print(type(x))
    if x >= 0:
        return x
    else:
        return -x

absval_jit = jax.jit(absval)

In [15]:
# Uncomment to see the error message.
# absval_jit(3.2)

For *jax.jit* the default tracer object is *ShapedArray*. Each tracer has a concrete shape and dtype (and we are allowed to condition on these), but no concrete value. The compiled function will work on all possible inputs with the same shape.

We can relax the traceability constraints in multiple ways. For example using *static_argnums* argument to *jit*, we can specify to trace on concrete (instead of abstract) values on some of the arguments.

In [16]:
absval_con = jax.jit(absval, static_argnums=(0,))

In [17]:
print(absval_con(-3.2))

<class 'float'>
3.2


The next example involves a loop which is statically unrolled.

In [18]:
def summer(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

summer_jit = jax.jit(summer, static_argnums=(1,))

In [19]:
summer_jit(jnp.array([1,2,3,4,5,6,7,8,9,0]), 4)

Array(10., dtype=float32, weak_type=True)

## Structured control flow primitives

JAX has structured control flow primitives which allow to avoid recompilation with traceable control flow that avoids unrolling large loops. To capture a conditional expression for dynamic execution, one must use *jax.lax.switch()* or *jax.lax.cond()*

In [20]:
@jax.jit
def one_of_three(index, arg):
  return lax.switch(index, [lambda x: x + 1.,
                            lambda x: x - 2.,
                            lambda x: x + 3.], arg)

In [21]:
one_of_three(2, 7)

Array(10., dtype=float32, weak_type=True)

In [22]:
one_of_three(0, 18)

Array(19., dtype=float32, weak_type=True)

In [23]:
# jax.lax provides the *cond* primitive (operation)
# The first operand of this primitive is the branch index

print(jax.make_jaxpr(one_of_three)(1, 5.))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m b[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = pjit[
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:i32[][39m e[35m:f32[][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] d
          g[35m:i32[][39m = clamp 0 f 2
          h[35m:f32[][39m = cond[
            branches=(
              { [34m[22m[1mlambda [39m[22m[22m; i[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m j[35m:f32[][39m = add i 1.0 [34m[22m[1min [39m[22m[22m(j,) }
              { [34m[22m[1mlambda [39m[22m[22m; k[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m l[35m:f32[][39m = sub k 2.0 [34m[22m[1min [39m[22m[22m(l,) }
              { [34m[22m[1mlambda [39m[22m[22m; m[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m n[35m:f32[][39m = add m 3.0 [34m[22m[1min [39m[22m[22m(n,) }
            )
            

The signatures of these conditionals are:

*lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B*

*lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B*

To capture a loop for dynamic execution the operations *jax.lax.while_loop()* and *jax.lax.fori_loop()* are available.

In [24]:
@jax.jit
def adder(arg, n):
  ones = jnp.ones(arg.shape)
  return lax.fori_loop(0, n,
                       lambda i, carry: carry + ones, ones)

In [25]:
adder(jnp.ones(3), 4)

Array([5., 5., 5.], dtype=float32)

Dynamic loops use the *while* primitive. Notice also the XLA_call primitive *pjit*.

In [26]:
print(jax.make_jaxpr(adder)(jnp.zeros(2), 6))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[2][39m b[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[2][39m = pjit[
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[2][39m e[35m:i32[][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[2][39m = broadcast_in_dim[broadcast_dimensions=() shape=(2,)] 1.0
          _[35m:i32[][39m _[35m:i32[][39m g[35m:f32[2][39m = while[
            body_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; h[35m:f32[2][39m i[35m:i32[][39m j[35m:i32[][39m k[35m:f32[2][39m. [34m[22m[1mlet
                [39m[22m[22ml[35m:i32[][39m = add i 1
                m[35m:f32[2][39m = add k h
              [34m[22m[1min [39m[22m[22m(l, j, m) }
            body_nconsts=1
            cond_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; n[35m:i32[][39m o[35m:i32[][39m p[35m:f32[2][39m. [34m[22m[1mlet
                [39m[22m[22mq[35m:bool[][39m = lt n o
              [34m[22m[1min 

## Summary

\begin{split}
\begin{array} {r|rr}
\hline \
\textrm{construct}
& \textrm{jit}
& \textrm{grad} \\
\hline \
\textrm{if} & ❌ & ✔ \\
\textrm{for} & ✔u & ✔\\
\textrm{while} & ✔u & ✔\\
\textrm{lax.cond} & ✔ & ✔\\
\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\
\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\
\textrm{lax.scan} & ✔ & ✔\\
\hline
\end{array}
\end{split}

*u* - unrolls the loop