# Project 2

In [29]:
import numpy as np
import jax
import jax.numpy as jnp

ReLU activation function

In [30]:
def relu(x):
    return jnp.maximum(0.0, x)

Flexible neural network architecture with initial random weights and biases

In [31]:
in_dim = 1
out_dim = 1
hidden_dims = []          
num_layers = len(hidden_dims) + 1

weights = []
biases = []

class NeuralNetwork:
    def __init__(self, in_dim, out_dim, hidden_dims, key):
        self.layer_dims = [in_dim] + hidden_dims + [out_dim]
        self.num_layers = len(self.layer_dims) - 1

        self.params = self.init_params(key)

    def init_params(self, key):
        params = []
        keys = jax.random.split(key, self.num_layers)

        for i in range(self.num_layers):
            k = keys[i]
            W = jax.random.normal(k, (self.layer_dims[i], self.layer_dims[i + 1]))
            b = jnp.zeros((self.layer_dims[i + 1],))
            params.append((W, b))

        return params


def forward(params, x):
    h = x
    for i, (W, b) in enumerate(params):
        z = h @ W + b
        if i < len(params) - 1:
            h = jnp.maximum(0.0, z)
        else:
            h = z
    return h


MSE Loss

In [32]:
def mse_loss(params, x, y_true):
    y_pred = forward(params, x)
    return jnp.mean((y_pred - y_true) ** 2)


Training step using automatic differentation from JAX

In [33]:
@jax.jit
def train_step(params, x, y, lr):
    loss, grads = jax.value_and_grad(mse_loss)(params, x, y)

    new_params = [
        (W - lr * dW, b - lr * db)
        for (W, b), (dW, db) in zip(params, grads)
    ]

    return new_params, loss

Simplest initial example, $ y = Ax + b$.

In [34]:
# Dimensions
num_samples = 10

# Input data
x = jnp.linspace(0, 1, num_samples).reshape(-1, in_dim)

# Random ground-truth linear model
key = jax.random.PRNGKey(42)
key_A, key_b = jax.random.split(key)

A_true = jax.random.normal(key_A, (in_dim, out_dim))
b_true = jax.random.normal(key_b, (out_dim,))

# Generate targets
y = x @ A_true + b_true


Initialise and train network

In [35]:
lr = 1e-3
model = NeuralNetwork(in_dim, out_dim, hidden_dims, key)
params = model.params

for epoch in range(100):
    params, loss = train_step(params, x, y, lr)

Final prediction

In [36]:
y_pred = forward(params, x)
print("Predictions:", y_pred)
print("True targets:", y)

Predictions: [[0.10733137]
 [0.12166817]
 [0.13600495]
 [0.15034175]
 [0.16467854]
 [0.17901534]
 [0.19335213]
 [0.20768893]
 [0.22202572]
 [0.23636252]]
True targets: [[0.60576403]
 [0.6142002 ]
 [0.6226364 ]
 [0.6310725 ]
 [0.6395087 ]
 [0.64794487]
 [0.6563811 ]
 [0.6648172 ]
 [0.6732534 ]
 [0.68168956]]
