# CH-7: Linear Regression

In [2]:
import jax
import numpy as np
import jax.numpy as jnp

print(jax.__version__)

0.4.31


In [3]:
# Seed the random number generator for reproducibility
jax.random.key(0)

EPOCHS = 10

learning_rate = 0.01 # Hyperparameter for gradient descent

In [4]:
# Sample feature (house size) and target (house price) data
house_sizes = jnp.array([1000., 1500., 2000., 2500., 3000.])
prices = jnp.array([150000., 200000., 250000., 300000., 350000.])

In [5]:
# Initialize random weights with a small range to avoid large initial errors
weights = jax.random.uniform(key=jax.random.PRNGKey(0), shape=(5,))
bias = jax.random.uniform(key=jax.random.PRNGKey(0), shape=(1,))
print("Initial weights: ", weights)
print("Initial bias: ", bias)

Initial weights:  [0.57450044 0.09968603 0.39316022 0.8941783  0.59656656]
Initial bias:  [0.41845703]


In [4]:
def linear_regression_model(weights, inputs):
    """
    Calculates the predicted value using the linear regression model
    Args: 
      weights: a numpy array
      inputs: a numpy array containing the feature values for a single data point.
    Returns:
      A 1D numpy array containing the predicted value.
    """
    return jnp.dot(weights.T,inputs) + bias

In [5]:
# Loss function - MSE (Mean Squared Error)
def mse_loss(predicted_prices, actual_prices):
  """
  Calculates the mean squared error (MSE) loss between the predicted and actual prices
  Args:
    predicted_prices: a numpy array containing the predicted prices
    actual_prices: a numpy array containing the actual prices
  Returns:
    A Numpy array containing the MSE loss.
  """
  squared_errors = jnp.square(predicted_prices - actual_prices)
  return jnp.mean(squared_errors)

In [7]:
# Get the gradient function of the loss function with respect to the weights
loss_grad_fn = jax.grad(mse_loss)

In [8]:
# Training loop with manual gradient descent computation
for _ in range(EPOCHS): # Train for EPOCHS epochs
  # Forward pass: Calculates predicted prices using current weights
  predicted_prices = linear_regression_model(weights, house_sizes)
  print(f"predicted_prices: {predicted_prices}")
  # Calculates the loss (MSE - Mean Square Error)
  loss = mse_loss(predicted_prices=predicted_prices, actual_prices=prices)
  # Calculates the gradients of the loss with respect to the weights (manual calculation)
  # This is where automatic differentiation (autodiff) from jax would be useful.
  gradients = 2*jnp.dot(house_sizes.T, (predicted_prices - prices)) / len(house_sizes)
  #gradients = loss_grad_fn(predicted_prices, prices)
  # Update the weights using gradient descent
  weights = weights - learning_rate * gradients
  # Print the loss after each epoch (optional) 
  # print(f"Epoch: {_}, Loss: {loss:.4f}, gradients: {gradients}, weights: {weights}")

predicted_prices: 5536.06982421875
predicted_prices: 107796357120.0
predicted_prices: -4.312263988294451e+16
predicted_prices: 1.7250740009167372e+22
predicted_prices: -6.900968278681468e+27
predicted_prices: 2.760656424177989e+33
predicted_prices: -inf
predicted_prices: inf
predicted_prices: nan
predicted_prices: nan
