In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score 
# import seaborn as sns

In [86]:
class RidgeRegression:
  def __init__(self, data):
    self.data = data
    self.alpha = 0.1 
    self.w = None
  
  def fit(self, X, y):
    X = np.c_[np.ones(X.shape[0]), X]
    xTranspose = X.T
    A = xTranspose @ X
    I = np.identity(A.shape[0])
    c = xTranspose @ y
    self.w = np.linalg.inv(A + self.alpha * I) @ c
  
  def predict(self, X):
    X = np.c_[np.ones(X.shape[0]), X]
    return X @ self.w
    

In [None]:
class LassoRegression:

  def soft_threshold(rho, alpha):
    if rho < alpha:
      return rho + alpha
    elif rho > alpha:
      return rho - alpha
    return 0
  
  def coordinate_descent(self,
                         X,
                         y,
                         theta,
                         alpha = 0.1,
                         n_iterations = 1000,
                         intercept = False):
    # The intercept parameter allows to specify whether or not we regularize theta_0
    m, n = X.shape
    X /= (np.linalg.norm(X, axis=0))
    for _ in range(n_iterations):
      for j in range(n):
        X_j = X[:,j].reshape(-1, 1)
        y_pred = X @ theta
        rho = X_j.T @ (y - y_pred + theta[j] * X_j)

        if intercept:
          if j == 0:
            theta[j] = rho
          else:
            theta[j] = self.soft_threshold(rho, alpha)
        else:
          theta[j] = self.soft_threshold(rho, alpha)
    return theta.flatten()


In [None]:
data = pd.read_csv("../data/diabetes.csv")
# sns.heatmap(data.corr(), annot=True)
target = data["target"]
data.drop(columns=["target"], inplace=True)
ridge = RidgeRegression(data)
ridge.fit(data, target)
y_pred = ridge.predict(data)
mse = mean_squared_error(target, y_pred)
r2 = r2_score(target, y_pred)
print("[Ridge] Mean Squared Error:", mse)
print("[Ridge] R2 Score:", r2)

lasso = LassoRegression()
lasso.coordinate_descent(data, target, )

Mean Squared Error: 2890.4524762078877
R2 Score: 0.5125617905814859
