# Jax stuff

In [12]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

TF_CPP_MIN_LOG_LEVEL = 0

In [8]:
# generate random numbers
key = random.PRNGKey(0)
print(key)
x = random.normal(key, (10,))
print(x)

[0 0]
[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [11]:
#multiply two matrices
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
print(x)
%timeit jnp.dot(x, x.T).block_until_ready()  # time execution?


[[ 1.3890220e+00 -3.2292119e-01  1.5543443e-01 ...  1.6672333e-01
   1.0217550e+00  9.6981764e-02]
 [ 1.0637628e+00 -1.8089763e+00 -7.7909984e-02 ...  1.1778636e+00
  -4.3357372e-01 -2.7877533e-01]
 [-4.4029754e-01 -3.2537547e-01  2.7817255e-01 ...  6.8317270e-01
  -6.1108190e-01 -6.3071573e-01]
 ...
 [ 2.9218230e-01 -4.0055802e-01 -1.4978158e+00 ...  3.0673659e+00
  -1.1350130e+00  4.0964666e-01]
 [ 2.7635786e-01  1.5621810e-01  2.2997444e-03 ...  6.8930797e-02
  -4.0692501e-02  4.1683877e-01]
 [ 1.0231308e+00 -2.7423611e-01 -8.0369931e-01 ...  1.9415886e+00
   1.0946991e+00  2.1876085e+00]]
923 ms ± 9.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
# numpy try
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

835 ms ± 20.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
# using jit to speed up function
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

print("without jit")
%timeit selu(x).block_until_ready()

print("\n")

print("with jit")
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

without jit
7.44 ms ± 503 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


with jit
1.46 ms ± 53.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [19]:
# using grad to do autograd
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [21]:
#using vmap to apply functions to multiple inputs
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Naively batched
2.81 ms ± 51.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Manually batched
27.3 µs ± 638 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
