In [1]:
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split


In [7]:
# create our dataset
X, y, β = make_regression(n_features=3, coef = True, random_state=42)
X, X_test, y, y_test = train_test_split(X, y)

β


array([75.06147516, 28.20345726, 17.74395438])

In [9]:
# model weights
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}

def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']

@jax.jit
def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse

grad_fn = jax.grad(loss_fn)
def update(params, grads):
    return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)


In [10]:
# the main training loop
for _ in range(50):
    loss = loss_fn(params, X_test, y_test)
    grads = grad_fn(params, X, y)
    params = update(params, grads)
params


{'b': Array(-0.09771829, dtype=float32, weak_type=True),
 'w': Array([74.859276, 26.309761, 17.222357], dtype=float32)}

In [None]:
params
