In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import random, jit, lax, vmap, pmap, grad, value_and_grad

In [4]:
rng = np.random.default_rng()
key = jax.random.PRNGKey(0)

## JIT
When I jit compile a function, upon first invocation Jax will trace that function along with computing the final results. The tracing is done using a `ShapedArray` that records the ops and intermediate dims and dtype of the input. Each dim and dtype combination requires its own special trace, which means that for each unique dtype/dim combination of the input params, Jax will need to trace and compile the function anew.

A gotcha is that there is not much point in jit compiling a lambda function, because each invocation of a lambda is seen as a new function and Jax is not able to cache the compiled version.

## Tracer in Action

The first time around the jitted function will be executed with `ShapedArray` objects, it will still return the right answer though. The second time around, the function has been compiled. As long as the dims and dtypes are the same, the compiled version will be used. For a new dim/dtype, the function will be compiled anew, but that does not mean that this overwrites the older compilation. Now two implementations exist.

In [2]:
@jit
def dotproduct(x, y):
    print("Inside dotproduct")
    print(f"x={x}")
    print(f"y={y}")
    z = jnp.dot(x, y)
    print(f"z={z}")
    return z

### First Invocation
Here we see that the print statements are executed and we can see that the internal object being used is a `ShapedArray` even though the final output is the usual `Array`.

In [5]:
npx = rng.random(100)
jnpx = jnp.asarray(npx)

In [6]:
np.dot(npx, npx.T)

27.961656218432985

In [7]:
dotproduct(jnpx, jnpx.T)

Inside dotproduct
x=Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=0/1)>
y=Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=0/1)>
z=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>


Array(27.961657, dtype=float32)

### Second Invocation
The function has now been compiled, which means all side-effects have been erased. Now I won't see any print statements.

In [8]:
npx = rng.random(100)
jnpx = jnp.asarray(npx)

In [9]:
np.dot(npx, npx.T)

33.88908808848025

In [10]:
dotproduct(jnpx, jnpx.T)

Array(33.88908, dtype=float32)

### With different dims
This will trigger another compiled copy to be created and cached.

In [11]:
npx = rng.random(50)
jnpx = jnp.asarray(npx)

In [12]:
np.dot(npx, npx.T)

13.679525864218837

In [13]:
dotproduct(jnpx, jnpx.T)

Inside dotproduct
x=Traced<ShapedArray(float32[50])>with<DynamicJaxprTrace(level=0/1)>
y=Traced<ShapedArray(float32[50])>with<DynamicJaxprTrace(level=0/1)>
z=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>


Array(13.679525, dtype=float32)

### Does not overwrite old compiled function
If I call the function with the old dims, it will still use the old compiled function.

In [14]:
npx = rng.random(100)
jnpx = jnp.asarray(npx)

In [15]:
np.dot(npx, npx.T)

32.2235387654915

In [16]:
dotproduct(jnpx, jnpx.T)

Array(32.223534, dtype=float32)

## JIT and Pure Functions
Here are three ways impurites can get into function. 
  * Use of IO
  * Use of global states
  * Use of iterators

We saw the use of IO above. Any IO is erased out of the Jitted function after the first invocation.

### Global States
For global states, JIT will create a closure around the value it has during the first invocation. If I change the value of the global state after the function has been compiled, it will not be reflected in the function. Recompilation will update the closed global state.

In [20]:
power = 5

def powerof(x):
    return x ** power

In [21]:
powerof(2)

32

In [22]:
power = 10
powerof(2)

1024

In [23]:
power = 5
jit(powerof)(2)

Array(32, dtype=int32, weak_type=True)

The power global state has been closed in the jitted function so changing it does not give me different results.

In [24]:
power = 10
jit(powerof)(2)

Array(32, dtype=int32, weak_type=True)

Of course I can force a recompile and the current value of the power will be used for compiling that version.

In [25]:
power = 10
jit(powerof)(jnp.array([2., 2.]))

Array([1024., 1024.], dtype=float32)

### Iterators
Iterators cannot be passed as arguments to a JITted function. Below are are two experiments. It can be seen that the iterator is caught the first time the function is invoked with this type. This check is probably being done at the JIT compilation time. In the second experiment I am trying this with an already JITted function, the `lax.fori_loop` to see if a) this check is still performed, b) the kind of error it throws.

In [26]:
@jit
def sum_of_sqs(xs):
    return jnp.sum(jnp.array([x**2 for x in xs]))

In [27]:
sum_of_sqs([0., 1., 2.])

Array(5., dtype=float32)

In [28]:
try:
    sum_of_sqs(iter([0., 1., 2.]))
except Exception as err:
    print(f"{type(err)}\n{err}")

<class 'TypeError'>
Argument '<list_iterator object at 0x1065ca9e0>' of type <class 'list_iterator'> is not a valid JAX type.


Below is how I'd call the `lax.fori_loop` with a JAX array.

In [29]:
ary = jnp.array([0., 1., 2.])
lax.fori_loop(0, 3, lambda idx, acc: acc + ary[idx]**2, 0)

Array(5., dtype=float32)

Now, I am going to call this with an iterator. The error thrown here is different than the one thrown earlier.

In [30]:
ary = jnp.array([0., 1., 2.])
it = iter(ary)
try:
    lax.fori_loop(0, 3, lambda idx, acc: acc + next(it)**2, 0)
except Exception as err:
    print(f"{type(err)}\n{err}")

<class 'jax._src.errors.UnexpectedTracerError'>
Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was scanned_fun at /opt/miniconda3/envs/ai/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1607 traced for scan.
------------------------------
The leaked intermediate value was created on line /var/folders/g0/8d67dwg94sj_ysrm2zmtkjrc0000gn/T/ipykernel_96597/2182326376.py:4 (<lambda>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/var/folders/g0/8d67dwg94sj_ysrm2zmtkjrc0000gn/T/ipykernel_96597/2182

#### !!Important!!
Contrary to what I'd think, it is perfectly ok to use iterators within the function to mutate internal state!

In [31]:
@jit
def sum_of_sqs_2(xs):
    # return jnp.sum(jnp.array([x**2 for x in xs]))
    answer = 0
    for x in xs:
        answer += x**2
    return answer

In [32]:
sum_of_sqs_2([0., 1., 2.])

Array(5., dtype=float32, weak_type=True)

In [33]:
sum_of_sqs_2([0., 1., 2., 3., 4.])

Array(30., dtype=float32, weak_type=True)

## JAXPR

In [34]:
def myrelu(x):
    return jnp.maximum(0., x)

In [35]:
jax.make_jaxpr(myrelu)(5.)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m b[35m:f32[][39m = max 0.0 a [34m[22m[1min [39m[22m[22m(b,) }

```
{
  lambda ; a:f32[].
    let b:f32[] = max 0.0 a
  in (b,)
}
```

## Control Flow
Jax cannot jit compile a function that control flows in it, because the `ShapedArray` can only deal with the dtype and dims of the input, not the actual value, which is what the control flow depends on. There are two ways to workaround this -

  * Jit the branches and keep the overall function as a pure Python function.
  * Use static args.

In [36]:
def f(x, n):
    if n > 0:
        return 3 * x**3 + 2 * x**2 + n
    else:
        return 2 * x - n

In [37]:
try:
    jf = jit(f)(2., 5.)
except Exception as err:
    print(f"{type(err)}\n{err}")

<class 'jax._src.errors.ConcretizationTypeError'>
Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
The error occurred while tracing the function f at /var/folders/g0/8d67dwg94sj_ysrm2zmtkjrc0000gn/T/ipykernel_96597/999900859.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'n'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


A workaround is to jit the branches, but keep the overall function as a pure Python function.

In [38]:
@jit
def g(x):
    return 3 * x**3 + 2 * x**2

# This is a pure Python function
def f(x, n):
    if n > 0:
        return g(x) + n
    else:
        return 2 * x - n

In [39]:
f(2., 5.)

Array(37., dtype=float32, weak_type=True)

The second workaround is to declare `n` as a static argument. This means that in addition to the dtype and dims of `x`, the function will be compiled for **every distinct** value of `n`.

In [40]:
def f2(x, n):
    print(f"x={x} n={n}")
    if n > 0:
        return 3 * x**3 + 2 * x**2 + n
    else:
        return 2 * x - n

In [41]:
jitf = jit(f2, static_argnames=["n"])

Compiled for the first time for x as a float scalar and n = 5.

In [42]:
jitf(2., 5.)

x=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> n=5.0


Array(37., dtype=float32, weak_type=True)

In [43]:
jitf(3., 5.)

Array(104., dtype=float32, weak_type=True)

Compiled for a second time even though x is still a float scalar, because value of n is different.

In [44]:
jitf(2., 3.)

x=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> n=3.0


Array(35., dtype=float32, weak_type=True)

Here is a more declarative way of declaring n as static.

In [45]:
from functools import partial

@partial(jit, static_argnames=["n"])
def f3(x, n):
    print(f"x={x} n={n}")
    if n > 0:
        return 3 * x**3 + 2 * x**2 + n
    else:
        return 2 * x - n

In [46]:
f3(2., 5.)

x=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> n=5.0


Array(37., dtype=float32, weak_type=True)

In [47]:
f3(3., 5.)

Array(104., dtype=float32, weak_type=True)

In [48]:
f3(2., 4.)

x=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> n=4.0


Array(36., dtype=float32, weak_type=True)