In [1]:
import jax
import jax.numpy as jnp

In [2]:
def selu(
    x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946
):
    """Scaled exponential linear unit activation function."""
    return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [3]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,))
x.shape

(1000000,)

In [4]:
selu_jit = jax.jit(selu)

In [5]:
%timeit -n100 selu(x).block_until_ready()

The slowest run took 19.28 times longer than the fastest. This could mean that an intermediate result is being cached.
302 µs ± 523 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
%timeit -n100 selu_jit(x).block_until_ready()

The slowest run took 27.88 times longer than the fastest. This could mean that an intermediate result is being cached.
61.3 µs ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
@jax.jit
def selu(
    x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946
):
    """Scaled exponential linear unit activation function."""
    return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [8]:
z = selu(x)  # warmup

In [9]:
%timeit -n100 selu(x).block_until_ready()

13.4 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Compiling to specific hardware

In [10]:
def selu(
    x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946
):
    """Scaled exponential linear unit activation function."""
    return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [11]:
selu_jit_cpu = jax.jit(selu, backend="cpu")
selu_jit_gpu = jax.jit(selu, backend="gpu")

In [12]:
%timeit -n100 selu(x).block_until_ready()

121 µs ± 19 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%timeit -n100 selu_jit_cpu(x).block_until_ready()

1.44 ms ± 151 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit -n100 selu_jit_gpu(x).block_until_ready()

The slowest run took 14.77 times longer than the fastest. This could mean that an intermediate result is being cached.
47.1 µs ± 58.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
x_cpu = jax.device_put(x, jax.devices("cpu")[0])
x_gpu = jax.device_put(x, jax.devices("gpu")[0])

In [16]:
%timeit -n100 selu(x_cpu).block_until_ready()

793 µs ± 187 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
%timeit -n100 selu(x_gpu).block_until_ready()

99.6 µs ± 21.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
%timeit -n100 selu_jit_cpu(x_cpu).block_until_ready()

138 µs ± 80.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
%timeit -n100 selu_jit_cpu(x_gpu).block_until_ready()

1.47 ms ± 85.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
%timeit -n100 selu_jit_gpu(x_gpu).block_until_ready()

22.3 µs ± 10.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%timeit -n100 selu_jit_gpu(x_cpu).block_until_ready()

858 µs ± 199 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## static arg

1. Reuse of Compiled Code : fn call multiple time with same value then reuse the compiled code
2. Recompilation Overhead : When the value of a static argument changes, JAX needs to recompile the function to generate optimized code specific to the new value.
3. Optimized Code: The performance gains from running optimized machine code often significantly outweigh the recompilation overhead, especially for computationally intensive operations.

In [22]:
import jax
import jax.numpy as np


def f(x, y):
    print("compile")  # side effect
    return x * y


# JIT-compile the function with static_argnums
jit_my_function = jax.jit(f, static_argnums=(1,))

In [23]:
jit_my_function(1.0, 5.0)

compile


Array(5., dtype=float32, weak_type=True)

In [24]:
jit_my_function(2.0, 5.0)

Array(10., dtype=float32, weak_type=True)

In [25]:
jit_my_function(1.0, 10.0)

compile


Array(10., dtype=float32, weak_type=True)

In [26]:
jit_my_function(1.0, 10.0)

Array(10., dtype=float32, weak_type=True)

### passing the func as the argument

In [27]:
def dense_layer(x, w, b, act_fn):
    return act_fn(x * w + b)


x = jnp.array([2.0, 4.0, 5.0])
w = jnp.ones((3, 3))
b = jnp.ones((3,))

x.shape, w.shape, b.shape

((3,), (3, 3), (3,))

In [28]:
dense_layer_jit = jax.jit(dense_layer)
try:
    dense_layer_jit(x, w, b, selu)  #! throw error since we are passing the fn
except TypeError as e:
    print(e)

Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute


In [29]:
dense_layer_jit = jax.jit(dense_layer, static_argnames=("act_fn",))
dense_layer_jit(x, w, b, selu_jit)

Array([[3.152103, 5.253505, 6.304206],
       [3.152103, 5.253505, 6.304206],
       [3.152103, 5.253505, 6.304206]], dtype=float32)

# Jit Internals

## Side Effect are not Logged

* Side Effect happen during tracing .Once tracing is done and jaxpr is created .Then side effect are not logged in jaxpr
* Global Variable are stored as const

In [30]:
import jax


global_state = jnp.ones(1)


def impure_func(x):
    print(f"side effect {x}")
    return x * global_state


jaxpr = jax.make_jaxpr(impure_func)(1.0)
jaxpr

side effect Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


{ lambda a:f32[1]; b:f32[]. let
    c:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    d:f32[1] = mul c a
  in (d,) }

In [31]:
jaxpr.consts

[Array([1.], dtype=float32)]

In [32]:
impure_func_jit = jax.jit(impure_func)
impure_func_jit(10)
# ? during the first run IR is created called jaxpr
# ? side effect are not logged in jaxpr
# ? so during second func behavior is diff from first run

side effect Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


Array([10.], dtype=float32)

In [33]:
impure_func_jit(10)
# ? print is not logged

Array([10.], dtype=float32)

In [34]:
global_state = 2
impure_func_jit(10)

# ? even after the global state change it is not affect in the func

Array([10.], dtype=float32)

In [35]:
impure_func(10)

side effect 10


20

## Jaxpr [JAX exPRession]

```python
jaxpr ::= { lambda Var* ; Var+.
            let Eqn*
            in [Expr+] 
            }
```

* Var*  --> Constant
* Var+  --> Input
* Eqn*  --> Intermediate Equation
* Expr+ --> Output

In [36]:
import jax.numpy as jnp


def f(x, y):
    const_array = jnp.array([1.0, 2.0])  # Automatically a constvar
    z = x * y + const_array
    return z

In [37]:
jaxpr = jax.make_jaxpr(f)(jnp.array([4.0, 5.0]), jnp.array([2.0, 3.0]))
jaxpr

{ lambda a:f32[2]; b:f32[2] c:f32[2]. let
    d:f32[2] = mul b c
    e:f32[2] = add d a
  in (e,) }

### Constant Vars

When some value not depend on the input argument , When these value appear as Eqn then these value are Constvars.

In [38]:
def f(x, y):
    a = jnp.array([10.0])  # ? constvar
    return x + y + a


jaxpr = jax.make_jaxpr(f)(jnp.array([2.0]), jnp.array([3.0]))

f_jit = jax.jit(f)
jaxpr_jit = jax.make_jaxpr(f_jit)(jnp.array([2.0]), jnp.array([3.0]))

f_jit_static = jax.jit(f, static_argnums=(0, 1))
jaxpr_jit_static = jax.make_jaxpr(f_jit_static)(jnp.array([2.0]), jnp.array([3.0]))

In [39]:
jaxpr

{ lambda a:f32[1]; b:f32[1] c:f32[1]. let
    d:f32[1] = add b c
    e:f32[1] = add d a
  in (e,) }

In [40]:
jaxpr_jit

{ lambda ; a:f32[1] b:f32[1]. let
    c:f32[1] = pjit[
      name=f
      jaxpr={ lambda d:f32[1]; e:f32[1] f:f32[1]. let
          g:f32[1] = add e f
          h:f32[1] = add g d
        in (h,) }
    ] a b
  in (c,) }

In [41]:
jaxpr_jit_static

{ lambda a:f32[1]; b:f32[1] c:f32[1]. let
    d:f32[1] = pjit[
      name=f
      jaxpr={ lambda ; e:f32[1] f:f32[1] g:f32[1]. let
          h:f32[1] = add f g
          i:f32[1] = add h e
        in (i,) }
    ] a b c
  in (d,) }

In [42]:
jaxpr.consts, jaxpr_jit.consts, jaxpr_jit_static.consts

([array([10.], dtype=float32)], [], [array([10.], dtype=float32)])

In [43]:
jaxpr_jit.eqns[0].params["jaxpr"].consts

[array([10.], dtype=float32)]

### Equation

```python
Eqn  ::= let Var+ = Primitive [ Param* ] Expr+

Example: 
    a:f32[] = reduce_sum[axes=(0,)] b
    c:f32[1] = integer_pow[y=2] b
```

In [44]:
def f(x):
    y = x + x
    return jnp.power(y, 2)


jaxpr = jax.make_jaxpr(f)(jnp.array([1.0]))
jaxpr

{ lambda ; a:f32[1]. let
    b:f32[1] = add a a
    c:f32[1] = integer_pow[y=2] b
  in (c,) }

In [45]:
jaxpr.eqns

[a:f32[1] = add b b, a:f32[1] = integer_pow[y=2] b]

In [46]:
jaxpr.eqns[1]

a:f32[1] = integer_pow[y=2] b

## Tracing

* Python Program  --> Jaxpr static typed Expression us done by process called Tracing.
* During Tracing the argument are wrapped by Tracer object. By default `SharedArray` is tracer Object
* It record all the operation during the function call.
* The output of the tracing is jaxpr

## Wrapped by tracer object

In [47]:
def f(x):
    print(f"This is wrapped by tracer object {x}")
    return x


jaxpr = jax.make_jaxpr(f)(1)

This is wrapped by tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [48]:
jaxpr

{ lambda ; a:i32[]. let  in (a,) }

In [49]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3, 3)) * 2.0
z = jnp.array([2.0, 1.0, 0.0]).T


def f(x, y):
    print(f"x={x}, y={y}, z={z}")
    return jnp.sum(x + y * z)


jaxpr = jax.make_jaxpr(f)(x, y)
jaxpr

x=Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>, y=Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=1/0)>, z=[2. 1. 0.]


{ lambda a:f32[3]; b:f32[3] c:f32[3,3]. let
    d:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
    e:f32[3,3] = mul c d
    f:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] b
    g:f32[3,3] = add f e
    h:f32[] = reduce_sum[axes=(0, 1)] g
  in (h,) }

In [50]:
jaxpr.consts

[Array([2., 1., 0.], dtype=float32)]

### Control flow in tracing

* Input should used to control the flow the program execution.
* but it use the `Input shape` to control the flow

In [62]:
def f(x: int):
    y = 0
    for i in range(x):  #! Throws Error since we are using input to control the flow
        y += i
    return i


try:
    jax.make_jaxpr(f)(5)
except jax.errors.TracerIntegerConversionError as e:
    print(e)

The __index__() method was called on traced array with shape int32[].
The error occurred while tracing the function f at /tmp/ipykernel_25766/2111049755.py:1 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError


In [63]:
def f(x: jax.Array):
    y = 0
    for i in x:  #* it fine since using the input shape to control the flow
        y += i
    return i

jax.make_jaxpr(f)(jnp.arange(5))
# try:
    
# except jax.errors.TracerIntegerConversionError as e:
#     print(e)

{ lambda ; a:i32[5]. let
    b:i32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] a
    c:i32[] = squeeze[dimensions=(0,)] b
    d:i32[] = add 0 c
    e:i32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] a
    f:i32[] = squeeze[dimensions=(0,)] e
    g:i32[] = add d f
    h:i32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=(1,)] a
    i:i32[] = squeeze[dimensions=(0,)] h
    j:i32[] = add g i
    k:i32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=(1,)] a
    l:i32[] = squeeze[dimensions=(0,)] k
    m:i32[] = add j l
    n:i32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=(1,)] a
    o:i32[] = squeeze[dimensions=(0,)] n
    _:i32[] = add m o
  in (o,) }

In [65]:
def relu(x):
    if x>0:
        return x
    return 0.0

try:
    jaxpr = jax.make_jaxpr(relu)(10.)
except jax.errors.TracerBoolConversionError as e:
    print(e)

Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function relu at /tmp/ipykernel_25766/1624187287.py:1 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError


### Use static arg to use the input in control statement

Here is the tradeoff. Your function is now being compiled on any call with a new input parameter value

In [69]:
jaxpr = jax.make_jaxpr(relu,static_argnums=0)(10.)
jaxpr

{ lambda ; . let  in (10.0,) }

In [70]:
jax.make_jaxpr(relu,static_argnums=0)(-0.)

{ lambda ; . let  in (0.0,) }

In [71]:
jax.make_jaxpr(relu,static_argnums=0)(1000.)

{ lambda ; . let  in (1000.0,) }

### use lax control flow operator to control the execution

In [75]:
def relu(x):
    return jax.lax.cond(x > 0, lambda x: x, lambda x: 0, x)

jaxpr = jax.make_jaxpr(relu)(1)
jaxpr

{ lambda ; a:i32[]. let
    b:bool[] = gt a 0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:i32[] = cond[
      branches=(
        { lambda ; e:i32[]. let  in (0,) }
        { lambda ; f:i32[]. let  in (f,) }
      )
      linear=(False,)
    ] c a
  in (d,) }

In [77]:
jax.make_jaxpr(relu)(11)

{ lambda ; a:i32[]. let
    b:bool[] = gt a 0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:i32[] = cond[
      branches=(
        { lambda ; e:i32[]. let  in (0,) }
        { lambda ; f:i32[]. let  in (f,) }
      )
      linear=(False,)
    ] c a
  in (d,) }

In [85]:
def f(x):
    y = 0
    for i in range(x):
        y += i
    return y

jaxpr_static = jax.make_jaxpr(f,static_argnums=0)(5)
jaxpr_static

{ lambda ; . let  in (10,) }

In [86]:
jax.jit(f,static_argnums=0)(5)

Array(10, dtype=int32, weak_type=True)

```python
def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val
```

In [87]:
def f(x):
    return jax.lax.fori_loop(0,x,jnp.add,0)
    
jaxpr_lax =jax.make_jaxpr(f)(5)
jaxpr_lax

{ lambda ; a:i32[]. let
    _:i32[] _:i32[] b:i32[] = while[
      body_jaxpr={ lambda ; c:i32[] d:i32[] e:i32[]. let
          f:i32[] = add c 1
          g:i32[] = add c e
        in (f, d, g) }
      body_nconsts=0
      cond_jaxpr={ lambda ; h:i32[] i:i32[] j:i32[]. let
          k:bool[] = lt h i
        in (k,) }
      cond_nconsts=0
    ] 0 a 0
  in (b,) }

In [88]:
jax.jit(f)(5)

Array(10, dtype=int32, weak_type=True)