**S02P02_tutorial_jit_compilation_with_jax.ipynb**

Arz

2024 APR 09 (TUE)

reference:
https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# how JAX transforms work

Python code -> [tracing] -> jaxpr -> [transformation]

In [4]:
from jax import make_jaxpr

In [5]:
g = []  # global list

def log2(x):
    g.append(x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.)
    return ln_x/ln_2

In [6]:
print(make_jaxpr(log2))
print(make_jaxpr(log2)(3.))

<function make_jaxpr(log2) at 0x7248eabb39c0>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


note: **jaxpr ignores side-effects**

ex) g.append()

In [7]:
def log2_with_print(x):
    print("x:", x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.)
    return ln_x/ln_2

In [8]:
print(make_jaxpr(log2_with_print)(3.))

x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [9]:
def log2_with_cond(x):
    if x.ndim == 2:
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.)
        return ln_x/ln_2
    else:
        return x

In [10]:
print(make_jaxpr(log2_with_print)(2.))

x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [11]:
print(make_jaxpr(log2_with_cond)(3.))

{ lambda ; a:f32[]. let  in (a,) }


# JIT compiling a function

## ex) SELU

In [12]:
def selu(x, alpha=1.67, lamda=1.05):
    return lamda*jnp.where(x > 0, x, alpha*jnp.exp(x) - alpha)

In [13]:
x = jnp.arange(1000000)

### sending one operation at a time to the accelerator

this limits the ability of the XLA compiler to optimize our functions.

In [14]:
%timeit selu(x).block_until_ready()

1.71 ms ± 266 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### JIT

In [15]:
selu_jit = jit(selu)

# warm-up
selu_jit(x).block_until_ready()  # because first run includes compiling time

%timeit selu_jit(x).block_until_ready()

87.1 µs ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# why can't we just JIT everything?

not-to-jit cases

## ex) if conditioned on the value

In [16]:
def f(x):
    if x > 0:
        return x
    else:
        return -x

In [17]:
f_jit = jit(f)

# f_jit(1.)  # forbidden: throws error

## ex) while loop conditioned on the value

In [18]:
def f(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

In [19]:
f_jit = jit(f)

# f_jit(7., 30.)  # forbidden: throws error

### solution #1: jit part of the function

In [20]:
@jit
def loop_body(i):
    return i + 1

def f(x, n):
    i = 0
    while i < n:
        i = loop_body(i)
    return x + i

In [21]:
f(7., 30.)

Array(37., dtype=float32, weak_type=True)

### solution #2: specify static arguments

⚠️ caveat: whenever the value of the static argument changes, recompilation is needed, leading to performance degradation. therefore, this method is useful is the changes are rare.

In [22]:
def f(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

In [23]:
f_jit = jit(f, static_argnums = (1,))

f_jit(7., 30.)

Array(37., dtype=float32, weak_type=True)

In [24]:
f_jit = jit(f, static_argnums = 1)

f_jit(7., 30.)

Array(37., dtype=float32, weak_type=True)

In [25]:
f_jit = jit(f, static_argnames=['n'])

f_jit(7., 30.)

Array(37., dtype=float32, weak_type=True)

to use decorator for the same purpose, use Python's *partial*.

In [26]:
from functools import partial

In [27]:
@partial(jit, static_argnames=['n'])
def f(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

In [28]:
f(7., 30.)

Array(37., dtype=float32, weak_type=True)

# when to use JIT

actually, above examples are not worth for JIT.

because jit introduces some overhead itself.
use jit for the largest possible chunk of your computation,
where the functions are complex and called numerous times enough for jit to be worth.

- ex) entire update step in machine learning 

In [29]:
def f(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

In [30]:
f_jit = jit(f, static_argnames=['n'])

In [31]:
n = 30

print("f:")
%timeit f(7., n)

print("f_jit:")
%timeit f_jit(7., n).block_until_ready()

f:
581 ns ± 0.842 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
f_jit:
77.3 µs ± 6.04 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


I think it is useful in this case (?)

# caching

⚠️ avoid calling jax.jit inside loops.

because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined.

In [32]:
def loop_body(i):
    return i + 1

# not good
def f_jit_partial(x, n):
    i = 0
    while i < n:
        # don't do this!
        # at each iteration, partial returns a function with a different hash
        i = jax.jit(partial(loop_body))(i)
    return x + i

# not good
def f_jit_lambda(x, n):
    i = 0
    while i < n:
        # don't do this!
        # at each iteration, lambda returns a function with a different hash
        i = jax.jit(lambda x: loop_body(x))(i)
    return x + i

# ok but refrain from using
def f_jit(x, n):
    i = 0
    while i < n:
        # ok
        # JAX can find the cached, compiled function
        i = jax.jit(loop_body)(i)
    return x + i

In [33]:
print("jit called in a loop with partials:")
%timeit f_jit_partial(7., 30).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit f_jit_lambda(7., 30).block_until_ready()

print("jit called in a loop with caching:")
%timeit f_jit(7., 30).block_until_ready()

jit called in a loop with partials:
465 ms ± 31.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
391 ms ± 22.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
7.3 ms ± 255 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


**one can just pre-jit.**

In [34]:
def loop_body(i):
    return i + 1

loop_body_jit = jit(loop_body)

def f(x, n):
    i = 0
    while i < n:
        i = loop_body_jit(i)
    return x + i

In [35]:
print("pre-jit, cached case:")
%timeit f(7., 30).block_until_ready()

pre-jit, cached case:
5.52 ms ± 678 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
