# Project 2

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

ReLU activation function

In [9]:
def relu(x):
    return jnp.maximum(0.0, x)

Sigmoid activation function

In [27]:
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

Flexible neural network architecture with initial random weights and biases

In [28]:
in_dim = 1
out_dim = 1
hidden_dims = [2]          
num_layers = len(hidden_dims) + 1

weights = []
biases = []

class NeuralNetwork:
    def __init__(self, in_dim, out_dim, hidden_dims, key):
        self.layer_dims = [in_dim] + hidden_dims + [out_dim]
        self.num_layers = len(self.layer_dims) - 1

        self.params = self.init_params(key)

    def init_params(self, key):
        params = []
        keys = jax.random.split(key, self.num_layers)

        for i in range(self.num_layers):
            k = keys[i]
            W = jax.random.normal(k, (self.layer_dims[i], self.layer_dims[i + 1]))
            b = jnp.zeros((self.layer_dims[i + 1],))
            params.append((W, b))

        return params


def forward(params, x, activation):
    h = x
    for i, (W, b) in enumerate(params):
        z = h @ W + b
        if i < len(params) - 1:
            h = activation(z)
        else:
            h = z
    return h


MSE Loss

In [29]:
def mse_loss(params,  x, y_true, activation):
    y_pred = forward(params, x, activation)
    return jnp.mean((y_pred - y_true) ** 2)


Training step using automatic differentation from JAX

In [34]:
from functools import partial

@partial(jax.jit,  static_argnames=("activation",))
def train_step(params, x, y, lr, activation):
    loss, grads = jax.value_and_grad(mse_loss)(params, x, y, activation)

    new_params = [
        (W - lr * dW, b - lr * db)
        for (W, b), (dW, db) in zip(params, grads)
    ]

    return new_params, loss

Simplest initial example, $ y = Ax + b$.

In [35]:
# Dimensions
num_samples = 10

# Input data
x = jnp.linspace(0, 1, num_samples).reshape(-1, in_dim)

# Random ground-truth linear model
key = jax.random.PRNGKey(42)
key_A, key_b = jax.random.split(key)

A_true = jax.random.normal(key_A, (in_dim, out_dim))
b_true = jax.random.normal(key_b, (out_dim,))

# Generate targets
y = x @ A_true + b_true

print(y)


[[0.60576403]
 [0.6142002 ]
 [0.6226364 ]
 [0.6310725 ]
 [0.6395087 ]
 [0.64794487]
 [0.6563811 ]
 [0.6648172 ]
 [0.6732534 ]
 [0.68168956]]


Initialise and train network

In [43]:
lr = 1e-1
model = NeuralNetwork(in_dim, out_dim, hidden_dims, key)
params = model.params
activation = relu

losses = []

for epoch in range(2000):
    params, loss = train_step(params, x, y, lr, activation)

    losses.append(loss)

    if epoch % 100 == 0:
        print(f"epoch {epoch}, loss: {loss:.12f}")

epoch 0, loss: 0.385397374630
epoch 100, loss: 0.000082268067
epoch 200, loss: 0.000031653235
epoch 300, loss: 0.000012800811
epoch 400, loss: 0.000005217569
epoch 500, loss: 0.000002131869
epoch 600, loss: 0.000000872130
epoch 700, loss: 0.000000357049
epoch 800, loss: 0.000000146241
epoch 900, loss: 0.000000059920
epoch 1000, loss: 0.000000024555
epoch 1100, loss: 0.000000010062
epoch 1200, loss: 0.000000004128
epoch 1300, loss: 0.000000001692
epoch 1400, loss: 0.000000000694
epoch 1500, loss: 0.000000000284
epoch 1600, loss: 0.000000000116
epoch 1700, loss: 0.000000000048
epoch 1800, loss: 0.000000000020
epoch 1900, loss: 0.000000000008


Final prediction

In [44]:
y_pred = forward(params, x, activation)
print("Predictions:", y_pred)
print("True targets:", y)

Predictions: [[0.605766  ]
 [0.61419713]
 [0.622634  ]
 [0.63107085]
 [0.6395077 ]
 [0.64794457]
 [0.6563814 ]
 [0.6648183 ]
 [0.6732552 ]
 [0.68169206]]
True targets: [[0.60576403]
 [0.6142002 ]
 [0.6226364 ]
 [0.6310725 ]
 [0.6395087 ]
 [0.64794487]
 [0.6563811 ]
 [0.6648172 ]
 [0.6732534 ]
 [0.68168956]]
