# Jit with Jax

## How Jax transforms work

To transform python fuctions, Jax converts the function into an intermediate lanauge called jaxpr. The Jax transformations then work on the jaxpr representation of the function.

Jax does not deal with side effects and is functional in this manner so in the code below we would expect the jaxpr to ignore the global_list.append(x) as this is a side effect within the code, as it modifies some state variable value(s) outside its local environment, that is to say has an observable effect besides returning a value. In this case the modification of global_list.

In [11]:
import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


This is purely functional representation of the function provided to the jaxpr, ignoring the side effect.

Its important to note that global_list is interacted on in the first pass by the Jax tracer object, which is used to construct the entire function. However, the tracers do not record the side-effects so they do not appear in the jaxpr but they do happen in the trace itself.

In [13]:
global_list

[Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>]

**One should not rely on this as its strictly an implemtation detail**

An important thing to understand is that jaxpr jaxpr capture the function as executed on the parameters given. Therefore if there is a condtional jaxpr will only construct the lanague on the branch taken.

In [16]:
def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

{ lambda ; a:i32[3]. let  in (a,) }


In [20]:
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([[1.0, 2.0],[1.0,2.0]])))

{ lambda ; a:f32[2,2]. let
    b:f32[2,2] = log a
    c:f32[] = log 2.0
    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
    e:f32[2,2] = div b d
  in (e,) }
