# Linear Regression with JAX
### Using Supervised learning

## Resources
https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html     
https://www.youtube.com/watch?v=aOsZdf9tiNQ    


In [1]:
import jax
import jax.numpy as jnp # JAX's numpy module
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [2]:
EPOCHS = 50 # 500
LEARNING_RATE = 0.1 # 0.01

### Dataset

In [3]:
# The dataset
X, y = make_regression(n_features=3)
print(f"X.shape: {X.shape}") # 100 data points of 3 features.
print(f"y.shape: {y.shape}") # a float number for each data point.
X, X_test, y, y_test = train_test_split(X, y)
print(f"\nX_test.shape: {X_test.shape}")
print(f"y_test.shape: {y_test.shape}")
print(f"\nX.shape: {X.shape}")
print(f"y.shape: {y.shape}")
print(f"\nX[:5]:\n{X[:5]}") # X first 5 rows
print(f"\ny[:5]:\n {y[:5]}") # y first 5 rows

X.shape: (100, 3)
y.shape: (100,)

X_test.shape: (25, 3)
y_test.shape: (25,)

X.shape: (75, 3)
y.shape: (75,)

X[:5]:
[[ 0.55192182 -0.91468396  0.56304891]
 [ 0.2496638   1.33390755  0.55998413]
 [ 0.76567745  0.148808    2.10802447]
 [ 0.23562138 -0.43986716  0.41129841]
 [-0.69494925 -0.82299718  0.1765742 ]]

y[:5]:
 [ -25.22183307  167.56078302  157.39915714   -7.97071023 -116.64319789]


### Model parameters

In [4]:
# model parameters, weights and bias
params = {
    'w': jnp.zeros(X.shape[1:]), # the shape of the input without the batch dimension, initialized to zeros.
    'b': 0. # a float, the y intercept.Initialized to zero.
}
params

{'w': Array([0., 0., 0.], dtype=float32), 'b': 0.0}

### Forward pass

In [5]:
def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']

# Sanity
print(forward(params, X))
print(f"\nforward pass shape: {forward(params, X).shape}")

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.]

forward pass shape: (75,)


### Loss function - MSE (Mean Squared Error)
MSE stands for Mean Squared Error. It is a common loss function used in regression models to measure the average of the squares of the errors—that is, the average squared difference between the estimated values and the actual value.

Here's a brief explanation of how MSE is calculated:

Calculate the error: Subtract the actual value from the predicted value for each data point.
Square the error: Square each of these errors to ensure they are positive and to penalize larger errors more.
Mean of squared errors: Calculate the mean (average) of these squared errors.
The formula for MSE is:

[ \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 ]

Where:

( n ) is the number of data points.
( y_i ) is the actual value.
( \hat{y}_i ) is the predicted value.

In [6]:
@jax.jit # Just-in-time compilation decorator, increase performance
def loss_fn(params, X, y):
    err = forward(params, X) - y # the error/residual is the prediction (forward pass return) - ground truth (y)
    return jnp.mean(jnp.square(err))  # return the MSE (Mean Squared Error)

# Sanity
print(loss_fn(params, X, y))
print(loss_fn(params, X_test, y_test))
loss_fn(params, X, y)

13203.245
12876.607


Array(13203.245, dtype=float32)

### Derivatives

In [7]:
# JAX calculates the gradient for us
# Takes the loss function, the parameters, and the input data as arguments
# Returns the derivative with respect to params
# Returns the same structure as the parameters
grad_fn = jax.grad(loss_fn)

# Sanity
grad_fn(params, X, y)

{'b': Array(24.404274, dtype=float32, weak_type=True),
 'w': Array([-160.68765, -126.04015,  -85.49282], dtype=float32)}

### Update

In [8]:
def update(params, grads):
    """
    Update the parameters by taking a small step in the negative direction of the gradients.
    Uses a JAX utility of pytrees:
    https://jax.readthedocs.io/en/latest/pytrees.html
    The lambda function acts on the leaves of those pytrees.
    In this case it acts only on the values of the dictionaries and not the keys.
    For each leaf parameter we want to subtract the learning rate times (in the direction of) the gradient.

    """
    return jax.tree.map(lambda p, g: p - LEARNING_RATE * g, params, grads)

# Sanity
update(params, grad_fn(params, X, y))

{'b': Array(-2.4404275, dtype=float32, weak_type=True),
 'w': Array([16.068766, 12.604015,  8.549282], dtype=float32)}

### Train

In [9]:
# the main training loop
for _ in range(EPOCHS):
    print(f"\nparams:\n{forward(params, X)}")
    loss = loss_fn(params, X_test, y_test)
    print(f"\nloss: {loss}")
    if loss < 0.001: # just for me, for output clarity
        break

    grads = grad_fn(params, X, y) # Don't calculate grads on the test set (that would be cheating)

    """
    The following two lines are not optimal.
    Especially if we have many parameters, because then we will have to do this 
    for every parameter in the dictionary.
    Instead I use the 'update' function that uses JAX utility.
    """
    # params['w'] -= LEARNING_RATE * grads['w']
    # params['b'] -= LEARNING_RATE * grads['b']

    params = update(params, grads)


params:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.]

loss: 12876.607421875

params:
[ -0.28675103  23.171415    29.760736    -0.6820693  -22.470892
   5.287875   -38.988083   -13.195957   -15.183517   -19.503387
  31.319578    14.002565   -19.800957   -11.193906   -46.817005
  49.46776    -15.689999    -6.3046913   -0.9364116    8.747227
  32.027504    25.626486   -22.24257    -16.8154       1.4848778
 -38.23557     -4.3935556   -9.063895    -9.727563   -20.701487
 -38.02827      0.34953904  -0.30937696  -7.7313414  -21.404425
  27.971693   -25.910416   -15.61331    -26.905554   -28.363007
 -27.401587    -7.124913    -4.8257866  -37.007214    -4.2177315
   5.145053   -54.257355   -10.307273    -7.693754   -14.490046
  -2.1491182  -17.698475    -7.5010185   17.917505    17.975458
  23.342659    -4.68065