In [1]:
import time
import numpy as np
import jax
from jax import random, pmap
import jax.numpy as jnp
from jax import jit

In [2]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
print(jax.default_backend())

gpu
gpu


In [3]:
# numpy summing 
rand_vec = np.random.normal(size = int(5e6))
start_time = time.time()
np.sum(rand_vec)
print("numpy %s seconds" % (time.time() - start_time))

numpy 0.003044605255126953 seconds


In [4]:
# jax summing
key = random.PRNGKey(0)
rand_vec_jax = random.normal(key, (1,int(5e6)))
start_time = time.time()
jnp.sum(rand_vec_jax)
print("jax %s seconds" % (time.time() - start_time))

jax 0.19688200950622559 seconds


In [12]:
# running jax again reduces the execution time
key = random.PRNGKey(0)
rand_vec_jax = random.normal(key, (1,int(5e6)))
start_time = time.time()
jnp.sum(rand_vec_jax)
print("jax %s seconds" % (time.time() - start_time))

jax 0.00021505355834960938 seconds


In [5]:
# multiply array size by 10
rand_vec = np.random.normal(size = int(5e7))
start_time = time.time()
np.sum(rand_vec)
print("numpy %s seconds" % (time.time() - start_time))

numpy 0.032800912857055664 seconds


In [6]:
key = random.PRNGKey(1)
rand_vec_jax = random.normal(key, (1,int(5e7)))
print(rand_vec_jax[:10])
start_time = time.time()
rand_sum = jnp.sum(rand_vec_jax)
print(rand_sum)
print("jax %s seconds" % (time.time() - start_time))

[[-0.3393011  -0.5583501  -1.677907   ...  0.0059462  -0.13327362
   0.7063387 ]]
3356.6104
jax 0.1278674602508545 seconds


In [7]:
# running jax again
key = random.PRNGKey(2)
rand_vec_jax = random.normal(key, (1,int(5e7)))
print(rand_vec_jax[:10])
start_time = time.time()
rand_sum = jnp.sum(rand_vec_jax)
print(rand_sum)
print("jax %s seconds" % (time.time() - start_time))

[[ 0.8741801   0.2930515   0.28383994 ...  1.0543811  -1.1015279
   2.0108752 ]]
-3500.2896
jax 0.0010759830474853516 seconds


JIT

In [19]:
def jax_random_sum(random_seed):
  key = random.PRNGKey(random_seed)
  rand_vec_jax = random.normal(key, (1,int(5e7)))
  start_time = time.time()
  rand_sum = jnp.sum(rand_vec_jax)
  return rand_sum

In [26]:
%time jax_random_sum(3)

CPU times: user 3.83 ms, sys: 953 µs, total: 4.79 ms
Wall time: 3.56 ms


DeviceArray(2262.0051, dtype=float32)

In [35]:
%time jax_random_sum(3)

CPU times: user 3.45 ms, sys: 771 µs, total: 4.22 ms
Wall time: 3.62 ms


DeviceArray(2262.0051, dtype=float32)

In [27]:
#using just in time compilation further reduces execution time
compiled_jax_random_sum = jit(jax_random_sum)

In [38]:
%time compiled_jax_random_sum(4)

CPU times: user 1.26 ms, sys: 0 ns, total: 1.26 ms
Wall time: 937 µs


DeviceArray(-9441.299, dtype=float32)

In [39]:
%time compiled_jax_random_sum(4)

CPU times: user 704 µs, sys: 125 µs, total: 829 µs
Wall time: 576 µs


DeviceArray(-9441.299, dtype=float32)