In [10]:
import random
import numpy as np
import pandas as pd

In [11]:
path = './data/sample_data.csv'
data_df = pd.read_csv(path)

input_size = 5

X = data_df.iloc[:, :input_size].values
y = data_df.iloc[:, input_size:].values.reshape(-1,)

num_samples = X.shape[0]
num_train_samples = int(num_samples * 0.9)

rng = np.random.default_rng()
indices = rng.permutation(num_samples)
train_indices = indices[:num_train_samples]
test_indices = indices[num_train_samples:]

X_train = X[train_indices]
y_train = y[train_indices]
X_test = X[test_indices]
y_test = y[test_indices]

In [12]:
class ADMMLasso():
    def __init__(self):
        self.input_size = None
        self.w = None
        self.z = None
        self.gamma = None
        self._lambda = None
        self.rho = None
    
    def fit(self, X, y, max_iter=10000, p_thre=1e-5, s_thre=1e-5, 
            _lambda = 0.5, rho=0.5, verbose=False):
        X, y = self.__validate_input(X, y)
        self.__init_params(X, _lambda, rho)

        for iter in range(max_iter):
            next_w = self.__next_w(X, y)
            next_z = self.__next_z()
            next_gamma = self.__next_gamma()

            prev_w = self.w
            prev_z = self.z

            self.w = next_w
            self.z = next_z
            self.gamma = next_gamma

            if verbose:
                loss = self.calc_loss(y, self.predict(X))
                print(f'step : {iter} loss : {loss}')

            if self.__converge(p_thre, s_thre, prev_w, prev_z):
                if verbose:
                    print(f'\tearly stop!')
                break
        
        return self
    
    def predict(self, X):
        if X.shape[0] != self.input_size:
            X = X.T
        
        return X.T @ self.w
    
    def calc_loss(self, y, y_prime):
        print(f'{y.shape} {y_prime.shape}')
        right = np.sum((y - y_prime) ** 2) / 2
        mid = self._lambda * np.sum(np.abs(self.z))
        left = self.rho * np.sum((self.w - self.z) ** 2) / 2

        return right + mid + left

    def __validate_input(self, X, y):
        # X.shape should be (input_size, sample_nums)
        # y.shape should be (sample_nums, )
        if X.shape[0] == y.shape[0]:
            X = X.T
        else: assert(X.shape[1] != y.shape[0])

        return X, y
    
    def __init_params(self, X, _lambda, rho):
        self.input_size = X.shape[0]
        self.w = np.zeros((self.input_size))
        self.z = np.zeros_like(self.w)
        self.gamma = np.zeros_like(self.w)
        self._lambda = _lambda
        self.rho = rho

        return
    
    def __next_w(self, X, y):
        left = np.linalg.inv(X @ X.T + self.rho * np.eye(self.input_size))
        right = (X @ y).reshape(-1, )  - self.gamma + self.rho * self.z

        return left @ right

    def __next_z(self):
        next_z = np.zeros_like(self.z)

        for l in range(self.input_size):
            next_z[l] = self.__soft_threshold(
                _lambda=self._lambda / self.rho,
                x=self.w[l] + self.gamma[l] / self.rho)
        
        return next_z
    
    def __next_gamma(self):
        return self.gamma + self.rho * (self.w - self.z)

    def __soft_threshold(self, _lambda, x):
        return np.sign(x) * max((0, abs(x) - _lambda))
    
    def __converge(self, p_thre, s_thre, prev_w, prev_z):
        if np.sum((self.w - self.z) ** 2) > p_thre: return False
        if np.sum((self.w - prev_w) ** 2) > s_thre: return False
        if np.sum((self.z - prev_z) ** 2) > s_thre: return False

        return True

In [13]:
model = ADMMLasso()

print(y_train.shape)

model.fit(X_train, y_train, verbose=True)
y_prime = model.predict(X_test)

loss = model.calc_loss(y_test, y_prime)

print(f'loss: {loss}')

(900,)
(900,) (900,)
step : 0 loss : 3.488312401326015
(900,) (900,)
step : 1 loss : 2.2531223773726037
(900,) (900,)
step : 2 loss : 5.744539350152806
(900,) (900,)
step : 3 loss : 9.449108112783891
(900,) (900,)
step : 4 loss : 5.227629411047202
(900,) (900,)
step : 5 loss : 2.7469769264906887
(900,) (900,)
step : 6 loss : 3.44985343495911
(900,) (900,)
step : 7 loss : 2.253183237102102
(900,) (900,)
step : 8 loss : 5.727317293871363
(900,) (900,)
step : 9 loss : 9.3399848087476
(900,) (900,)
step : 10 loss : 5.177388006247197
(900,) (900,)
step : 11 loss : 2.7385804252039962
(900,) (900,)
step : 12 loss : 3.4011386689827465
(900,) (900,)
step : 13 loss : 2.253913497771773
(900,) (900,)
step : 14 loss : 5.710020279638561
(900,) (900,)
step : 15 loss : 9.233030500608896
(900,) (900,)
step : 16 loss : 5.128327047912128
(900,) (900,)
step : 17 loss : 2.730123801491023
(900,) (900,)
step : 18 loss : 3.3539644796118173
(900,) (900,)
step : 19 loss : 2.2551893714492213
(900,) (900,)
step :