# 🔢 JAX Linear Regression on CO2 Emissions Dataset
This notebook replicates a simple linear regression using JAX, NumPy, and Matplotlib, adapted from the IBM ML with Python course.

In [None]:
!pip install --upgrade jax jaxlib matplotlib

📝 **Explanation:**
This installs the required JAX and plotting libraries in your Colab or local environment.

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score


📝 **Explanation:**
Here, we import JAX, NumPy, matplotlib, and supporting libraries for data processing and evaluation.

In [None]:
import pandas as pd

url = 'https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMDeveloperSkillsNetwork-ML0101EN-SkillsNetwork/labs/Module%202/data/FuelConsumptionCo2.csv'
df = pd.read_csv(url)

# Extract and prepare data
cdf = df[['ENGINESIZE','CYLINDERS','FUELCONSUMPTION_COMB','CO2EMISSIONS']]
X = cdf['ENGINESIZE'].values
y = cdf['CO2EMISSIONS'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


📝 **Explanation:**
We load the fuel consumption dataset, filter relevant columns, and split it into training and testing sets.

In [None]:
plt.scatter(X_train, y_train, color='blue')
plt.xlabel("Engine Size")
plt.ylabel("CO2 Emissions")
plt.title("Training Data")
plt.show()


📝 **Explanation:**
This visualizes the training data — Engine Size vs CO2 Emissions — as a scatter plot.

In [None]:
# Convert to jnp arrays
X_train_jax = jnp.array(X_train).reshape(-1, 1)
y_train_jax = jnp.array(y_train)

# Add bias term
X_b = jnp.hstack([jnp.ones_like(X_train_jax), X_train_jax])

# Initialize weights
key = jnp.array([0])
theta = jnp.zeros((2,))

# Prediction function
def predict(X, theta):
    return jnp.dot(X, theta)

# Loss function (MSE)
def mse_loss(theta, X, y):
    preds = predict(X, theta)
    return jnp.mean((preds - y)**2)

# Gradient descent
learning_rate = 0.01
n_iterations = 1000

for i in range(n_iterations):
    grads = grad(mse_loss)(theta, X_b, y_train_jax)
    theta -= learning_rate * grads

print("Learned parameters:", theta)


📝 **Explanation:**
We convert training data into JAX arrays, add a bias column, and initialize model weights.

In [None]:
# Evaluate on test set
X_test_jax = jnp.array(X_test).reshape(-1, 1)
X_test_b = jnp.hstack([jnp.ones_like(X_test_jax), X_test_jax])
y_test_pred = predict(X_test_b, theta)

# Metrics
mse = mean_squared_error(y_test, y_test_pred)
r2 = r2_score(y_test, y_test_pred)

print("MSE:", mse)
print("R²:", r2)


📝 **Explanation:**
This block defines the prediction and MSE loss functions, then performs gradient descent to fit the model.

In [None]:
plt.scatter(X_test, y_test, color='blue', label='Actual')
plt.plot(X_test, np.array(y_test_pred), color='red', label='Predicted')
plt.xlabel("Engine Size")
plt.ylabel("CO2 Emissions")
plt.title("JAX Linear Regression Predictions")
plt.legend()
plt.show()


📝 **Explanation:**
We reshape the test data, generate predictions using the learned weights, and compute MSE and R² metrics.