# Linear Regression with JAX
### Using Supervised learning

## Resources
https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html     
https://www.youtube.com/watch?v=aOsZdf9tiNQ    


In [5]:
import jax
import jax.numpy as jnp # JAX's numpy module
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [32]:
EPOCHS = 50
LEARNING_RATE = 0.01

### Dataset

In [12]:
# The dataset
X, y = make_regression(n_features=3)
print(f"X.shape: {X.shape}") # 100 data points of 3 features.
print(f"y.shape: {y.shape}") # a float number for each data point.
X, X_test, y, y_test = train_test_split(X, y)
print(f"\nX_test.shape: {X_test.shape}")
print(f"y_test.shape: {y_test.shape}")
print(f"\nX.shape: {X.shape}")
print(f"y.shape: {y.shape}")
print(f"\nX[:5]:\n{X[:5]}") # X first 5 rows
print(f"\ny[:5]:\n {y[:5]}") # y first 5 rows

X.shape: (100, 3)
y.shape: (100,)

X_test.shape: (25, 3)
y_test.shape: (25,)

X.shape: (75, 3)
y.shape: (75,)

X[:5]:
[[-0.63314246  1.72607764  0.99368579]
 [ 1.26313396  0.45627519  2.04240246]
 [-1.81818504 -0.18394181 -1.12144175]
 [-0.76760346  0.21830801 -2.06919072]
 [-0.47773342  0.23747435  0.74183241]]

y[:5]:
 [  94.91726361  229.47039647 -191.33008049 -175.01757296   29.57724779]


### Model parameters

In [14]:
# model parameters, weights and bias
params = {
    'w': jnp.zeros(X.shape[1:]), # the shape of the input without the batch dimension, initialized to zeros.
    'b': 0. # a float, the y intercept.Initialized to zero.
}
params

{'w': Array([0., 0., 0.], dtype=float32), 'b': 0.0}

### Forward pass

In [19]:
def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']

# Sanity
print(forward(params, X))
print(f"\nforward pass shape: {forward(params, X).shape}")

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.]

forward pass shape: (75,)


### Loss function - MSE (Mean Squared Error)
MSE stands for Mean Squared Error. It is a common loss function used in regression models to measure the average of the squares of the errors—that is, the average squared difference between the estimated values and the actual value.

Here's a brief explanation of how MSE is calculated:

Calculate the error: Subtract the actual value from the predicted value for each data point.
Square the error: Square each of these errors to ensure they are positive and to penalize larger errors more.
Mean of squared errors: Calculate the mean (average) of these squared errors.
The formula for MSE is:

[ \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 ]

Where:

( n ) is the number of data points.
( y_i ) is the actual value.
( \hat{y}_i ) is the predicted value.

In [29]:
@jax.jit # Just-in-time compilation decorator, increase performance
def loss_fn(params, X, y):
    err = forward(params, X) - y # the error/residual is the prediction (forward pass return) - ground truth (y)
    return jnp.mean(jnp.square(err))  # return the MSE (Mean Squared Error)

# Sanity
print(loss_fn(params, X, y))
print(loss_fn(params, X_test, y_test))
loss_fn(params, X, y)

11719.045
16463.441


Array(11719.045, dtype=float32)

### Derivatives

In [30]:
# JAX calculates the gradient for us
# Takes the loss function, the parameters, and the input data as arguments
# Returns the derivative with respect to params
# Returns the same structure as the parameters
grad_fn = jax.grad(loss_fn)

# Sanity
grad_fn(params, X, y)

{'b': Array(14.363653, dtype=float32, weak_type=True),
 'w': Array([-110.79656, -105.89562, -190.3433 ], dtype=float32)}

### Update

In [31]:
def update(params, grads):
    """
    Update the parameters by taking a small step in the negative direction of teh gradients.
    """
    # return jax.tree_map(lambda p, g: p - LEARNING_RATE * g, params, grads)
    return jax.tree.map(lambda p, g: p - LEARNING_RATE * g, params, grads)

### Train

In [4]:
# the main training loop
for _ in range(EPOCHS):
    loss = loss_fn(params, X_test, y_test)
    print(loss)
    if loss < 0.001: # just for me, for output clarity
        break

    grads = grad_fn(params, X, y) # Don't calculate grads on the test set (that would be cheating)
    params = update(params, grads)

25933.582
20410.049
16064.913
12646.394
9956.582
7839.903
6174.0356
4862.8076
3830.5942
3017.922
2378.0134
1874.0739
1477.1584
1164.4943
918.1621
724.06116
571.0939
450.5239
355.47495
280.5327
221.43373
174.82013
138.04834
109.034424
86.13783
68.06493
53.796814
42.53005
33.631374
26.60159
21.046825
16.656607
13.185933
10.44149
8.270813
6.553452
5.194402
4.118541
3.2666554
2.591905
2.057276
1.6335433
1.2975923
1.0311568
0.8197599
0.65198505
0.5187718
0.41296473
0.32888302
0.26204428
0.20888999
0.16660687
0.1329463
0.10613999
0.08478629
0.06776581
0.05419131
0.04335986
0.03471623
0.027810927
0.022293288
0.017880581
0.014349177
0.011523679
0.009259203
0.007445076
0.005990444
0.00482298
0.0038861656
0.0031330797
0.0025274695
0.0020402547
0.001648122
0.0013322071
0.0010774925
0.00087196496
