### 3.3.1. Generating the Dataset

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import flax
import optax
import matplotlib.pyplot as plt

In [None]:
key = jax.random.PRNGKey(1071)

def synthetic_data(w, b, num_examples):
    """Generate y = Xw + b + noise."""
    X = jax.random.normal(key, shape=(num_examples, len(w)))
    y = jax.numpy.matmul(X, w) + np.random.normal(0, 0.01, size=(num_examples, 1))
    return X, y
true_w = jnp.array([[2, -3.4]])
true_b = 4.2
num_examples = 1000
features, labels = synthetic_data(jnp.transpose(true_w), true_b, num_examples)
print(jnp.shape(features))
print(jnp.shape(labels))
print('features:', features[0],'\nlabel:', labels[0])

plt.scatter(features[:, 0], labels)
plt.scatter(features[:, 1], labels)

In [None]:
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = jax.numpy.array(range(num_examples))
    # The examples are read at random, in no particular order
    jax.random.permutation(key, indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = jnp.array(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

In [None]:
batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

In [None]:
w = jax.random.normal(key, shape=(2, 1))
b = jnp.zeros(1)

In [None]:
w , b

In [None]:
def mean_squared_loss(y_hat, y):
    return jnp.mean(0.5 * jnp.square(y_hat - y.reshape(y.shape)))

### Define model in Flax

In [None]:
feature_dim = len(w)
model = flax.linen.Dense(features=feature_dim)

In [None]:
dummy_x = [1.0, 2.0] # trigger shape inference
params = model.init(key, dummy_x)
jax.tree_map(lambda x: x.shape, params)

In [None]:
model.apply(params, [1.0, 1.0])

In [None]:
def loss(params, inputs, y_hat):
    output = model.apply(params, inputs)
    return mean_squared_loss(y_hat, output)

loss(params, jnp.array([[1.0, 1.0], [2.0, 2.2]]), jnp.array([[2.0, 2.0], [1.0, 1.0]]))

In [None]:
tx = optax.sgd(learning_rate=0.01)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [None]:
for epoch in range(5):
    for X, y_hat in data_iter(batch_size, features, labels):
        l, grad = loss_grad_fn(params, X, y_hat)
        updates, opt_state = tx.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
    print("loss: {0}".format(l))