<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/JIT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#SELU

In [1]:
import jax
import jax.numpy as jnp

In [2]:
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) #jnp.where(condition, value_if_true, value_if_false)

In [3]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,)) #generate a million random numbers
selu_jit = jax.jit(selu) #obtain a JIT-transformed version of the functionn
%timeit -n100 selu(x).block_until_ready()

1.3 ms ± 765 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
%timeit -n100 selu_jit(x).block_until_ready()

The slowest run took 5.28 times longer than the fastest. This could mean that an intermediate result is being cached.
222 µs ± 196 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
@jax.jit
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [6]:
z = selu(x) #warmup the function

In [7]:
%timeit -n100 selu_jit(x).block_until_ready()

147 µs ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [2]:
#backend control
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
  """Scaled exponential linear unit activation function"""
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [3]:
selu_jit_cpu = jax.jit(selu, backend = 'cpu')
selu_jit_gpu = jax.jit(selu, backend = 'gpu')

In [5]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,))

In [6]:
%timeit -n100 selu(x).block_until_ready() #uses gpu, just not JIT compiled

1.3 ms ± 799 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%timeit -n100 selu_jit_cpu(x).block_until_ready()

2.36 ms ± 680 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit -n100 selu_jit_gpu(x).block_until_ready()

The slowest run took 5.09 times longer than the fastest. This could mean that an intermediate result is being cached.
227 µs ± 196 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


##Controlling both backend and tensor device placement

In [10]:
x_cpu = jax.device_put(x, jax.devices('cpu')[0])
x_gpu = jax.device_put(x, jax.devices('gpu')[0])

In [11]:
%timeit -n100 selu(x_cpu).block_until_ready()

7.53 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%timeit -n100 selu(x_gpu).block_until_ready()

1.31 ms ± 754 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%timeit -n100 selu_jit_cpu(x_cpu).block_until_ready()

949 µs ± 65.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit -n100 selu_jit_gpu(x_gpu).block_until_ready()

154 µs ± 8.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


##Static arguments

In [15]:
def dense_layer(x, w, b, activation_func): #function parameterized by another function
  return activation_func(x*w+b)

In [16]:
x = jnp.array([1.0, 2.0, 3.0])
w = jnp.ones((3, 3))
b = jnp.ones(3)

In [17]:
dense_layer_jit = jax.jit(dense_layer)

In [18]:
dense_layer_jit(x, w, b, selu)

TypeError: Error interpreting argument to <function dense_layer at 0x79c804315a80> as an abstract array. The problematic value is of type <class 'function'> and was passed to the function at path activation_func.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

In [19]:
dense_layer_jit = jax.jit(dense_layer, static_argnums=3)

In [20]:
dense_layer_jit(x, w, b, selu)

Array([[2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804]], dtype=float32)

In [21]:
def dist(order, x, y):
  print('Compiling')
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

In [22]:
dist_jit = jax.jit(dist, static_argnums=0)

In [25]:
dist_jit(1, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0])) #compile function for the given parameter value and run

Compiling


Array(4., dtype=float32)

In [26]:
dist_jit(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0])) #compile function for another parameter value and run

Compiling


Array(2.828427, dtype=float32)

In [27]:
dist_jit(1, jnp.array([10.0, 10.0]), jnp.array([2.0, 2.0])) #function already compiled

Array(16., dtype=float32)

#static arguments for jit decorator

In [28]:
from functools import partial

@partial(jax.jit, static_argnums = 0)
def dist(order, x, y):
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)