# ADMM 算法

### ADMM 算法

ADMM 可以用来求解形如
$$\begin{align*}
\min_{x,z}\  & f(x)+g(z)\\
\mathrm{s.t.}\  & Ax+Bz=c
\end{align*}$$
的优化问题，其中 $f$ 和 $g$ 是凸函数。

ADMM 的迭代公式为
$$
\begin{align*}
x^{k+1} & =\underset{x}{\arg\min}\ f(x)+\frac{\rho}{2}\Vert Ax+Bz^{k}-c+u^{k}\Vert^{2}\\
z^{k+1} & =\underset{z}{\arg\min}\ g(z)+\frac{\rho}{2}\Vert Ax^{k+1}+Bz-c+u^{k}\Vert^{2}\\
u^{k+1} & =u^{k}+Ax^{k+1}+Bz^{k+1}-c.
\end{align*}
$$

定义原问题残差 $r^{k+1}=Ax^{k+1}+Bz^{k+1}-c$ 和对偶问题残差 $s^{k+1}=\rho A'B(z^{k+1}-z^{k})$。当 $||r^k||$ 和 $||s^k||$ 小于某个阈值时即可认为算法收敛。

### Lasso

Lasso 是一种带有变量选择效果的回归方法，它与线性回归中的最小二乘方法（OLS）类似，但加上了对回归系数的 $L^1$ 范数惩罚项。为了与 ADMM 算法的记号匹配，我们用 $M\in\mathbb{R}^{n\times p}$ 表示自变量矩阵，$b\in\mathbb{R}^n$ 表示因变量向量，要估计的回归系数为 $x\in\mathbb{R}^p$。于是 Lasso 的目标函数为 $$\frac{1}{2}\Vert Mx-b\Vert^2+\lambda \Vert x\Vert_1,$$ 其中 $\Vert v\Vert_1$ 表示向量 $v=(v_1,\ldots,v_n)'$ 的 $L^1$ 范数，即 $\Vert v\Vert_1=|v_1|+\cdots+|v_n|$。

Lasso 可以改写为 ADMM 的形式：$f(x)=(1/2)\cdot\Vert Mx-b\Vert^2$，$g(z)=\lambda ||z||_1$，约束中 $A=I_p$, $B=-I_p$，$c=0$。其迭代公式为

$$
\begin{align*}
x^{k+1} & =(M'M+\rho I)^{-1}(M'b+\rho(z^{k}-u^{k}))\\
z^{k+1} & =S_{\lambda/\rho}(x^{k+1}+u^{k})\\
u^{k+1} & =u^{k}+x^{k+1}-z^{k+1},
\end{align*}
$$

其中 $S_{\kappa}(a)$ 为 soft-thresholding 运算符，定义为

$$
S_{\kappa}(a)=\begin{cases}
a-\kappa, & a>\kappa\\
0, & |a|\le\kappa\\
a+\kappa, & a<-\kappa
\end{cases},
$$

一种紧凑的表达是 $S_{\kappa}(a)=\mathrm{sign}(a)\cdot\max\{0,|a|-\kappa\}$。

相应地，原问题残差为 $r^{k+1}=x^{k+1}-z^{k+1}$，对偶问题残差为 $s^{k+1}=-\rho (z^{k+1}-z^{k})$。

返回哪一个变量 x z u？

### 利用 ADMM 求解 Lasso

In [2]:
import numpy as np
np.set_printoptions(linewidth=100)

生成模拟数据：

In [18]:
np.random.seed(123)
n = 1000
p = 30
nz = 10
M = np.random.normal(size=(n, p))
# 真实的 x 只有前10个元素非零，其余均为0
xtrue = np.random.normal(size=nz)
xtrue = np.concatenate((xtrue, np.zeros(p - nz)))
b = M.dot(xtrue) + np.random.normal(size=n)
xtrue

array([-1.05417044, -0.78301134,  1.82790084,  1.7468072 ,  1.3282585 , -0.43277314, -0.6686141 ,
       -0.47208845,  1.05554064,  0.67905585,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ])

将 $\lambda$ 设为 $\lambda=0.01n$，编写 ADMM 算法来对 Lasso 问题进行求解。

$$
\begin{align*}
x^{k+1} & =(M'M+\rho I)^{-1}(M'b+\rho(z^{k}-u^{k}))\\
z^{k+1} & =S_{\lambda/\rho}(x^{k+1}+u^{k})\\
u^{k+1} & =u^{k}+x^{k+1}-z^{k+1},
\end{align*}
$$

**注意：**注意到在每一次迭代中都要计算 $(M'M+\rho I)^{-1}v$，其中 $v$ 是某个向量。如果直接使用 `np.linalg.solve()`，计算量会非常大。一种更好的方法是先对 $M'M+\rho I$ 进行 Cholesky 分解（$M'M+\rho I$ 是正定矩阵），然后再解线性方程组。

In [19]:
x = np.zeros(p)
z = np.zeros(p)
u = np.zeros(p)

In [20]:
rho = 1.0
l = 0.01 * n

In [21]:
from scipy.linalg import cho_factor, cho_solve
c,lower = cho_factor(M.T.dot(M) + rho * np.eye(p))
Mtb = M.T.dot(b)

In [22]:
def soft_thresholding(a, k):
    return np.sign(a) * np.maximum(0.0, np.abs(a) - k)

In [23]:
xnew = cho_solve((c,lower), Mtb + rho * (z - u))
xnew

array([-1.12672886, -0.78338621,  1.81712861,  1.76873077,  1.31667535, -0.41171736, -0.62774124,
       -0.49119182,  1.01766884,  0.69460945, -0.00379547, -0.03955283,  0.03138226,  0.00605927,
        0.03499658, -0.02064702,  0.03643848, -0.04799411,  0.01243125,  0.02168373,  0.04973613,
        0.0236359 ,  0.02654542,  0.07400461, -0.0323586 ,  0.02629139,  0.00622242, -0.02120399,
        0.0170705 , -0.00448481])

In [24]:
znew = soft_thresholding(xnew + u, l / rho)
znew

array([-0., -0.,  0.,  0.,  0., -0., -0., -0.,  0.,  0., -0., -0.,  0.,  0.,  0., -0.,  0., -0.,
        0.,  0.,  0.,  0.,  0.,  0., -0.,  0.,  0., -0.,  0., -0.])

In [25]:
unew = u + xnew - znew
unew

array([-1.12672886, -0.78338621,  1.81712861,  1.76873077,  1.31667535, -0.41171736, -0.62774124,
       -0.49119182,  1.01766884,  0.69460945, -0.00379547, -0.03955283,  0.03138226,  0.00605927,
        0.03499658, -0.02064702,  0.03643848, -0.04799411,  0.01243125,  0.02168373,  0.04973613,
        0.0236359 ,  0.02654542,  0.07400461, -0.0323586 ,  0.02629139,  0.00622242, -0.02120399,
        0.0170705 , -0.00448481])

相应地，原问题残差为 $r^{k+1}=x^{k+1}-z^{k+1}$，对偶问题残差为 $s^{k+1}=-\rho (z^{k+1}-z^{k})$。

In [26]:
resid_r_norm = np.linalg.norm(xnew - znew)
resid_r_norm

3.5200059369810295

In [27]:
resid_s_norm = np.linalg.norm(-rho * (znew - z))
resid_s_norm

0.0

In [28]:
max_iter = 10000
tol = 0.001

for i in range(max_iter):
    xnew = cho_solve((c,lower), Mtb + rho * (z - u))
    znew = soft_thresholding(xnew + u, l / rho)
    unew = u + xnew - znew
    resid_r_norm = np.linalg.norm(xnew - znew)
    resid_s_norm = np.linalg.norm(-rho * (znew - z))
    x = xnew
    z = znew
    u = unew
    if i % 100 == 0:
        print(f"Iteration{i}:||r|| = {resid_r_norm:.6f}, ||s|| = {resid_s_norm:.6f}")
    if resid_r_norm <= tol and resid_s_norm <= tol:
        print(f"Iteration{i}:||r|| = {resid_r_norm:.6f}, ||s|| = {resid_s_norm:.6f}")
        break

Iteration0:||r|| = 3.520006, ||s|| = 0.000000
Iteration100:||r|| = 0.129868, ||s|| = 0.000014
Iteration200:||r|| = 0.100437, ||s|| = 0.000012
Iteration300:||r|| = 0.069206, ||s|| = 0.000011
Iteration400:||r|| = 0.041848, ||s|| = 0.000007
Iteration500:||r|| = 0.030335, ||s|| = 0.000006
Iteration600:||r|| = 0.023931, ||s|| = 0.000215
Iteration700:||r|| = 0.011848, ||s|| = 0.000002
Iteration800:||r|| = 0.006923, ||s|| = 0.000001
Iteration900:||r|| = 0.006265, ||s|| = 0.000001
Iteration1000:||r|| = 0.005670, ||s|| = 0.000001
Iteration1100:||r|| = 0.005132, ||s|| = 0.000001
Iteration1200:||r|| = 0.004645, ||s|| = 0.000001
Iteration1300:||r|| = 0.004204, ||s|| = 0.000001
Iteration1400:||r|| = 0.003805, ||s|| = 0.000001
Iteration1500:||r|| = 0.003445, ||s|| = 0.000001
Iteration1600:||r|| = 0.003118, ||s|| = 0.000001
Iteration1700:||r|| = 0.002822, ||s|| = 0.000000
Iteration1800:||r|| = 0.002555, ||s|| = 0.000000
Iteration1900:||r|| = 0.001568, ||s|| = 0.000000
Iteration2000:||r|| = 0.001416, 

In [29]:
x

array([-1.11994625e+00, -7.75665313e-01,  1.81095466e+00,  1.75936757e+00,  1.31168505e+00,
       -4.01678107e-01, -6.19517515e-01, -4.85216828e-01,  1.00954251e+00,  6.85583755e-01,
       -2.25490602e-04, -2.97250900e-02,  2.40073271e-02,  7.38205063e-04,  2.71655144e-02,
       -1.06604280e-02,  2.65468210e-02, -3.70551913e-02,  1.83798119e-03,  1.35241530e-02,
        3.83247661e-02,  1.21210117e-02,  1.70710398e-02,  6.45411795e-02, -2.12076991e-02,
        1.62502629e-02,  6.00482416e-04, -1.08612299e-02,  7.94383401e-03, -2.07260485e-04])

In [30]:
z

array([-1.11994625, -0.77566531,  1.81095466,  1.75936757,  1.31168505, -0.40167811, -0.61951751,
       -0.48521683,  1.00954251,  0.68558375, -0.        , -0.02972509,  0.02400733,  0.        ,
        0.02716551, -0.01066043,  0.02654682, -0.03705519,  0.00183798,  0.01352415,  0.03832477,
        0.01212101,  0.01707104,  0.06454118, -0.0212077 ,  0.01625026,  0.        , -0.01086123,
        0.00794383, -0.        ])