<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/JIT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#SELU

In [1]:
import jax
import jax.numpy as jnp

In [None]:
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) #jnp.where(condition, value_if_true, value_if_false)

In [None]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,)) #generate a million random numbers
selu_jit = jax.jit(selu) #obtain a JIT-transformed version of the functionn
%timeit -n100 selu(x).block_until_ready()

1.3 ms ± 765 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit -n100 selu_jit(x).block_until_ready()

The slowest run took 5.28 times longer than the fastest. This could mean that an intermediate result is being cached.
222 µs ± 196 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
@jax.jit
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [None]:
z = selu(x) #warmup the function

In [None]:
%timeit -n100 selu_jit(x).block_until_ready()

147 µs ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
#backend control
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [None]:
selu_jit_cpu = jax.jit(selu, backend = 'cpu')
selu_jit_gpu = jax.jit(selu, backend = 'gpu')

In [None]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,))

In [None]:
%timeit -n100 selu(x).block_until_ready() #uses gpu, just not JIT compiled

1.3 ms ± 799 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit -n100 selu_jit_cpu(x).block_until_ready()

2.36 ms ± 680 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit -n100 selu_jit_gpu(x).block_until_ready()

The slowest run took 5.09 times longer than the fastest. This could mean that an intermediate result is being cached.
227 µs ± 196 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


##Controlling both backend and tensor device placement

In [None]:
x_cpu = jax.device_put(x, jax.devices('cpu')[0])
x_gpu = jax.device_put(x, jax.devices('gpu')[0])

In [None]:
%timeit -n100 selu(x_cpu).block_until_ready()

7.53 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit -n100 selu(x_gpu).block_until_ready()

1.31 ms ± 754 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit -n100 selu_jit_cpu(x_cpu).block_until_ready()

949 µs ± 65.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit -n100 selu_jit_gpu(x_gpu).block_until_ready()

154 µs ± 8.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


##Static arguments

In [None]:
def dense_layer(x, w, b, activation_func): #function parameterized by another function
  return activation_func(x*w+b)

In [None]:
x = jnp.array([1.0, 2.0, 3.0])
w = jnp.ones((3, 3))
b = jnp.ones(3)

In [None]:
dense_layer_jit = jax.jit(dense_layer)

In [None]:
dense_layer_jit(x, w, b, selu)

TypeError: Error interpreting argument to <function dense_layer at 0x79c804315a80> as an abstract array. The problematic value is of type <class 'function'> and was passed to the function at path activation_func.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

In [None]:
dense_layer_jit = jax.jit(dense_layer, static_argnums=3)

In [None]:
dense_layer_jit(x, w, b, selu)

Array([[2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804]], dtype=float32)

In [None]:
def dist(order, x, y):
  print('Compiling')
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

In [None]:
dist_jit = jax.jit(dist, static_argnums=0)

In [None]:
dist_jit(1, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0])) #compile function for the given parameter value and run

Compiling


Array(4., dtype=float32)

In [None]:
dist_jit(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0])) #compile function for another parameter value and run

Compiling


Array(2.828427, dtype=float32)

In [None]:
dist_jit(1, jnp.array([10.0, 10.0]), jnp.array([2.0, 2.0])) #function already compiled

Array(16., dtype=float32)

#static arguments for jit decorator

In [None]:
from functools import partial

@partial(jax.jit, static_argnums = 0)
def dist(order, x, y):
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

##compiling an impure function

In [None]:
global_state = 1 #global state to be used in an impure function. impure functions depend on a global state and /or. have side effects.jax strips side effects after first call and it is not looged in Jaxpr

def impure_function(x):
  print(f'Side-efect: printing x={x}') #side effet of an impure function
  y = x*global_state
  return y

In [None]:
impure_function_jit = jax.jit(impure_function)

In [None]:
impure_function_jit(10)

Side-efect: printing x=JitTracer<~int32[]>


Array(10, dtype=int32, weak_type=True)

In [None]:
impure_function_jit(10) #no side effects during second run

Array(10, dtype=int32, weak_type=True)

In [None]:
global_state = 2

In [None]:
impure_function_jit(10) #changed global state has no influence on the compiled function

Array(10, dtype=int32, weak_type=True)

In [None]:
impure_function(10)

Side-efect: printing x=10


20

##JAXPR

In [None]:
def f1(x, y, z):
  return jnp.sum(x + y * z)

In [None]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 2.0, 0.0]).T

In [None]:
jax.make_jaxpr(f1) (x, y, z)  #generates jaxpr

{ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[3,3][39m c[35m:f32[3][39m. [34;1mlet
    [39;22md[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] c
    e[35m:f32[3,3][39m = mul b d
    f[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] a
    g[35m:f32[3,3][39m = add f e
    h[35m:f32[][39m = reduce_sum[axes=(0, 1)] g
  [34;1min [39;22m(h,) }

In [None]:
def f2 (x,y):
  print(f'x={x}, y={y}, z={z}') #side effect
  return jnp.sum(x + y * z) #uses global variable z

In [None]:
f2_jaxpr = jax.make_jaxpr(f2) (x,y) #side effect z is present

x=JitTracer<float32[3]>, y=JitTracer<float32[3,3]>, z=[2. 2. 0.]


In [None]:
f2_jaxpr.jaxpr #doesn't capture side effect

{ [34;1mlambda [39;22ma[35m:f32[3][39m; b[35m:f32[3][39m c[35m:f32[3,3][39m. [34;1mlet
    [39;22md[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] a
    e[35m:f32[3,3][39m = mul c d
    f[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] b
    g[35m:f32[3,3][39m = add f e
    h[35m:f32[][39m = reduce_sum[axes=(0, 1)] g
  [34;1min [39;22m(h,) }

In [None]:
f2_jaxpr.consts #global varuable z is now a constant

[Array([2., 2., 0.], dtype=float32)]

##Tracing

In [None]:
def f3(x):
  y = x
  for i in range(5): #loop does not depend on an input parameter - good.
    y += i
  return y

In [None]:
jax.make_jaxpr(f3)(0) #unroll loop

{ [34;1mlambda [39;22m; a[35m:i32[][39m. [34;1mlet
    [39;22mb[35m:i32[][39m = add a 0:i32[]
    c[35m:i32[][39m = add b 1:i32[]
    d[35m:i32[][39m = add c 2:i32[]
    e[35m:i32[][39m = add d 3:i32[]
    f[35m:i32[][39m = add e 4:i32[]
  [34;1min [39;22m(f,) }

In [None]:
jax.jit(f3) (0)

Array(10, dtype=int32, weak_type=True)

In [None]:
def f4(x):
  y = x
  for i in range(x.shape[0]): #loop depends on an input parameter shape- good.
    y += x[i]
  return y

In [None]:
jax.make_jaxpr(f4)(jnp.array([1.0, 2.0, 3.0])) #loop is unrolled

{ [34;1mlambda [39;22m; a[35m:f32[3][39m. [34;1mlet
    [39;22mb[35m:f32[1][39m = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
    c[35m:f32[][39m = squeeze[dimensions=(0,)] b
    d[35m:f32[3][39m = add a c
    e[35m:f32[1][39m = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
    f[35m:f32[][39m = squeeze[dimensions=(0,)] e
    g[35m:f32[3][39m = add d f
    h[35m:f32[1][39m = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
    i[35m:f32[][39m = squeeze[dimensions=(0,)] h
    j[35m:f32[3][39m = add g i
  [34;1min [39;22m(j,) }

In [None]:
jax.jit(f4)(jnp.array([1.0, 2.0, 3.0]))

Array([7., 8., 9.], dtype=float32)

In [None]:
#depends on input parameter - crashes!
def f5(x):
  y = 0
  for i in range(x): #loop depend on input parameter
    y +=i
  return y

In [None]:
f5(5)

10

In [None]:
jax.make_jaxpr(f5)(5)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function f5 at /tmp/ipython-input-1812372770.py:2 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [None]:
#same here
def relu(x):
  if x > 0:   #if statement depends on input parameter
    return x
  return 0.0

In [None]:
relu(10.0)

10.0

In [None]:
jax.make_jaxpr(relu)(10.0)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function relu at /tmp/ipython-input-739987649.py:2 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [None]:
jax.make_jaxpr(f5, static_argnums=0) (5) #same argument but we use static_argnums to mark first parameter as static as in JIT to escape

{ [34;1mlambda [39;22m; . [34;1mlet[39;22m  [34;1min [39;22m(10:i32[],) }

In [None]:
jax.jit(f5, static_argnums=0)(5)

Array(10, dtype=int32, weak_type=True)

In [None]:
jax.make_jaxpr(relu, static_argnums=0)(10.0)

{ [34;1mlambda [39;22m; . [34;1mlet[39;22m  [34;1min [39;22m(10.0:f32[],) }

In [None]:
jax.jit(relu, static_argnums=0)(10.0)

Array(10., dtype=float32, weak_type=True)

In [None]:
#another way to solve, use fori loop -> lax primitive
def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

In [None]:
def f5(x):
  return jax.lax.fori_loop(0, x, lambda i, v: v+i, 0) #replace the for loop

In [None]:
f5(5)

Array(10, dtype=int32, weak_type=True)

In [None]:
jax.make_jaxpr(f5)(5)

{ [34;1mlambda [39;22m; a[35m:i32[][39m. [34;1mlet
    [39;22m_[35m:i32[][39m _[35m:i32[][39m b[35m:i32[][39m = while[
      body_jaxpr={ [34;1mlambda [39;22m; c[35m:i32[][39m d[35m:i32[][39m e[35m:i32[][39m. [34;1mlet
          [39;22mf[35m:i32[][39m = add c 1:i32[]
          g[35m:i32[][39m = add e c
        [34;1min [39;22m(f, d, g) }
      body_nconsts=0
      cond_jaxpr={ [34;1mlambda [39;22m; h[35m:i32[][39m i[35m:i32[][39m j[35m:i32[][39m. [34;1mlet
          [39;22mk[35m:bool[][39m = lt h i
        [34;1min [39;22m(k,) }
      cond_nconsts=0
    ] 0:i32[] a 0:i32[]
  [34;1min [39;22m(b,) }

In [None]:
jax.jit(f5)(5)

Array(10, dtype=int32, weak_type=True)

HLO

In [2]:
def f(x, y, z):
  return jnp.sum(x + y * z)

In [3]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 2.0, 0.0]).T

In [7]:
f_jitted = jax.jit(f)

In [11]:
f_lowered= f_jitted.lower(x, y, z) #lowers the function, geenerating StableHLO

In [13]:
print(f_lowered.as_text())

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3xf32>) -> (tensor<f32> {jax.result_info = "result"}) {
    %0 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [0, 1] : (tensor<1x3xf32>) -> tensor<3x3xf32>
    %2 = stablehlo.multiply %arg1, %1 : tensor<3x3xf32>
    %3 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x3xf32>) -> tensor<3x3xf32>
    %5 = stablehlo.add %4, %2 : tensor<3x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %6 = stablehlo.reduce(%5 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x3xf32>, tensor<f32>) -> tensor<f32>
    return %6 : tensor<f32>
  }
}



In [14]:
f_compiled = f_jitted.lower(x, y, z).compile() #compiles the lowered function for the specific backend (generating HLO code)
print(f_compiled.as_text())

HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[3]{0}, f32[3,3]{1,0}, f32[3]{0})->f32[]}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}

%region_0.1 (reduce_sum.3: f32[], reduce_sum.4: f32[]) -> f32[] {
  %reduce_sum.3 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  %reduce_sum.4 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.5 = f32[] add(%reduce_sum.3, %reduce_sum.4), metadata={op_name="jit(f)/reduce_sum" source_file="/tmp/ipython-input-1761176124.py" source_line=2 source_end_line=2 source_column=9 source_end_column=26}
}

%fused_computation (param_0.3: f32[3,3], param_1.4: f32[3], param_2.3: f32[3]) -> f32[] {
  %param_2.3 = f32[3]{0} parameter(2)
  %add.1 = f32[3,3]{1,0} broadcast(%param_2.3), dimensions={1}, metadata={op_name="jit(f)/add" source_file="/tmp/ipython-input-1761176124.py" source_line=2 source_end_line=2 source_column=17 source_end_column=25}
  %param

##Compile with JIT and AOT

In [2]:
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  print("Function run") #side effect
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [3]:
selu_jit = jax.jit(selu)

In [4]:
selu_aot = jax.jit(selu).lower(1.0).compile() #AOT-compiled function,fictional 1.0 argument; needed for AOT to infer types

Function run


In [5]:
selu_jit(17.8)

Array(18.702477, dtype=float32, weak_type=True)

In [6]:
selu_aot(17.8)

Array(18.702477, dtype=float32, weak_type=True)