In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import optax
import functools

jax.config.update("jax_enable_x64", True)
key = jax.random.PRNGKey(42)


In [2]:
n, p = 1000, 10

print(β := jax.random.uniform(key, (p,)))
X = jax.random.normal(key, (n, p))
y = X @ β + jax.random.normal(key, (n,))


[0.24195682 0.57397975 0.43901027 0.20791509 0.37068355 0.97989601
 0.97685815 0.36242998 0.32092    0.54494161]


## single-layer neural network with linear activation function (OLS)

In [3]:
@functools.partial(jax.vmap, in_axes=(None, 0))
def network(params, x):
    return jnp.dot(params, x)

def compute_loss(params, x, y):
    y_pred = network(params, x)
    loss = jnp.mean(optax.l2_loss(y_pred, y))
    return loss


Optimisation

In [4]:

optimizer = optax.adam(1e-1)
# Initialize parameters of the model + optimizer.
params = jnp.repeat(0.0, p)
opt_state = optimizer.init(params)
# A simple update loop.
for _ in range(1000):
    grads = jax.grad(compute_loss)(params, X, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)


In [10]:
print(np.c_[
        β,          # truth
        params,        # opt estimate
        np.linalg.lstsq(X, y, rcond=None)[0] # closed form
        ])


[[0.24195682 0.26143229 0.26143229]
 [0.57397975 0.55450794 0.55450794]
 [0.43901027 0.44335113 0.44335113]
 [0.20791509 0.18611259 0.18611259]
 [0.37068355 0.39626627 0.39626627]
 [0.97989601 1.01610712 1.01610712]
 [0.97685815 0.98650555 0.98650555]
 [0.36242998 0.37793706 0.37793706]
 [0.32092    0.2965661  0.2965661 ]
 [0.54494161 0.54387632 0.54387632]]


Works well.