# Linear Regression with JAX

## Resources
https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html     
https://www.youtube.com/watch?v=aOsZdf9tiNQ    


In [2]:
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [6]:
# The dataset
X, y = make_regression(n_features=3)
print(f"X.shape: {X.shape}, y.shape: {y.shape}")
X, X_test, y, y_test = train_test_split(X, y)
print(f"X_test.shape: {X_test.shape}, y_test.shape: {y_test.shape}")
print(f"X[:5]:\n{X[:5]}") # X first 5 rows
print(f"y[:5]:\n {y[:5]}") # y first 5 rows

X.shape: (100, 3), y.shape: (100,)
X_test.shape: (25, 3), y_test.shape: (25,)
X[:5]:
[[-1.02693272 -0.61965346  0.79635043]
 [ 1.31763649  0.0146806  -0.19200161]
 [-0.80849802 -0.27264047 -0.19800573]
 [-0.06227385  0.51198102  0.77566437]
 [ 0.13687259 -0.23795419  0.50641083]]
y[:5]:
 [ 28.89022652  25.63730128 -42.84691442  63.86981837  43.92067942]


## At 2:18
https://www.youtube.com/watch?v=aOsZdf9tiNQ

In [1]:
# model weights
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}


def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']


def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse


grad_fn = jax.grad(loss_fn)


def update(params, grads):
    # return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
    return jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)


# the main training loop
for _ in range(100):
    loss = loss_fn(params, X_test, y_test)
    print(loss)

    grads = grad_fn(params, X, y)
    params = update(params, grads)

6324.3447
4989.4507
3943.51
3122.6614
2477.3916
1969.2719
1568.4434
1251.6764
1000.876
801.92834
643.8087
517.8933
417.4257
337.1038
272.76065
221.11526
179.58002
146.11037
119.088
97.22909
79.5139
65.13061
53.431522
43.89912
36.11886
29.758179
24.549862
20.278547
16.770483
13.885185
11.508829
9.549117
7.93093
6.593188
5.48601
4.568626
3.807761
3.176043
2.6510909
2.2144372
1.850931
1.5480798
1.2955558
1.0848532
0.90891266
0.7619066
0.6389973
0.5361677
0.45008776
0.3779931
0.31757948
0.2669247
0.22443938
0.1887836
0.1588459
0.13370165
0.112573184
0.0948128
0.07987784
0.067315154
0.056745518
0.047846146
0.040352724
0.03404078
0.028723419
0.02424269
0.020465327
0.01727943
0.014593503
0.012327169
0.010414496
0.008800519
0.007437313
0.0062866476
0.005315062
0.0044939225
0.0038006522
0.0032144056
0.00271894
0.0023002422
0.0019464453
0.0016471745
0.0013938341
0.0011799347
0.0009987504
0.0008457442
0.00071599
0.00060615357
0.0005134157
0.00043485928
0.00036836835
0.00031201774
0.00026431205
0.