# Jit with Jax

## How Jax transforms work

To transform python fuctions, Jax converts the function into an intermediate lanauge called jaxpr. The Jax transformations then work on the jaxpr representation of the function.

Jax does not deal with side effects and is functional in this manner so in the code below we would expect the jaxpr to ignore the global_list.append(x) as this is a side effect within the code, as it modifies some state variable value(s) outside its local environment, that is to say has an observable effect besides returning a value. In this case the modification of global_list.

In [2]:
import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


This is purely functional representation of the function provided to the jaxpr, ignoring the side effect.

Its important to note that global_list is interacted on in the first pass by the Jax tracer object, which is used to construct the entire function. However, the tracers do not record the side-effects so they do not appear in the jaxpr but they do happen in the trace itself.

In [3]:
global_list

[Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>]

**One should not rely on this as its strictly an implemtation detail**

An important thing to understand is that jaxpr jaxpr capture the function as executed on the parameters given. Therefore if there is a condtional jaxpr will only construct the lanague on the branch taken.

In [4]:
def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

{ lambda ; a:i32[3]. let  in (a,) }


In [5]:
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([[1.0, 2.0],[1.0,2.0]])))

{ lambda ; a:f32[2,2]. let
    b:f32[2,2] = log a
    c:f32[] = log 2.0
    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
    e:f32[2,2] = div b d
  in (e,) }


## JIT compiling a function

Jax allows for code to be run on the CPU/GPU/TPU
using exactly the same code.

Here we are going to look at the SELU operation:

$$ SELU(x) = \lambda \begin{cases}
    \mbox{$x$} & \mbox{if } x > 0\\
    \mbox{$\alpha e^x-\alpha$} & \mbox{if } x \leq 0
    \end{cases} $$

In [6]:
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

994 µs ± 88.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In this example the code is sending one operation at a time to the accelerator limimiting the XLA compilers ability to optimize the function.


For this code to be the most performant, we want to give the XLA compiler as much code as possible. For this jax provides the jax.jit transformation, this will jit a jax-compatible function, to speed it up.

In [7]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

89.1 µs ± 713 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


As we can see that is an order of mangnituded faster, so what have we done:

1. We compiled selu with jax.jit(selu) and called it selu_jit
2. We when ran selu_jit once on x, such that Jax can do the tracing. The jaxpr is then compiled using XLA into effiecent code. Now subsequent calls use the optimised compiled code.

Meaning that we no longer use the old python implemention at all.

*If we didn't include the warm-up call seperately, everything still works as expected, however the compilitation time would be included in the benchmark as well. It would still be faster but would not be a fair comparision.*

## Why can we not JIT everything?

After the above example, it may be tempting to think we can just jit everything. However, this is not the case and there are time when jitting is approtiate and times when it is not.

The following example is a case where it doesn't work

### 1. Conditional functions

In [8]:
# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

f_jit = jax.jit(f)
f_jit(10)  # Should raise an error. 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /tmp/ipykernel_4822/1879943193.py:3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

### 2. While loop conditations

In [12]:
def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

g_jit = jax.jit(g)
g_jit(10, 20)  # Should raise an error.

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function g at /tmp/ipykernel_4822/4172967325.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The problem here, is that we are trying to jit on a value that is being conditioned. It should be clear as to why this is not possible as the jaxpr depends on the actual values to trace it. See how jit transforms work.

The more specific information about the values we use in the trace, the more the standard Python control flow can be used. However, at the cost of being too specific means that we can not resuse the same traced function values. Jax deals with this by tracing at different levels of abstractions.

For `jax.jit`, the default level is `ShapedArray` - that is, each tracer has a concrete shape, this we can condition on), but it has no concrete value. What this allows, is that compiled version can work on all possible inputs as long as they have the same shape. This means that if we try and condition on a value we get an error as the tracer has no concrete value.

In `jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` must not condition on value. The docs go into more detail about the sharp bits of Jax controll flow.


A nice table outlines when to use `jit` or `grad` 

| Construct      | jit    | grad |
|----------------|--------|------|
| if             | **NO** | YES  |
| for            | YES*   | YES  |
| while          | YES*   | YES  |
| lax.cond       | YES    | YES  |
| lax.while_loop | YES    | fwd  |
| lax.fori_loop  | YES    | fwd  |
| lax.scan       | YES    | YES  |

∗ = argument-value-independent loop condition - unrolls the loop

So one way to deal with the problem is to rewrite the code to avoid conditionals on value. Another option is to use special jax control flow operations like `jax.lax.cond`. This is sometimes impossible to do. In situations like this you can consider jitting only part of the function. For example the more computationaly expensive part inside the loop. i.e. **MAKE SURE YOU CHECK CACHING TO AVOID SHOOTING YOURSELF IN THE FOOT**

In [13]:
# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)

DeviceArray(30, dtype=int32, weak_type=True)

In siutations where we really need to JIT a function that has a condition on the value of an input. We can tell Jax to use a less abstract tracer for a specific input, by using `static_argnums`. The cost of this is that the resulting jaxpr is less a lot flexible. Which means jax will have to recompile the function with every new value of the specific input. It is a good strat if the function is guaranteed to get a limited set of different values.

In [9]:
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))

10


In [14]:
g_jit_correct = jax.jit(g, static_argnums=1)
print(g_jit_correct(10, 20))

30


## When to use Jit

In the example above, jitting is not actually worth it:

In [None]:
print("g jitted:")
%timeit g_jit_correct(10, 20).block_until_ready()

print("g:")
%timeit g(10, 20)

g jitted:


This is because `jax.jit` introduces some overhead as well. Therefore it usually only saves time if the complied function is complex and the plan is to run it many times. This is a common usecase in Machine learning though :-) 

**Generally, you want to jit the largest possible chunk of computation, i.e. the entire update step**