[Reference](https://medium.com/@hylke.donker/jax-fast-as-pytorch-simple-as-numpy-a0c14893a738)

# Functional Programming

In [1]:
import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total

In [2]:
def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation

# Deterministic Samplers

In [3]:
import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

In [4]:
key = jax.random.PRNGKey(43)
# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

In [5]:
from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

# ∇ grad

In [6]:
from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

# vmap and pmap

In [7]:
def linear(x):
 return weights @ x

In [8]:
def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])

In [9]:
def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)