Functional Programming with Pure Functions

In [1]:
import jax
from jax import jit, vmap, random
import jax.numpy as jnp
import numpy as np

In [2]:
n_pies = 0
def add_pies(pies_to_add: int):
    return n_pies + pies_to_add

print(jit(add_pies)(10))

10


In [3]:

n_pies = 20
print(jit(add_pies)(10))


10


In [4]:


def pure_add_pies(n_pies: int, pies_to_add: int):
    return n_pies + pies_to_add

print(jit(pure_add_pies)(0, 10))
print(jit(pure_add_pies)(20, 10))

10
30


Understanding the basics of immutability. 

In [5]:
base = np.array([1, 2, 3, 4, 5])
base[0] = 100
print(base)
# >>> [100, 2, 3, 4, 5]

base_jax = jnp.array([1, 2, 3, 4, 5])
try:
    base_jax[0] = 100
except:
    print("This is not possible in JAX")
print(base_jax)
# >>> TypeError: JAX arrays are immutable


[100   2   3   4   5]
This is not possible in JAX
[1 2 3 4 5]


In [6]:

base = np.array([1, 2, 3, 4, 5])
base[0] = 100
print(base)
# >>> [100, 2, 3, 4, 5]

base_jax = jnp.array([1, 2, 3, 4, 5])
# Notice the square brackets
# and the assignment to the variable updated_base 
# we'll see why we need the assignment a bit later
updated_base = base_jax.at[0].set(100)
print(updated_base)
# >>> [100, 2, 3, 4, 5]

[100   2   3   4   5]
[100   2   3   4   5]


Understanding JIT


In [7]:



def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

1.86 ms ± 60 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

426 µs ± 27.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


VMAP explanation

In [9]:
keygen = random.PRNGKey(0)
weights = random.normal(keygen, shape=(20, 2)) # hidden layer with 20 neurons and 2 inputs
features = random.normal(keygen, shape=(5, 2)) # batch of 5 samples with two features

def dotproduct(w, x):
    return jnp.dot(w, x) 

# dotproduct works on a single sample, but we want to apply it to all samples in the batch.
dotproduct(10, 20) #works :)

# now we want to apply it to all samples in the batch
dotproduct = jit(dotproduct)
try:
    dotproduct(weights, features)
except:
    print("This is not possible in JAX")
# >>> TypeError: Incompatible shapes for dot: got (20, 2) and (5, 2).



This is not possible in JAX


In [10]:
@jit
def linear_forward(weights, features):
    batch_dim_weights = None
    batch_dim_features = 0
    return vmap(dotproduct, in_axes=(batch_dim_weights, batch_dim_features))(weights, features)



preds = linear_forward(weights, features)
preds.shape



(5, 20)

In [23]:
from jax import grad
w = jnp.array([20.])
x = jnp.array([10.])
b = jnp.array([1.])
y_true = jnp.array([50.])
alpha = 0.0001
def forward(w, x, b):
    return w * x + b

def loss(y, y_true):
    return jnp.mean((y - y_true) ** 2)
for i in range(10000):
    y_pred = forward(w, x, b)
    
    grad_w, grad_b = grad(loss, argnums=(0, 1))(y_pred, y_true)
    w = w - alpha * grad_w
    b = b - alpha * grad_b
    if i % 1000 == 0:
        print(loss(y_pred, y_true))

22801.0
620.98846
16.912619
0.4606071
0.012546455
0.00034088735
9.035924e-06
3.541354e-07
3.541354e-07
3.541354e-07
