In [13]:
import numpy as np

A = np.array([[1, 0, 1], [0, 1, 1]])

b = np.array([[1], [1]])

N = A.shape[0]
M = A.shape[1]

class ADMM:
    max_loop = 100

    def __init__(self, x0, u0, rho):
        self.x = x0
        self.u = u0
        self.rho = rho

    def update_x(self):
        p = b - self.u
        self.x = self.rho*(np.linalg.inv(2*np.identity(M)+self.rho*A.T.dot(A))).dot(A.T.dot(p))

    def update_u(self):
        self.u = self.u + self.rho*(A.dot(self.x)-b);
     
    def update(self):
        self.update_x()
        self.update_u()

    def fit(self):
        for loop in range(self.max_loop):
            self.update()
            # print("x: "+str(self.x))
            # print("u: "+str(self.u))

    def result_x(self, a):
        return np.dot(a, self.x)

print(A)
print(b)

admm = ADMM(x0=np.dot(A.T, b) / N, u0=0, rho=1)
admm.fit()
print("最小二乗法 by ADMM")

print("coef_")
print(np.round(admm.x, 10))

print('\n')

print("Result value : Actual value")
for i in range(N):
    print(admm.result_x(A[i]), " : ", b[i])

print(admm.x)

[[1 0 1]
 [0 1 1]]
[[1]
 [1]]
最小二乗法 by ADMM
coef_
[[0.33333333]
 [0.33333333]
 [0.66666667]]


Result value : Actual value
[1.]  :  [1]
[1.]  :  [1]
[[0.33333333]
 [0.33333333]
 [0.66666667]]


In [1]:
from sklearn.datasets import load_boston
from sklearn import linear_model
import pandas as pd
import numpy as np

boston = load_boston()
boston_df = pd.DataFrame(boston.data, columns=boston.feature_names)
boston_df['PRICE'] = pd.DataFrame(boston.target)

A = np.array([boston_df['CRIM']
             ,boston_df['ZN']
             ,boston_df['INDUS']
             ,boston_df['CHAS']
             ,boston_df['NOX']
             ,boston_df['RM']
             ,boston_df['AGE']
             ,boston_df['DIS']
             ,boston_df['RAD']
             ,boston_df['TAX']
             ,boston_df['PTRATIO']
             ,boston_df['B']
             ,boston_df['LSTAT']
             ]).T

b = boston_df['PRICE']

N = boston.data.shape[0]
M = boston.data.shape[1]

class ADMM:

    max_loop = 10000

    def __init__(self,x,z,y,lambd,rho):
        self.x = x
        self.z = z
        self.y = y
        self.lambd = lambd
        self.rho = rho

    def S(self,a,k):
        if k > a :
            return k - a
        elif -a <= k and k <= a :
            return 0
        elif k < -a:
            return k + a

    def fit(self):
        for loop in range(self.max_loop):
            self.update()

    def update(self):
        self.update_x()
        self.update_y()
        self.update_z()

    def update_x(self):
        Q = np.dot(A.T,A) / N + self.rho * np.identity(M)
        inv_Q = np.linalg.inv(Q)
        P = np.dot(A.T,b) / N + self.rho * self.z - self.y

        self.x = np.dot(inv_Q,P)

    def update_z(self):
        for i in range(M):
            self.z[i] = self.S(self.lambd / self.rho, self.x[i] + self.y[i] / self.rho)

    def update_y(self):
        self.y = self.y + self.rho * (self.x - self.z)

    def pridict(self,u):
        return np.dot(self.x,u)



admm = ADMM(x = np.dot(A.T,b) / N, z = np.dot(A.T,b) / N, y = [0] * M, lambd = 1.0, rho = 1.0)
admm.fit()
print("Lasso by ADMM")

print("coef_")
print(np.round(admm.x,10))

print('\n')

print("Pridicted value : Actual value")
for i in range(N):
    print(admm.pridict(A[i])," : ",b[i])

Lasso by ADMM
coef_
[-0.05448317  0.04803143  0.          0.         -0.          4.10651622
  0.03184158 -0.2701645   0.0904304  -0.00690278 -0.08542695  0.01601162
 -0.55827292]


Pridicted value : Actual value
29.15065112443969  :  24.0
25.77900427199753  :  21.6
31.137211038634142  :  34.7
30.374922547498876  :  33.4
29.95433835327645  :  36.2
27.177869512603767  :  28.7
22.304119230263964  :  22.9
20.06128002497916  :  27.1
11.742028366359378  :  16.5
19.85317601023866  :  18.9
19.938498978809076  :  15.0
22.1811098182167  :  18.9
19.036706710176052  :  21.7
23.283643171071105  :  20.4
23.289188541080883  :  18.2
22.5614677559602  :  19.9
23.00160970249164  :  23.1
20.457661947418757  :  17.5
17.050730213904476  :  20.2
21.076810661066666  :  18.2
15.645895562666123  :  13.6
21.2179635303117  :  19.6
19.35327381640584  :  15.2
17.56246667099891  :  14.5
19.756417470786804  :  15.6
16.561699238860644  :  13.9
19.65999441412847  :  16.6
18.111697732270237  :  14.8
23.94630723052502 