In [3]:
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split


In [4]:
# create dataset
X, y, β = make_regression(n_features=3, coef=True, random_state=42)
X, X_test, y, y_test = train_test_split(X, y)
X, X_test = jnp.c_[jnp.ones(X.shape[0]), X], jnp.c_[jnp.ones(X_test.shape[0]), X_test]
β, X[:2, :], y[:2]

(array([75.06147516, 28.20345726, 17.74395438]),
 Array([[ 1.        ,  0.17136829,  0.19686124,  0.73846656],
        [ 1.        , -1.4247482 , -0.2257763 ,  0.0675282 ]],      dtype=float32),
 array([  31.51864074, -112.11315545]))

In [5]:
# param dict
params = {
    "b": jnp.zeros(X.shape[1]),
}


# forward pass: Xbeta
def forward(params, X):
    return jnp.dot(X, params["b"])


@jax.jit
def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse


grad_fn = jax.grad(loss_fn)


def update(params, grads, lr=0.05):
    return jax.tree.map(lambda p, g: p - lr * g, params, grads)

In [6]:
# the main training loop
for _ in range(50):
    loss = loss_fn(params, X_test, y_test)
    grads = grad_fn(params, X, y)
    params = update(params, grads)


In [7]:
β, params['b'][1:]

(array([75.06147516, 28.20345726, 17.74395438]),
 Array([74.25435 , 26.689371, 16.658165], dtype=float32))