In [None]:
import numpy as np
import seaborn as sns

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from random import random, seed
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split


class Regression:
    def __init__(self, X, y):
        """X: design matrix, y: target variable"""
        self.X, self.y = X, y
        self.p = len(X[0])


    def split_scale(self, only_centering=True, test_size=.2): # add option to do full scaling (not only centering)
        
        self.only_centering = only_centering

        # Splitting the data
        X_train, X_test, y_train, self.y_test = train_test_split(self.X, self.y, test_size=test_size, random_state=42)

        if only_centering:
            # Scaling the data (only centering)
            X_train_mean = np.mean(X_train, axis=0) # 1D array of each column mean

            self.X_train_scaled = X_train - X_train_mean
            self.X_test_scaled = X_test - X_train_mean # centering with training data

            self.y_train_mean = np.mean(y_train)
            self.y_train_scaled = y_train - self.y_train_mean
        else: raise NotImplementedError


    def fit(self, llambda=0, test_size=.2):
        """ llambda: negative one for Lasso, zero for OLS, positive for Rigde. """

        X_train, X_test, y_train = self.X_train_scaled, self.X_test_scaled, self.y_train_scaled
        y_train_mean = self.y_train_mean

        if not llambda:
            # solving OLS (equivalent to ridge with parameter 0)
            beta = np.linalg.inv(X_train.T @ X_train) @ X_train.T @ y_train
        
        elif llambda > 0:
            # solving Ridge
            I = np.eye(self.p)
            beta = np.linalg.inv(X_train.T @ X_train + llambda*I) @ X_train.T @ y_train
        
        elif llambda == -1: raise NotImplementedError # solving Lasso (numerically)
        else: raise Exception("The hyperparameter should be a non-negative number or -1.")

        if self.only_centering:
            y_tilde = X_train @ beta + y_train_mean # predictor on training data
            y_predict = X_test @ beta + y_train_mean # predictor on test data

        return y_tilde, y_predict
    