In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.config import config
config.update("jax_enable_x64", False)

In [2]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m), dtype = jnp.float32), scale * random.normal(b_key, (n,), dtype = jnp.float32)

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [1, 128, 128, 128, 128, 1]
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [3]:
def predict(params, x):
  activations = x
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = jnp.tanh(outputs)
  
  final_w, final_b = params[-1]
  y = jnp.dot(final_w, activations) + final_b
  return y

In [4]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
random_x = random.uniform(random.PRNGKey(0), shape = (100, 1), dtype = jnp.float32, 
						  minval = -1, maxval = 1)
batched_preds = batched_predict(params, random_x)
print(batched_preds.shape)

(100, 1)


In [5]:
def mse(x, y):
	return jnp.mean(jnp.sum(jnp.square(x - y), axis = 1))

def loss(params, x, y):
  preds = batched_predict(params, x)
  return mse(preds, y)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [6]:
x = random.uniform(random.PRNGKey(0), shape = (1000, 1), minval = 0.0, maxval = 1.0, dtype = jnp.float32)
f = lambda x: jnp.sin(x)
y = f(x)

In [7]:
epochs = 10000
step_size = 0.01

import time
start_time = time.time()
for epoch in range(1, epochs + 1):
	params = update(params, x, y)
print("AVG epoch time: ", (time.time() - start_time)/epochs)

AVG epoch time:  0.00034999020099639894
