In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
# JAX is a library for array orient numerical computation, with automatic differentiation and JIT compilation. 
# JAX features built-in JIT compilation via Open XLA, an open-source machine learning compiler ecosystem.
# JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.

In [3]:
import jax.numpy as jnp

In [4]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha*jnp.exp(x) - alpha)

In [5]:
x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


In [6]:
from jax import random

In [7]:
key = random.key(0)
x = random.normal(key, (1_000_000,))

In [8]:
%timeit selu(x).block_until_ready()

595 μs ± 49.5 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x) # Compiles on first call

In [10]:
%timeit selu_jit(x).block_until_ready()

50.6 μs ± 3.37 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [11]:
x.dtype

dtype('float32')

In [12]:
from jax import grad

In [13]:
def sum_logistic(x):
    return jnp.sum(1.0/(1.0 + jnp.exp(-x)))

In [14]:
x_small = jnp.arange(1.0)
derivative_fn = grad(sum_logistic)

In [23]:
jit(grad(grad(grad(grad(sum_logistic)))))(1.0)

Array(0.12350688, dtype=float32, weak_type=True)

In [16]:
x_small

Array([0., 1., 2.], dtype=float32)

In [25]:
from jax import jacobian

In [27]:
print(jit(jacobian(jnp.exp))(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]


In [28]:
# Auto-vectorization with jax.vmap(). It has familiar semantics of mapping a function along array axes, but instead of explicitly looping over function
# calls, it transforms the function into a natively vectorized version for better performance. 

In [29]:
key1, key2 = random.split(key)

In [30]:
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

In [31]:
def apply_matrix(x):
    return jnp.dot(mat, x)

In [37]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
2.31 ms ± 475 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [45]:
from jax import vmap

In [None]:
@jit
def vmap_batched_apply_matrix(batched_x):
    return vmap(apply_matrix)(batched_x)

# Vmap basically adds another dimention to the input of the original function, and applies the internal operations on all the dimensions. 
# (Not necessarily parallelly)

(10, 150)