In [None]:
import numpy as np
from sklearn.datasets import make_regression

# 生成随机数据
X, y = make_regression(n_samples=100, n_features=10, noise=0.1)

# 定义ADMM算法参数
rho = 1
max_iter = 1000
epsilon = 1e-4

# 初始化变量
n, m = X.shape
x = np.zeros(m)
z = np.zeros(m)
u = np.zeros(m)

# 定义soft-threshold函数
def soft_threshold(x, threshold):
    return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)

# 定义目标函数f和g的更新函数
def update_x(x, z, u, rho):
    A = X.T.dot(X) + rho * np.eye(m)
    b = X.T.dot(y) + rho * (z - u)
    return np.linalg.solve(A, b)

def update_z(x, z, u, rho, lam):
    return soft_threshold(x + u, lam/rho)

# 迭代更新
for i in range(max_iter):
    x = update_x(x, z, u, rho)
    z_prev = z.copy()
    z = update_z(x, z, u, rho, lam=rho)
    u = u + x - z
    # 判断是否收敛
    if np.linalg.norm(z - z_prev) < epsilon:
        break

print("Coefficients: ", x)


In [11]:
import numpy as np
'''
最小化（x-1）^2+(z-2)^2，约束条件是3x+2z=5
'''

rho = 1

x = np.zeros(1)
z = np.zeros(1)
y = np.zeros(1)
for i in range(100):
    x = (15*rho-6*rho*z-3*y+2)/(2+9*rho)
    print(x)
    z = (-2*rho*(3*x-5)+4)/(2*y+2+4*rho)
    print(z)
    y = y+rho*(3*x+2*z-5)
print(x,z)
print(3*x+2*z)

[1.54545455]
[0.78787879]
[0.78512397]
[1.10268149]
[0.73321298]
[1.27234142]
[0.71040163]
[1.38429028]
[0.67666986]
[1.45450506]
[0.65500187]
[1.5003039]
[0.63939885]
[1.52992802]
[0.62922589]
[1.54935337]
[0.62234359]
[1.5620563]
[0.61783016]
[1.57040948]
[0.61482585]
[1.57589611]
[0.61285048]
[1.57950854]
[0.61154319]
[1.58188577]
[0.61068259]
[1.5834518]
[0.61011438]
[1.58448319]
[0.60974012]
[1.58516278]
[0.60949327]
[1.58561053]
[0.60933063]
[1.58590558]
[0.6092234]
[1.58610001]
[0.60915275]
[1.58622813]
[0.60910618]
[1.58631257]
[0.60907549]
[1.58636821]
[0.60905526]
[1.58640488]
[0.60904193]
[1.58642905]
[0.60903314]
[1.58644497]
[0.60902736]
[1.58645547]
[0.60902354]
[1.58646238]
[0.60902103]
[1.58646694]
[0.60901937]
[1.58646995]
[0.60901828]
[1.58647193]
[0.60901756]
[1.58647323]
[0.60901708]
[1.58647409]
[0.60901677]
[1.58647466]
[0.60901656]
[1.58647503]
[0.60901643]
[1.58647528]
[0.60901634]
[1.58647544]
[0.60901628]
[1.58647555]
[0.60901624]
[1.58647562]
[0.60901622]
[1.