In [20]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


In [21]:
def relu(x):
    return jnp.maximum(0.0, x)

def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

In [22]:
in_dim = 1
out_dim = 1
hidden_dims = [2]          
num_layers = len(hidden_dims) + 1

weights = []
biases = []

class NeuralNetwork:
    def __init__(self, in_dim, out_dim, hidden_dims, key):
        self.layer_dims = [in_dim] + hidden_dims + [out_dim]
        self.num_layers = len(self.layer_dims) - 1
        self.params = self.init_params(key)

    def init_params(self, key):
        params = []
        keys = jax.random.split(key, self.num_layers)

        for i in range(self.num_layers):
            k = keys[i]
            W = jax.random.normal(k, (self.layer_dims[i], self.layer_dims[i + 1]))
            b = jnp.zeros((self.layer_dims[i + 1],))
            params.append((W, b))

        return params

def forward(params, x, activation):
    h = x
    for i, (W, b) in enumerate(params):
        z = h @ W + b
        if i < len(params) - 1:
            h = activation(z)
        else:
            h = z
    return h


In [23]:
def mse_loss(params,  x, y_true, activation):
    y_pred = forward(params, x, activation)
    return jnp.mean((y_pred - y_true) ** 2)

In [33]:
from functools import partial

@partial(jax.jit,  static_argnames=("activation",))

def train_step(params, x, y, lr, activation):
    loss, grads = jax.value_and_grad(mse_loss)(params, x, y, activation)

    new_params = [
        (W - lr * dW, b - lr * db)
        for (W, b), (dW, db) in zip(params, grads)
    ]

    return new_params, loss

In [25]:
def make_train_step(activation):
    @jax.jit
    def train_step_jit(params, x, y, lr):
        loss, grads = jax.value_and_grad(mse_loss)(
            params, x, y, activation
        )
        new_params = [
            (W - lr * dW, b - lr * db)
            for (W, b), (dW, db) in zip(params, grads)
        ]
        return new_params, loss

    return train_step_jit

In [26]:
# Dimensions
num_samples = 10

# Input data
x = jnp.linspace(0, 1, num_samples).reshape(-1, in_dim)

# Random ground-truth linear model
key = jax.random.PRNGKey(42)
key_A, key_b = jax.random.split(key)

A_true = jax.random.normal(key_A, (in_dim, out_dim))
b_true = jax.random.normal(key_b, (out_dim,))

# Generate targets
y = x @ A_true + b_true

print(y)

[[1.3694694]
 [1.384792 ]
 [1.4001145]
 [1.4154371]
 [1.4307597]
 [1.4460824]
 [1.4614049]
 [1.4767275]
 [1.49205  ]
 [1.5073726]]


In [27]:
lr = 1e-1
model = NeuralNetwork(in_dim, out_dim, hidden_dims, key)
params = model.params
activation = relu

losses = []

for epoch in range(2000):
    params, loss = make_train_step(params, x, y, lr, activation)

    losses.append(loss)

    if epoch % 100 == 0:
        print(f"epoch {epoch}, loss: {loss:.12f}")

TypeError: make_train_step() takes 1 positional argument but 5 were given

In [28]:
y_pred = forward(params, x, activation)
print("Predictions:", y_pred)
print("True targets:", y)

Predictions: [[0.        ]
 [0.1230425 ]
 [0.246085  ]
 [0.36912754]
 [0.49217   ]
 [0.61521256]
 [0.7382551 ]
 [0.8612976 ]
 [0.98434   ]
 [1.1073825 ]]
True targets: [[1.3694694]
 [1.384792 ]
 [1.4001145]
 [1.4154371]
 [1.4307597]
 [1.4460824]
 [1.4614049]
 [1.4767275]
 [1.49205  ]
 [1.5073726]]


# Exercise 3a

Generating the data and splitting it into 80/20 train/test datasets, where we shuffle to avoid the model from learning orders. 

In [29]:
import numpy as np


# generate training data

def training_data(k,N):
    x = jnp.linspace(-1.0, 1.0, N).reshape(-1, 1)
    y = jnp.sin(k * jnp.pi * x)
    return x,y

x, y = training_data(1,100)

def train_test_split(x, y, test_ratio=0.2, key=jax.random.PRNGKey(0)):
    N = x.shape[0]
    perm = jax.random.permutation(key, N)

    x_shuffled = x[perm]
    y_shuffled = y[perm]

    test_size = int(test_ratio * N)

    x_test = x_shuffled[:test_size]
    y_test = y_shuffled[:test_size]

    x_train = x_shuffled[test_size:]
    y_train = y_shuffled[test_size:]

    return x_train, y_train, x_test, y_test




Training function using batch gradient descent, returning epoch losses and model parameters

In [30]:
def train_nn(key, hidden_dims, activation, lr, batch_size, x_train, y_train, num_epochs=2000):
    model = NeuralNetwork(
        in_dim = 1,
        out_dim = 1,
        hidden_dims = hidden_dims, 
        key = key
    )

    params = model.params

    num_samples = x_train.shape[0]
    num_batches = num_samples // batch_size

    epoch_losses = []
    
    for epoch in range(num_epochs):
        key, subkey = jax.random.split(key)
        perm = jax.random.permutation(subkey, num_samples)

        x_shuffled = x_train[perm]
        y_shuffled = y_train[perm]

        total_loss = 0
        for i in range(0, num_samples, batch_size):
            xb = x_shuffled[i: i + batch_size]
            yb = y_shuffled[i: i + batch_size]
                
            params, loss = train_step(
                params, 
                xb, 
                yb, 
                lr, 
                activation
            )
            total_loss += loss
        epoch_mse = mse_loss(params, x_train, y_train, activation)
        epoch_losses.append(epoch_mse)

    return params, epoch_losses

Implement n_folds cross validation function

In [31]:
def make_folds(key, N, n_folds=5):
    perm = jax.random.permutation(key, N)
    return jnp.array_split(perm, n_folds)

def cross_val_score(key, hidden_dims, activation, lr, batch_size, x_train, y_train, num_epochs, n_folds = 5):

    N = x.shape[0]
    folds = make_folds(key, N, n_folds)
    val_losses = []

    for i in range(n_folds):
        val_idx = folds[i]
        train_idx = jnp.concatenate(
            [folds[j] for j in range(n_folds) if j != i]
        )

        x_train = x[train_idx]
        y_train = y[train_idx]
        x_val = x[val_idx]
        y_val = y[val_idx]

        key, subkey = jax.random.split(key)

        params, loss = train_nn(
            key=subkey,
            hidden_dims=hidden_dims,
            activation=activation,
            lr=lr,
            batch_size=batch_size,
            x_train=x_train,
            y_train=y_train,
            num_epochs=num_epochs
        )

        val_loss = mse_loss(params, x_val, y_val, activation)
        val_losses.append(val_loss)
    return jnp.mean(jnp.array(val_losses))

    

Implementing grid search function and printing best configuration and associated MSE.

In [32]:
# grid search for hyperparameter optimization for fitting f1
# Hyperparameters to vary (Question 2a)

param_grid = {
    'architectures': [[128], [128, 64]],
    'learning_rates': [0.1, 0.01],
    'batch_sizes': [32],
    'activations': [relu, sigmoid],
    'init_type': ["xavier"]
}
    

def grid_search_cv(
    param_grid,
    key,
    x_train,
    y_train,
    num_epochs=300,
    n_folds=5
):
    best_loss = jnp.inf
    best_config = None

    keys = jax.random.split(
        key,
        len(param_grid['learning_rates'])
        * len(param_grid['architectures'])
        * len(param_grid['batch_sizes'])
        * len(param_grid['activations'])
    )

    idx = 0

    for lr in param_grid['learning_rates']:
        for arch in param_grid['architectures']:
            for batch_size in param_grid['batch_sizes']:
                for activation in param_grid['activations']:
                    subkey = keys[idx]
                    idx += 1

                    cv_loss = cross_val_score(
                        key=subkey,
                        hidden_dims=arch,
                        activation=activation,
                        lr=lr,
                        batch_size=batch_size,
                        x_train=x_train,
                        y_train=y_train,
                        num_epochs=num_epochs,
                        n_folds=n_folds
                    )

                    print(
                        f"lr={lr}, "
                        f"arch={arch}, "
                        f"batch_size={batch_size}, "
                        f"activation={activation.__name__}, "
                        f"CV MSE={cv_loss:.6f}"
                    )

                    if cv_loss < best_loss:
                        best_loss = cv_loss
                        best_config = {
                            'learning_rate': lr,
                            'architecture': arch,
                            'batch_size': batch_size,
                            'activation': activation
                        }

    return best_config, best_loss




    

In [None]:
key = jax.random.PRNGKey(0)
x_train, y_train, x_test, y_test = train_test_split(x,y)

best_config, best_loss = grid_search_cv(
    param_grid=param_grid,
    key=key,
    x_train=x_train,
    y_train=y_train,
    num_epochs=1500,
    n_folds=5
)

print("\nBest hyperparameters:")
print(best_config)
print("Best CV MSE:", best_loss)


Using the best hyperparameters to train the NN and plotting the predictions versus the true function. Note that since we took random test points in $[-1,1]$, the plot does not look like the sine function.

In [12]:
best_params, best_epoch_losses = train_nn(key, best_config['architecture'],best_config['activation'], best_config['learning_rate'], best_config['batch_size'], x_train, y_train, num_epochs=2000)

y_pred = forward(
    best_params,
    x_test,
    best_config['activation']
)

test_mse = jnp.mean((y_pred - y_test) ** 2)
print("Test MSE:", test_mse)



# sort test points for a clean plot
idx = jnp.argsort(x_test[:, 0])
x_sorted = x_test[idx]
y_true_sorted = y_test[idx]
y_pred_sorted = y_pred[idx]

plt.figure()
plt.plot(x_sorted, y_true_sorted, label="True function", linewidth=2)
plt.plot(x_sorted, y_pred_sorted, "--", label="Model prediction", linewidth=2)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Model prediction vs true function")
plt.legend()
plt.grid(True)
plt.show()


NameError: name 'best_config' is not defined

# Exercise 3b)

The grid search with 5-fold cross validation outputs optimal parameters: 
For a large learning rate of 0.1, we see that the neural network architecture of two hidden layers with 128, 64 dimensions outputs NaN values. We see that this problem doesnt appear for lower learning rates which indicates that the gradient might be the problem. Due to the multiplication of the learning rate with the gradient, the gradient might explode leading to NaNs.
For smaller learning rates this architecture actually performs better than the shallow neural network. This may be due to the flexibility of the neural network, as the increasement of weights allows for more learning capacity. 

The following is for k = 3




In [None]:
# generate training data varying k 
def generate_data_k(k, N):
    xk, yk = training_data(k, N)
    xk_train, yk_train, xk_test, yk_test = train_test_split(xk, yk)
    return xk_train, yk_train, xk_test, yk_test

key = jax.random.PRNGKey(0)

param_grid = {
    'architectures': [[128], [128, 64]],
    'learning_rates': [0.01, 0.001],
    'batch_sizes': [32],
    'activations': [relu, sigmoid],
    'init_type': ["xavier"]
}

def find_best_config(param_grid, xk_train, yk_train, key):
    # find best configuration 
    best_config, best_loss = grid_search_cv(
        param_grid=param_grid,
        key=key,
        x_train=xk_train,
        y_train=yk_train,
        num_epochs=1500,
        n_folds=5
    )

    print("\nBest hyperparameters:")
    print(best_config)
    print("Best CV MSE:", best_loss)
    return best_config, best_loss

def predict_and_plot_test(best_config, xk_train, yk_train, xk_test, yk_test, key):
    best_params, best_epoch_losses = train_nn(key, best_config['architecture'],best_config['activation'], best_config['learning_rate'], best_config['batch_size'], xk_train, yk_train, num_epochs=2000)

    x_plot = jnp.linspace(-1.0, 1.0, 2000).reshape(-1, 1)

    y_pred = forward(
        best_params,
        xk_test,
        best_config['activation']
    )

    test_mse = jnp.mean((y_pred - yk_test) ** 2)
    print("Test MSE:", test_mse)
    
    
    
    # sort test points for a clean plot
    idx = jnp.argsort(xk_test[:, 0])
    x_sorted = xk_test[idx]
    y_true_sorted = yk_test[idx]
    y_pred_sorted = y_pred[idx]
    
    plt.figure()
    plt.plot(x_sorted, y_true_sorted, label="True function", linewidth=2)
    plt.plot(x_sorted, y_pred_sorted, "--", label="Model prediction", linewidth=2)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title("Model prediction vs true function")
    plt.legend()
    plt.grid(True)
    plt.show()
    return test_mse

x3_train, y3_train, x3_test, y3_test = generate_data_k(3, 1000)
best_config_3, best_loss_3 =  find_best_config(param_grid, x3_train, y3_train, key)
test_mse_3 = predict_and_plot_test(best_config_3, x3_train, y3_train, x3_test, y3_test, key)


lr=0.01, arch=[128], batch_size=32, activation=relu, CV MSE=0.000702
lr=0.01, arch=[128], batch_size=32, activation=sigmoid, CV MSE=0.022855
lr=0.01, arch=[128, 64], batch_size=32, activation=relu, CV MSE=nan
lr=0.01, arch=[128, 64], batch_size=32, activation=sigmoid, CV MSE=0.001967
lr=0.001, arch=[128], batch_size=32, activation=relu, CV MSE=0.019760
lr=0.001, arch=[128], batch_size=32, activation=sigmoid, CV MSE=0.153035
lr=0.001, arch=[128, 64], batch_size=32, activation=relu, CV MSE=0.004252
lr=0.001, arch=[128, 64], batch_size=32, activation=sigmoid, CV MSE=0.016322

Best hyperparameters:
{'learning_rate': 0.01, 'architecture': [128], 'batch_size': 32, 'activation': <function relu at 0x00000155BE23F520>}
Best CV MSE: 0.0007018956
Test MSE: 0.10310117


## k = 5

In [None]:
x5_train, y5_train, x5_test, y5_test = generate_data_k(5, 1000)
best_config_5, best_loss_5 =  find_best_config(param_grid, x5_train, y5_train, key)
test_mse_5 = predict_and_plot_test(best_config_5, x5_train, y5_train, x5_test, y5_test, key)


## k = 10