### $\textbf{Task: Create a simple regression model and fit using jax optimisers}$

In [16]:
import jax.numpy as jnp
import jax

rng_key = jax.random.PRNGKey(42)
rng_key, data_rng, true_rng, noise_rng = jax.random.split(rng_key, 4)

N = 128

xs = jax.random.normal(data_rng, shape=(N, 10))
xs = jnp.concatenate([xs, jnp.ones((N, 1))], axis=1)
true_beta = (jax.random.uniform(true_rng, shape=(11, 1)) * 20) - 10
ys = xs @ true_beta + jax.random.normal(noise_rng, shape=(N, 1))

In [17]:
import optax
from jax.nn.initializers import lecun_uniform

rng_key, rng_init = jax.random.split(rng_key, 2)

initialiser = lecun_uniform()
weights = initialiser(rng_init, (11, 1), dtype=jnp.float32)

learning_rate = 5e-2
optimiser = optax.adam(learning_rate)
optimiser_state = optimiser.init(weights)

@jax.jit
def fwd(weights, batch_in):
    return batch_in @ weights

@jax.jit
def mse_loss(weights, batch_in, batch_out):
    return jnp.mean(jnp.square(fwd(weights, batch_in) - batch_out)) + jnp.mean(jnp.square(weights))

In [18]:
num_epochs = 100
batch_size = 32

for epoch in range(num_epochs):
    batch_losses = []
    for batch in range(N // batch_size):
        rng_key, rng_batch = jax.random.split(rng_key, 2)
        batch_indices = jax.random.choice(rng_batch, N, shape=(batch_size, ), replace=False)
        batch_in = xs[batch_indices, :]
        batch_out = ys[batch_indices]
        loss, grads = jax.value_and_grad(mse_loss, argnums=(0))(weights, batch_in, batch_out)
        updates, optimiser_state = optimiser.update(grads, optimiser_state)
        weights = optax.apply_updates(weights, updates)
        batch_losses.append(loss)
    if ((epoch + 1) % 10 == 0):
        print(jnp.array(batch_losses).mean())
        
print(f"True beta: {true_beta}")
print(f"Learned beta: {weights}")

199.41626
115.69794
68.3955
39.97103
33.384377
29.50718
25.952251
25.229612
24.513538
24.743563
True beta: [[ 1.2040024 ]
 [ 9.90509   ]
 [ 3.9389248 ]
 [-0.37490082]
 [ 7.4107857 ]
 [-4.6295214 ]
 [ 2.2635794 ]
 [ 5.3508472 ]
 [ 7.32815   ]
 [ 2.3363113 ]
 [-1.0534973 ]]
Learned beta: [[ 1.084684  ]
 [ 8.980898  ]
 [ 3.7051468 ]
 [-0.28055766]
 [ 6.8415504 ]
 [-4.022745  ]
 [ 1.7459644 ]
 [ 4.877995  ]
 [ 6.72596   ]
 [ 1.7789909 ]
 [-0.94609976]]
