In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import jax
import jax.numpy as np

from jax import random
from jax import grad, jit, vmap

import numpy as onp

from IPython import display
from matplotlib import pyplot as plt

In [None]:
plt.rcParams['figure.figsize'] = ((10, 7.5))

In [None]:
key = random.PRNGKey(0)

# 3.2.1. Generating the Dataset

In [None]:
def synthetic_data(w, b, num_examples):
    """Generate y = X w + b + noise."""
    X = random.normal(key, (num_examples, len(w)))
    y = np.dot(X, w) + b
    y += onp.random.normal(0, 0.01, y.shape)
    return X, y

In [None]:
true_w = np.array([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

In [None]:
print('features:', features[0],'\nlabel:', labels[0])

In [None]:
plt.scatter(features[:, 1], labels, 1);

# 3.2.2. Reading the Dataset

In [None]:
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(key, np.array(indices))
    for i in range(0, num_examples, batch_size):
        batch_indices = np.array(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

In [None]:
batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

# 3.2.3. Initializing Model Parameters

In [None]:
w = np.array(onp.random.normal(0, 0.01, (2, 1)))
b = np.zeros(1)

In [None]:
w, b

# 3.2.4. Defining the Model

In [None]:
def linreg(X, w, b):
    return np.dot(X, w) + b

# 3.2.5. Defining the Loss Function

In [None]:
def squared_loss(y_hat, y):
    return 0.5 * (y_hat - y.reshape(y_hat.shape))**2 

In [None]:
# [NEW] JAX Exp
def linreg_loss(X, w, b, y):
    y_hat = np.dot(X, w) + b
    return (0.5 * (y_hat - y.reshape(y_hat.shape))**2).mean()

# 3.2.6. Defining the Optimization Algorithm

In [None]:
def sgd(params, grad, lr, batch_size):
    for i, param in enumerate(params):
        param -= lr * grad[i] / batch_size

# 3.2.7. Training

In [None]:
lr = 0.03  # Learning rate
num_epochs = 3  # Number of iterations
net = linreg  # Our fancy linear model
loss = squared_loss  # 0.5 (y-y')^2
grad_loss = grad(linreg_loss)

In [None]:
for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        # SGD Step
        w -= lr * grad(linreg_loss, argnums=1)(X, w, b, y) / batch_size
        b -= lr * grad(linreg_loss, argnums=2)(X, w, b, y) / batch_size
    train_l = loss(net(features, w, b), labels)
    print('epoch %d, loss %f' % (epoch + 1, train_l.mean()))

In [None]:
print('Error in estimating w', true_w - w.reshape(true_w.shape))
print('Error in estimating b', true_b - b)