<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/JIT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#SELU

In [1]:
import jax
import jax.numpy as jnp

In [2]:
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) #jnp.where(condition, value_if_true, value_if_false)

In [3]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,)) #generate a million random numbers
selu_jit = jax.jit(selu) #obtain a JIT-transformed version of the functionn
%timeit -n100 selu(x).block_until_ready()

1.3 ms ± 765 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
%timeit -n100 selu_jit(x).block_until_ready()

The slowest run took 5.28 times longer than the fastest. This could mean that an intermediate result is being cached.
222 µs ± 196 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
@jax.jit
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [6]:
z = selu(x) #warmup the function

In [7]:
%timeit -n100 selu_jit(x).block_until_ready()

147 µs ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [2]:
#backend control
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [3]:
selu_jit_cpu = jax.jit(selu, backend = 'cpu')
selu_jit_gpu = jax.jit(selu, backend = 'gpu')

In [5]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,))

In [6]:
%timeit -n100 selu(x).block_until_ready() #uses gpu, just not JIT compiled

1.3 ms ± 799 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%timeit -n100 selu_jit_cpu(x).block_until_ready()

2.36 ms ± 680 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit -n100 selu_jit_gpu(x).block_until_ready()

The slowest run took 5.09 times longer than the fastest. This could mean that an intermediate result is being cached.
227 µs ± 196 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


##Controlling both backend and tensor device placement

In [10]:
x_cpu = jax.device_put(x, jax.devices('cpu')[0])
x_gpu = jax.device_put(x, jax.devices('gpu')[0])

In [11]:
%timeit -n100 selu(x_cpu).block_until_ready()

7.53 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%timeit -n100 selu(x_gpu).block_until_ready()

1.31 ms ± 754 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%timeit -n100 selu_jit_cpu(x_cpu).block_until_ready()

949 µs ± 65.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit -n100 selu_jit_gpu(x_gpu).block_until_ready()

154 µs ± 8.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


##Static arguments

In [15]:
def dense_layer(x, w, b, activation_func): #function parameterized by another function
  return activation_func(x*w+b)

In [16]:
x = jnp.array([1.0, 2.0, 3.0])
w = jnp.ones((3, 3))
b = jnp.ones(3)

In [17]:
dense_layer_jit = jax.jit(dense_layer)

In [18]:
dense_layer_jit(x, w, b, selu)

TypeError: Error interpreting argument to <function dense_layer at 0x79c804315a80> as an abstract array. The problematic value is of type <class 'function'> and was passed to the function at path activation_func.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

In [19]:
dense_layer_jit = jax.jit(dense_layer, static_argnums=3)

In [20]:
dense_layer_jit(x, w, b, selu)

Array([[2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804]], dtype=float32)

In [21]:
def dist(order, x, y):
  print('Compiling')
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

In [22]:
dist_jit = jax.jit(dist, static_argnums=0)

In [25]:
dist_jit(1, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0])) #compile function for the given parameter value and run

Compiling


Array(4., dtype=float32)

In [26]:
dist_jit(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0])) #compile function for another parameter value and run

Compiling


Array(2.828427, dtype=float32)

In [27]:
dist_jit(1, jnp.array([10.0, 10.0]), jnp.array([2.0, 2.0])) #function already compiled

Array(16., dtype=float32)

#static arguments for jit decorator

In [28]:
from functools import partial

@partial(jax.jit, static_argnums = 0)
def dist(order, x, y):
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

##compiling an impure function

In [29]:
global_state = 1 #global state to be used in an impure function. impure functions depend on a global state and /or. have side effects.jax strips side effects after first call and it is not looged in Jaxpr

def impure_function(x):
  print(f'Side-efect: printing x={x}') #side effet of an impure function
  y = x*global_state
  return y

In [30]:
impure_function_jit = jax.jit(impure_function)

In [33]:
impure_function_jit(10)

Side-efect: printing x=JitTracer<~int32[]>


Array(10, dtype=int32, weak_type=True)

In [34]:
impure_function_jit(10) #no side effects during second run

Array(10, dtype=int32, weak_type=True)

In [35]:
global_state = 2

In [36]:
impure_function_jit(10) #changed global state has no influence on the compiled function

Array(10, dtype=int32, weak_type=True)

In [37]:
impure_function(10)

Side-efect: printing x=10


20

##JAXPR

In [2]:
def f1(x, y, z):
  return jnp.sum(x + y * z)

In [8]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 2.0, 0.0]).T

In [4]:
jax.make_jaxpr(f1) (x, y, z)  #generates jaxpr

{ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[3,3][39m c[35m:f32[3][39m. [34;1mlet
    [39;22md[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] c
    e[35m:f32[3,3][39m = mul b d
    f[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] a
    g[35m:f32[3,3][39m = add f e
    h[35m:f32[][39m = reduce_sum[axes=(0, 1)] g
  [34;1min [39;22m(h,) }

In [2]:
def f2 (x,y):
  print(f'x={x}, y={y}, z={z}') #side effect
  return jnp.sum(x + y * z) #uses global variable z

In [5]:
f2_jaxpr = jax.make_jaxpr(f2) (x,y) #side effect z is present

x=JitTracer<float32[3]>, y=JitTracer<float32[3,3]>, z=[2. 2. 0.]


In [6]:
f2_jaxpr.jaxpr #doesn't capture side effect

{ [34;1mlambda [39;22ma[35m:f32[3][39m; b[35m:f32[3][39m c[35m:f32[3,3][39m. [34;1mlet
    [39;22md[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] a
    e[35m:f32[3,3][39m = mul c d
    f[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] b
    g[35m:f32[3,3][39m = add f e
    h[35m:f32[][39m = reduce_sum[axes=(0, 1)] g
  [34;1min [39;22m(h,) }

In [7]:
f2_jaxpr.consts #global varuable z is now a constant

[Array([2., 2., 0.], dtype=float32)]

##Tracing

In [6]:
def f3(x):
  y = x
  for i in range(5): #loop does not depend on an input parameter - good.
    y += i
  return y

In [7]:
jax.make_jaxpr(f3)(0) #unroll loop

{ [34;1mlambda [39;22m; a[35m:i32[][39m. [34;1mlet
    [39;22mb[35m:i32[][39m = add a 0:i32[]
    c[35m:i32[][39m = add b 1:i32[]
    d[35m:i32[][39m = add c 2:i32[]
    e[35m:i32[][39m = add d 3:i32[]
    f[35m:i32[][39m = add e 4:i32[]
  [34;1min [39;22m(f,) }

In [8]:
jax.jit(f3) (0)

Array(10, dtype=int32, weak_type=True)

In [9]:
def f4(x):
  y = x
  for i in range(x.shape[0]): #loop depends on an input parameter shape- good.
    y += x[i]
  return y

In [10]:
jax.make_jaxpr(f4)(jnp.array([1.0, 2.0, 3.0])) #loop is unrolled

{ [34;1mlambda [39;22m; a[35m:f32[3][39m. [34;1mlet
    [39;22mb[35m:f32[1][39m = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
    c[35m:f32[][39m = squeeze[dimensions=(0,)] b
    d[35m:f32[3][39m = add a c
    e[35m:f32[1][39m = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
    f[35m:f32[][39m = squeeze[dimensions=(0,)] e
    g[35m:f32[3][39m = add d f
    h[35m:f32[1][39m = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
    i[35m:f32[][39m = squeeze[dimensions=(0,)] h
    j[35m:f32[3][39m = add g i
  [34;1min [39;22m(j,) }

In [11]:
jax.jit(f4)(jnp.array([1.0, 2.0, 3.0]))

Array([7., 8., 9.], dtype=float32)

In [None]:
#depends on input parameter - crashes!
def f5(x):
  y = 0
  for i in range(x):
    y +=i
  return y