# Linear Regression with JAX
### Using Supervised learning

## Resources
https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html     
https://www.youtube.com/watch?v=aOsZdf9tiNQ    


In [1]:
import jax
import jax.numpy as jnp # JAX's numpy module
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [2]:
EPOCHS = 50 # 500
LEARNING_RATE = 0.1 # 0.01

### Dataset

In [3]:
# The dataset
X, y = make_regression(n_features=3)
print(f"X.shape: {X.shape}") # 100 data points of 3 features.
print(f"y.shape: {y.shape}") # a float number for each data point.
X, X_test, y, y_test = train_test_split(X, y)
print(f"\nX_test.shape: {X_test.shape}")
print(f"y_test.shape: {y_test.shape}")
print(f"\nX.shape: {X.shape}")
print(f"y.shape: {y.shape}")
print(f"\nX[:5]:\n{X[:5]}") # X first 5 rows
print(f"\ny[:5]:\n {y[:5]}") # y first 5 rows

X.shape: (100, 3)
y.shape: (100,)

X_test.shape: (25, 3)
y_test.shape: (25,)

X.shape: (75, 3)
y.shape: (75,)

X[:5]:
[[-0.08558383 -0.09131494  0.92575657]
 [-0.04107166 -0.88512454  1.54184891]
 [ 0.41668555  0.46752467  0.23619388]
 [-0.20825448  0.44464659 -0.53490139]
 [ 0.19027261  0.42155077 -1.20302263]]

y[:5]:
 [ 12.33603052 -17.5296469   45.65653555   6.39202706   3.76626716]


### Model parameters

In [4]:
# model parameters, weights and bias
params = {
    'w': jnp.zeros(X.shape[1:]), # the shape of the input without the batch dimension, initialized to zeros.
    'b': 0. # a float, the y intercept.Initialized to zero.
}
params

# model parameters, weights and bias
# params = {
#     'w': jax.random.uniform(key=jax.random.PRNGKey(0), shape=X.shape[1:]), # the shape of the input without the batch dimension, initialized to zeros.
#     'b': 0. # a float, the y intercept.Initialized to zero.
# }
# params

{'w': Array([0., 0., 0.], dtype=float32), 'b': 0.0}

### Forward pass

In [5]:
def forward(params, X):
    # return jnp.dot(X, params['w'].T) + params['b'] # Transpose the weights
    return jnp.dot(X, params['w']) + params['b']

# Sanity
print(forward(params, X))
print(f"\nforward pass shape: {forward(params, X).shape}")

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.]

forward pass shape: (75,)


### Loss function - MSE (Mean Squared Error)
MSE stands for Mean Squared Error. It is a common loss function used in regression models to measure the average of the squares of the errors—that is, the average squared difference between the estimated values and the actual value.

Here's a brief explanation of how MSE is calculated:

Calculate the error: Subtract the actual value from the predicted value for each data point.
Square the error: Square each of these errors to ensure they are positive and to penalize larger errors more.
Mean of squared errors: Calculate the mean (average) of these squared errors.
The formula for MSE is:

[ \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 ]

Where:

( n ) is the number of data points.
( y_i ) is the actual value.
( \hat{y}_i ) is the predicted value.

In [6]:
@jax.jit # Just-in-time compilation decorator, increase performance
def loss_fn(params, X, y):
    err = forward(params, X) - y # the error/residual is the prediction (forward pass return) - ground truth (y)
    return jnp.mean(jnp.square(err))  # return the MSE (Mean Squared Error)

# Sanity
print(loss_fn(params, X, y))
print(loss_fn(params, X_test, y_test))
loss_fn(params, X, y)

4755.4663
3348.2605


Array(4755.4663, dtype=float32)

### Derivatives

In [7]:
# JAX calculates the gradient for us
# Takes the loss function, the parameters, and the input data as arguments
# Returns the derivative with respect to params
# Returns the same structure as the parameters
grad_fn = jax.grad(loss_fn)

# Sanity
grad_fn(params, X, y)

{'b': Array(12.478823, dtype=float32, weak_type=True),
 'w': Array([ -64.38199 , -119.19191 ,  -27.376173], dtype=float32)}

### Update

In [8]:
def update(params, grads):
    """
    Update the parameters by taking a small step in the negative direction of the gradients.
    Uses a JAX utility of pytrees:
    https://jax.readthedocs.io/en/latest/pytrees.html
    The lambda function acts on the leaves of those pytrees.
    In this case it acts only on the values of the dictionaries and not the keys.
    For each leaf parameter we want to subtract the learning rate times (in the direction of) the gradient.

    """
    return jax.tree.map(lambda p, g: p - LEARNING_RATE * g, params, grads)

# Sanity
update(params, grad_fn(params, X, y))

{'b': Array(-1.2478822, dtype=float32, weak_type=True),
 'w': Array([ 6.438199 , 11.919191 ,  2.7376173], dtype=float32)}

### Train

In [9]:
# the main training loop
for _ in range(EPOCHS):
    # print(f"\nparams:\n{forward(params, X)}")
    loss = loss_fn(params, X_test, y_test)
    print(f"\nloss: {loss}")
    if loss < 0.001: # just for me, for output clarity
        break

    grads = grad_fn(params, X, y) # Don't calculate grads on the test set (that would be cheating)

    """
    The following two lines are not optimal.
    Especially if we have many parameters, because then we will have to do this 
    for every parameter in the dictionary.
    Instead I use the 'update' function that uses JAX utility.
    """
    # params['w'] -= LEARNING_RATE * grads['w']
    # params['b'] -= LEARNING_RATE * grads['b']

    params = update(params, grads)


loss: 3348.260498046875

loss: 2097.861572265625

loss: 1320.1396484375

loss: 834.7333374023438

loss: 530.6007080078125

loss: 339.22821044921875

loss: 218.23890686035156

loss: 141.3498077392578

loss: 92.20957946777344

loss: 60.61093521118164

loss: 40.15789794921875

loss: 26.825946807861328

loss: 18.07109260559082

loss: 12.277174949645996

loss: 8.411867141723633

loss: 5.8119120597839355

loss: 4.048447132110596

loss: 2.8423237800598145

loss: 2.010556221008301

loss: 1.432295322418213

loss: 1.0271083116531372

loss: 0.741074800491333

loss: 0.5377066731452942

loss: 0.3921509385108948

loss: 0.2873229384422302

loss: 0.21139657497406006

loss: 0.15611524879932404

loss: 0.11567297577857971

loss: 0.08595991134643555

loss: 0.06404363363981247

loss: 0.04782363772392273

loss: 0.03578297793865204

loss: 0.026820559054613113

loss: 0.020133327692747116

loss: 0.01513282023370266

loss: 0.011386659927666187

loss: 0.008576474152505398

loss: 0.006465078331530094

loss: 0.00