In [1]:
class CFG:
    optimizer = 'SGD_momentum'

# Libs

In [2]:
import jax.numpy as jnp
from jax import grad, jit
from sklearn.base import BaseEstimator, RegressorMixin
from jax import random
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import LinearRegression, ElasticNet, Ridge
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, mean_absolute_percentage_error

# Loss

## MSE

In [3]:
class Loss:
    '''MSE'''
    @staticmethod
    @jit
    def forward(w, b, X, y):
        """MSE"""
        y_pred = jnp.dot(X, w) + b
        return jnp.mean((y - y_pred) ** 2)

    @staticmethod
    @jit
    def backward(w, b, X, y):
        grad_w = grad(Loss.forward, 0)(w, b, X, y)
        grad_b = grad(Loss.forward, 1)(w, b, X, y)
        return grad_w, grad_b

# Optimizer

## SGD

In [4]:
class SGD:
    def __init__(self, lr=0.01):
        self.lr = lr
    
    def update(self, w, b, X, y):
        grad_w, grad_b = Loss.backward(w, b, X, y)
        w -= self.lr * grad_w
        b -= self.lr * grad_b
        return w, b

    def initialize(self, w_shape):
        pass

## SGD with momentum

In [5]:
class SGD_momentum:
    def __init__(self, lr=0.01, momentum=0.9):
        self.lr = lr
        self.v_w = None
        self.v_b = None
        self.momentum = momentum

    def initialize(self, w_shape):
        self.v_w = jnp.zeros(w_shape)
        self.v_b = 0.0

    def update(self, w, b, X, y):
        grad_w, grad_b = Loss.backward(w, b, X, y)
        self.v_w = self.momentum * self.v_w + self.lr * grad_w
        w -= self.v_w
        self.v_b = self.momentum * self.v_b + self.lr * grad_b
        b -= self.v_b
        return w, b

## Nesterov

In [6]:
class Nesterov:
    def __init__(self, lr=0.01, momentum = 0.9):
        self.lr = lr
        self.v_w = None
        self.v_b = None
        self.momentum = momentum
    
    
    def initialize(self, w_shape):
        self.v_w = jnp.zeros(w_shape)
        self.v_b = 0.0    
        
        
    def update(self,w, b, X, y):
        # Lookahead based on the current velocity
        w_lookahead = w - self.momentum * self.v_w
        b_lookahead = b - self.momentum * self.v_b

        # Compute gradients at the lookahead position
        grad_w, grad_b = Loss.backward(
            w_lookahead, b_lookahead, X, y)

        # Update velocities
        self.v_w = self.momentum * self.v_w + self.lr * grad_w
        self.v_b = self.momentum * self.v_b + self.lr * grad_b

        # Update parameters
        w = w - self.v_w
        b = b - self.v_b

        return w, b

## Adam

In [7]:
# Optimizer class for the Adam optimization algorithm
class AdamOptimizer:
    def __init__(self, lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-8):
        self.lr = lr
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.t = 0
        self.m_w = None
        self.v_w = None
        self.m_b = None
        self.v_b = None
    
    def initialize(self, w_shape):
        # Initialize moments to zero
        self.m_w = jnp.zeros(w_shape)
        self.v_w = jnp.zeros(w_shape)
        self.m_b = 0.0
        self.v_b = 0.0

    
    def update(self, w, b, X, y):
        # Compute gradients
        grad_w, grad_b = Loss.backward(w, b, X, y)
        # Increment time step
        self.t += 1
        
        # Update biased first moment estimate
        self.m_w = self.beta_1 * self.m_w + (1 - self.beta_1) * grad_w
        self.m_b = self.beta_1 * self.m_b + (1 - self.beta_1) * grad_b

        # Update biased second raw moment estimate
        self.v_w = self.beta_2 * self.v_w + (1 - self.beta_2) * (grad_w ** 2)
        self.v_b = self.beta_2 * self.v_b + (1 - self.beta_2) * (grad_b ** 2)

        # Compute bias-corrected first moment estimate
        m_w_hat = self.m_w / (1 - self.beta_1 ** self.t)
        m_b_hat = self.m_b / (1 - self.beta_1 ** self.t)

        # Compute bias-corrected second raw moment estimate
        v_w_hat = self.v_w / (1 - self.beta_2 ** self.t)
        v_b_hat = self.v_b / (1 - self.beta_2 ** self.t)

        # Update parameters
        w = w - self.lr * m_w_hat / (jnp.sqrt(v_w_hat) + self.epsilon)
        b = b - self.lr * m_b_hat / (jnp.sqrt(v_b_hat) + self.epsilon)

        return w, b


# Standart with SGD

In [8]:
class JAXLinearRegressor(BaseEstimator, RegressorMixin):
    def __init__(self, lr=0.01, n_iter=1000, random_seed=42, optimizer = CFG.optimizer):
        self.lr = lr
        self.n_iter = n_iter
        self.random_seed = random_seed
        # Initialize weights and bias
        self.w = None
        self.b = None
        # Initialize the optimizer
        if optimizer == 'SGD':
            self.optimizer = SGD(lr=lr)
        if optimizer == 'Adam':
            self.optimizer = AdamOptimizer(lr=lr)
        if optimizer == 'SGD_momentum':
            self.optimizer = SGD_momentum(lr=lr)
        if optimizer == 'Nesterov':
            self.optimizer = Nesterov(lr=lr)
    def fit(self, X, y):
        self.initialize(X)
        self.optimizer.initialize(self.w.shape)

        for _ in range(self.n_iter):

            self.w, self.b = self.optimizer.update(self.w, self.b, X, y)

        return self

    def predict(self, X):
        return jnp.dot(X, self.w) + self.b

    def initialize(self, X):
        key = random.PRNGKey(self.random_seed)
        # Initialize weights and bias
        self.w = random.normal(
            key, (X.shape[1],)) * np.sqrt(2 / X.shape[1])*0.01
        self.b = 0.0  # initialize bias as a scalar

# Load dataset

In [9]:
X,y = load_iris(return_X_y=True, as_frame=False)



X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=42)

# Train

In [10]:
# Initialize our custom Linear Regression estimator
lin_reg = JAXLinearRegressor()
lin_sk = LinearRegression()
# Fit the model
lin_reg.fit(X_train, y_train)
lin_sk.fit(X_train, y_train)
# Make predictions
y_pred = lin_reg.predict(X_test)
y_pred_sk = lin_sk.predict(X_test)

# Metrics

In [11]:
def metrics(y_pred, y_true):
    print('MSE',mean_squared_error(y_true, y_pred))
    print('MAE',mean_absolute_error(y_true, y_pred))
    print('MAPE',mean_absolute_percentage_error(y_true, y_pred))
    print('r2',r2_score(y_true, y_pred))

In [12]:
metrics(y_pred, y_test)

MSE 0.03607431351681588
MAE 0.1436190750253828
MAPE 99296078283614.4
r2 0.9487794407883166


In [13]:
metrics(y_pred_sk, y_test)

MSE 0.03611030626905014
MAE 0.14443140820853126
MAPE 102015956763893.48
r2 0.9487283360348984
