In [11]:
import os
import time
import requests

import jax
import jax.numpy as jnp
from jax import jit, grad, random

from jax.config import config
%config IPCompleter.use_jedi = False

In [12]:
def apply_activation(x):
    return jnp.maximum(0.0, x)

def get_dot_product(W, X):
    return jnp.dot(W, X)

In [13]:
# Always use a seed
key = random.PRNGKey(1234)
W = random.normal(key=key, shape=[1000, 10000], dtype=jnp.float32)

# Never reuse the key
key, subkey = random.split(key)
X = random.normal(key=subkey, shape=[10000, 20000], dtype=jnp.float32)

In [14]:
# JIT the functions we have
dot_product_jit  = jit(get_dot_product)
activation_jit = jit(apply_activation)

In [15]:
for i in range(3):
    start = time.time()
    # Don't forget to use `block_until_ready(..)`
    # else you will be recording dispatch time only
    Z = dot_product_jit(W, X).block_until_ready()
    end = time.time()
    print(f"Iteration: {i+1}")
    print(f"Time taken to execute dot product: {end - start:.2f} seconds", end="")
    
    start = time.time()
    A = activation_jit(Z).block_until_ready()
    print(f", activation function: {time.time()-start:.2f} seconds")

Iteration: 1
Time taken to execute dot product: 0.23 seconds, activation function: 0.03 seconds
Iteration: 2
Time taken to execute dot product: 0.23 seconds, activation function: 0.01 seconds
Iteration: 3
Time taken to execute dot product: 0.22 seconds, activation function: 0.01 seconds


In [16]:
# Make jaxpr for the activation function
print(jax.make_jaxpr(activation_jit)(Z))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1000,20000][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[1000,20000][39m = xla_call[
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; c[35m:f32[1000,20000][39m. [34m[22m[1mlet
          [39m[22m[22md[35m:f32[1000,20000][39m = max 0.0 c
        [34m[22m[1min [39m[22m[22m(d,) }
      name=apply_activation
    ] a
  [34m[22m[1min [39m[22m[22m(b,) }
