# ADMM 算法

### 一致性问题

假设我们需要求解优化问题
$$
\min_{x} \sum_{i=1}^N f_i(x)+g(x)
$$
其中 $g(\cdot)$ 是一个全局的正则项。一个典型的例子是 Lasso
$$
\frac{1}{2}\Vert Ax-b\Vert^2+\lambda \Vert x\Vert_1.
$$

它可以转化成 ADMM 一致性问题
$$\begin{align*}
\min_{x_1,\ldots,x_N,z}\  & f_i(x_i)+g(z)\\
\mathrm{s.t.}\  & x_i=z,i=1,\ldots,N.
\end{align*}$$

迭代公式为
$$
\begin{align*}
x_i^{k+1} & =\underset{x_i}{\arg\min}\ f_i(x_i)+\frac{\rho}{2}\Vert x_i-z^{k}+u_i^{k}\Vert^{2}\\
z^{k+1} & =\underset{z}{\arg\min}\ g(z)+\frac{N\rho}{2}\Vert z-\bar{x}^{k+1}-\bar{u}^{k}\Vert^{2}\\
u_i^{k+1} & =u_i^{k}+x_i^{k+1}-z^{k+1}.
\end{align*}
$$

注意其中带下标 $i$ 的更新要对所有 $i=1,\ldots,N$ 进行。

总的原问题残差范数为
$$
\Vert r^{k+1}\Vert=\sqrt{\sum_{i=1}^N \Vert x_i^{k+1} - z^{k+1}\Vert^2},
$$

总的对偶问题残差范数为
$$
\Vert s^{k+1}\Vert=\rho \sqrt{N} \Vert z^{k+1} - z^k\Vert.
$$

### Lasso

对于 Lasso 问题，假设数据按行分块，

$$
A_{n\times p}=\left[\begin{array}{c}
A_{1}\in\mathbb{R}^{n_{1}\times p}\\
\vdots\\
A_{N}\in\mathbb{R}^{n_{N}\times p}
\end{array}\right],\quad b_{n\times1}=\left[\begin{array}{c}
b_{1}\in\mathbb{R}^{n_{1}}\\
\vdots\\
b_{N}\in\mathbb{R}^{n_{N}}
\end{array}\right].
$$

其迭代公式为

$$
\begin{align*}
x_i^{k+1} & =(A_i'A_i+\rho I)^{-1}(A_i'b_i+\rho(z^{k}-u_i^{k}))\\
z^{k+1} & =S_{\lambda/(\rho N)}(\bar{x}^{k+1}+\bar{u}^{k})\\
u_i^{k+1} & =u_i^{k}+x_i^{k+1}-z^{k+1}.
\end{align*}
$$

### PySpark 实现

In [1]:
import os
import numpy as np
np.set_printoptions(linewidth=100)

先生成模拟数据：

In [2]:
np.random.seed(123)
n = 100000
p = 100
nz = 20
A = np.random.normal(size=(n, p))
# 真实的 x 只有前20个元素非零，其余均为0
xtrue = np.random.normal(size=nz)
xtrue = np.concatenate((xtrue, np.zeros(p - nz)))
b = A.dot(xtrue) + np.random.normal(size=n)
dat = np.hstack((b.reshape(n, 1), A))
if not os.path.exists("data"):
    os.makedirs("data", exist_ok=True)
np.savetxt("data/lasso.txt", dat, fmt="%.9f", delimiter="\t")

配置和启动 PySpark：

In [4]:
import findspark
findspark.init("/Users/xinby/Library/Spark")

from pyspark.sql import SparkSession
# 本地模式
spark = SparkSession.builder.\
    master("local[*]").\
    config("spark.executor.memory", "2g").\
    config("spark.driver.memory", "2g").\
    appName("ADMM").\
    getOrCreate()
sc = spark.sparkContext
# sc.setLogLevel("ERROR")
print(spark)
print(sc)

23/06/13 12:05:47 WARN Utils: Your hostname, XinBys-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 172.17.11.190 instead (on interface en0)
23/06/13 12:05:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/06/13 12:05:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
<pyspark.sql.session.SparkSession object at 0x7feb88612b20>
<SparkContext master=local[*] appName=ADMM>


PySpark 读取文件：

In [5]:
file = sc.textFile("data/lasso.txt")

# 打印前5行，并将每行字符串截尾
text = file.map(lambda x: x[:70] + "...").take(5)
print(*text, sep="\n")

file.getNumPartitions()

[Stage 0:>                                                          (0 + 1) / 1]

0.783983964	-1.085630603	0.997345447	0.282978498	-1.506294714	-0.57860...
7.234829271	0.642054689	-1.977887932	0.712264635	2.598303927	-0.024625...
1.353003297	0.703310118	-0.598105331	2.200702099	0.688296930	-0.006307...
-5.014040860	0.765054846	-0.828988834	-0.659151311	0.611123550	-0.1440...
-4.701729193	1.534090289	-0.529914099	-0.490972283	-1.309165314	-0.008...


                                                                                

4

分区转换：

In [8]:
# str => np.array
def str_to_vec(line):
    # 分割字符串
    str_vec = line.split("\t")
    # 将每一个元素从字符串变成数值型
    num_vec = map(lambda s: float(s), str_vec)
    # 创建 Numpy 向量
    return np.fromiter(num_vec, dtype=float)

# Iter[str] => Iter[matrix]
def part_to_mat(iterator):
    # Iter[str] => Iter[np.array]
    iter_arr = map(str_to_vec, iterator)

    # Iter[np.array] => list(np.array)
    dat = list(iter_arr)

    # list(np.array) => matrix
    if len(dat) < 1:  # Test zero iterator
        mat = np.array([])
    else:
        mat = np.vstack(dat)

    # matrix => Iter[matrix]
    yield mat

dat = file.mapPartitions(part_to_mat).filter(lambda x: x.shape[0] > 0)
dat.cache()
print(dat.count())
dat.first()

                                                                                

4


array([[ 0.78398396, -1.0856306 ,  0.99734545, ..., -1.36347154,  0.37940061, -0.37917643],
       [ 7.23482927,  0.64205469, -1.97788793, ..., -0.11085072, -0.34126172, -0.21794626],
       [ 1.3530033 ,  0.70331012, -0.59810533, ...,  0.41569454,  0.16054442,  0.81976061],
       ...,
       [-2.66214793, -0.35893791,  1.1753093 , ..., -0.27385503,  0.64362609,  1.25292122],
       [-0.45562882,  1.89322323,  0.58039023, ..., -1.07129657,  1.40304602, -0.76660851],
       [-8.51940409,  0.27954009, -0.43108834, ...,  0.69615652, -1.01735946, -1.84640876]])

创建一个新的 RDD，用来存储每个数据分区上的 $x_i$ 和 $u_i$ 变量：

In [7]:
def create_xu(block_dat):
    p = block_dat.shape[1] - 1
    xi = np.zeros(p)
    ui = np.zeros(p)
    return xi, ui

params = dat.map(create_xu)
params.first()

(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]))

在后续更新中，我们会将数据 RDD（dat）和参数 RDD（params）进行拼接，从而更新参数。

In [9]:
dat.zip(params).first()

(array([[ 0.78398396, -1.0856306 ,  0.99734545, ..., -1.36347154,  0.37940061, -0.37917643],
        [ 7.23482927,  0.64205469, -1.97788793, ..., -0.11085072, -0.34126172, -0.21794626],
        [ 1.3530033 ,  0.70331012, -0.59810533, ...,  0.41569454,  0.16054442,  0.81976061],
        ...,
        [-2.66214793, -0.35893791,  1.1753093 , ..., -0.27385503,  0.64362609,  1.25292122],
        [-0.45562882,  1.89322323,  0.58039023, ..., -1.07129657,  1.40304602, -0.76660851],
        [-8.51940409,  0.27954009, -0.43108834, ...,  0.69615652, -1.01735946, -1.84640876]]),
 (array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.

在主进程上设定若干变量，包括初始化 $z$：

In [10]:
# 样本量
n = dat.map(lambda x: x.shape[0]).reduce(lambda x, y: x + y)
print(f"n = {n}")

# 变量数
p = dat.first().shape[1] - 1
print(f"p = {p}")

# 分区数
N = dat.count()
print(f"N = {N}")

# rho
rho = 10.0

# lambda
lam = 0.001 * n

# z变量
z = np.zeros(p)

n = 100000
p = 100
N = 4


更新一次 $x_i$：

$$x_i^{k+1} =(A_i'A_i+\rho I)^{-1}(A_i'b_i+\rho(z^{k}-u_i^{k}))$$

In [11]:
def update_x(block_dat_and_param, z, rho):
    block_dat, (xi, ui) = block_dat_and_param
    # 分离因变量向量和自变量矩阵
    bi = block_dat[:, 0]
    Ai = block_dat[:, 1:]
    p = Ai.shape[1]
    # 计算新 xi
    xi_new = np.linalg.solve(Ai.T.dot(Ai) + rho * np.eye(p), Ai.T.dot(bi) + rho * (z - ui))
    return xi_new, ui

updated_x = dat.zip(params).map(lambda d: update_x(d, z, rho))
updated_x.first()

(array([-5.90814270e-01, -1.11321607e+00,  1.15095752e+00,  6.87063924e-01,  6.40293561e-01,
        -1.69374030e+00,  8.71403383e-01, -6.94379736e-01,  1.20866282e+00, -1.75238934e-01,
        -5.90611190e-01,  4.51873454e-01,  1.56854930e+00,  9.37139756e-01,  3.30179252e-01,
        -6.21776857e-01, -1.68055436e-01,  1.04836673e+00,  9.84804634e-01, -2.24043441e-01,
        -1.10207886e-02,  4.74571699e-03,  1.99914358e-02, -3.99392232e-03,  2.99631275e-03,
        -3.21884881e-03, -1.00943808e-02, -4.96676839e-03, -4.24626145e-03, -5.95266905e-03,
        -3.72556604e-03,  5.14898158e-03, -3.52804213e-03,  5.04132529e-03,  2.25529949e-03,
        -1.10746048e-02,  1.45531578e-02, -2.68163609e-03, -8.41348115e-03,  5.02697750e-04,
        -4.25471754e-03, -3.12527736e-03,  4.03257522e-03, -1.44902762e-03,  1.63975781e-03,
        -1.44719846e-02,  7.95545957e-03, -3.18808466e-03, -2.90799052e-03, -2.48552651e-03,
        -8.21171927e-03, -4.37358090e-03, -6.20168644e-03, -4.40891713

计算 $\bar{x}$ 和 $\bar{u}$：

In [12]:
xbar, ubar = updated_x.reduce(lambda xu1, xu2: (xu1[0] + xu2[0], xu1[1] + xu2[1]))
xbar /= N
ubar /= N
print(xbar)
print(ubar)



[-5.89441106e-01 -1.11140794e+00  1.15667202e+00  6.86230067e-01  6.42184751e-01 -1.70784908e+00
  8.72816079e-01 -6.92350683e-01  1.21201786e+00 -1.76527604e-01 -5.89417484e-01  4.50452142e-01
  1.56229423e+00  9.34678939e-01  3.33633574e-01 -6.25241839e-01 -1.63868904e-01  1.04197325e+00
  9.87340713e-01 -2.27278342e-01 -6.10460651e-03  9.10453738e-05  4.04206210e-03 -4.66167579e-03
  3.07175165e-03 -1.28888144e-04 -2.73214488e-03 -5.06100989e-03 -2.33029018e-03  5.24505576e-04
 -5.51145096e-03  2.31573937e-03 -3.94824355e-03  2.92775513e-04  3.17403469e-03 -9.40286179e-04
  8.16506892e-03 -4.41510032e-03 -3.63106038e-03 -3.18298560e-03  5.21931168e-04 -8.01699057e-03
  1.81022640e-04 -3.57419636e-03  2.44778618e-03 -6.62225030e-03  5.83903026e-03 -4.79766121e-03
  4.87740292e-03 -1.25469889e-03 -3.43332330e-03 -1.29728780e-03 -3.65021625e-03 -2.13285003e-03
  1.78870498e-03  3.93544083e-03 -2.27077489e-03 -3.36240716e-03 -1.39212238e-03  2.69188464e-03
  4.26897165e-03 -3.34218222e-

                                                                                

更新一次 $z$：

In [13]:
def soft_thresholding(a, k):
    return np.sign(a) * np.maximum(0.0, np.abs(a) - k)

znew = soft_thresholding(xbar + ubar, lam / (rho * N))
znew

array([-0., -0.,  0.,  0.,  0., -0.,  0., -0.,  0., -0., -0.,  0.,  0.,  0.,  0., -0., -0.,  0.,
        0., -0., -0.,  0.,  0., -0.,  0., -0., -0., -0., -0.,  0., -0.,  0., -0.,  0.,  0., -0.,
        0., -0., -0., -0.,  0., -0.,  0., -0.,  0., -0.,  0., -0.,  0., -0., -0., -0., -0., -0.,
        0.,  0., -0., -0., -0.,  0.,  0., -0., -0.,  0., -0., -0., -0., -0.,  0., -0.,  0.,  0.,
       -0.,  0., -0.,  0., -0., -0., -0., -0.,  0.,  0., -0.,  0., -0., -0., -0.,  0., -0., -0.,
        0.,  0., -0.,  0.,  0.,  0., -0.,  0., -0.,  0.])

更新一次 $u_i$：

In [14]:
def update_u(param, z):
    xi_new, ui = param
    # 计算新 ui
    ui_new = ui + xi_new - z
    return xi_new, ui_new

updated_u = updated_x.map(lambda d: update_u(d, znew))
updated_u.first()

(array([-5.90814270e-01, -1.11321607e+00,  1.15095752e+00,  6.87063924e-01,  6.40293561e-01,
        -1.69374030e+00,  8.71403383e-01, -6.94379736e-01,  1.20866282e+00, -1.75238934e-01,
        -5.90611190e-01,  4.51873454e-01,  1.56854930e+00,  9.37139756e-01,  3.30179252e-01,
        -6.21776857e-01, -1.68055436e-01,  1.04836673e+00,  9.84804634e-01, -2.24043441e-01,
        -1.10207886e-02,  4.74571699e-03,  1.99914358e-02, -3.99392232e-03,  2.99631275e-03,
        -3.21884881e-03, -1.00943808e-02, -4.96676839e-03, -4.24626145e-03, -5.95266905e-03,
        -3.72556604e-03,  5.14898158e-03, -3.52804213e-03,  5.04132529e-03,  2.25529949e-03,
        -1.10746048e-02,  1.45531578e-02, -2.68163609e-03, -8.41348115e-03,  5.02697750e-04,
        -4.25471754e-03, -3.12527736e-03,  4.03257522e-03, -1.44902762e-03,  1.63975781e-03,
        -1.44719846e-02,  7.95545957e-03, -3.18808466e-03, -2.90799052e-03, -2.48552651e-03,
        -8.21171927e-03, -4.37358090e-03, -6.20168644e-03, -4.40891713

计算原问题残差范数

$$
\Vert r^{k+1}\Vert=\sqrt{\sum_{i=1}^N \Vert x_i^{k+1} - z^{k+1}\Vert^2}.
$$

In [15]:
# 在每个分块上计算 ||xi_new - z||^2，然后汇总求和
resid_r_norm = updated_u.map(lambda d: np.sum(np.square(d[0] - znew))).reduce(lambda x, y: x + y)
resid_r_norm = np.sqrt(resid_r_norm)
print(resid_r_norm)

7.9930216055938645


                                                                                

计算对偶问题残差范数：

In [16]:
resid_s_norm = rho * np.sqrt(N) * np.linalg.norm(znew - z)
print(resid_s_norm)

0.0


接下来将整个过程写入一个循环，同时设定最大迭代次数为100，收敛的阈值为0.001。

In [17]:
max_iter = 100
tol = 0.001

for i in range(max_iter):
    # x 更新
    updated_x = dat.zip(params).map(lambda d: update_x(d, z, rho))
    updated_x.cache()
    # 计算平均
    xbar, ubar = updated_x.reduce(lambda xu1, xu2: (xu1[0] + xu2[0], xu1[1] + xu2[1]))
    xbar /= N
    ubar /= N
    # z 更新
    znew = soft_thresholding(xbar + ubar, lam / (rho * N))
    # u 更新
    updated_u = updated_x.map(lambda d: update_u(d, znew))
    updated_u.cache()
    # 计算残差大小
    resid_r_norm = updated_u.map(lambda d: np.sum(np.square(d[0] - znew))).reduce(lambda x, y: x + y)
    resid_r_norm = np.sqrt(resid_r_norm)
    resid_s_norm = rho * np.sqrt(N) * np.linalg.norm(znew - z)
    # 更新 x、z 和 u 的取值
    params = updated_u
    params.cache()
    z = znew
    # 打印残差信息，判断是否收敛
    if i % 1 == 0:
        print(f"Iteration {i}, ||r|| = {resid_r_norm:.6f}, ||s|| = {resid_s_norm:.6f}")
    if resid_r_norm <= tol and resid_s_norm <= tol:
        break

                                                                                

Iteration 0, ||r|| = 7.993022, ||s|| = 0.000000
Iteration 1, ||r|| = 6.960675, ||s|| = 22.149790
Iteration 2, ||r|| = 4.198749, ||s|| = 45.315842


                                                                                

Iteration 3, ||r|| = 2.976602, ||s|| = 25.851012
Iteration 4, ||r|| = 1.368161, ||s|| = 24.406323
Iteration 5, ||r|| = 1.070546, ||s|| = 5.704705
Iteration 6, ||r|| = 0.946671, ||s|| = 4.991149
Iteration 7, ||r|| = 0.751589, ||s|| = 3.304970
Iteration 8, ||r|| = 0.672594, ||s|| = 3.348655
Iteration 9, ||r|| = 0.672322, ||s|| = 0.001349
Iteration 10, ||r|| = 0.672049, ||s|| = 0.000038
Iteration 11, ||r|| = 0.496576, ||s|| = 4.425707
Iteration 12, ||r|| = 0.496273, ||s|| = 0.101481
Iteration 13, ||r|| = 0.496070, ||s|| = 0.000055
Iteration 14, ||r|| = 0.357201, ||s|| = 2.807153
Iteration 15, ||r|| = 0.161806, ||s|| = 2.378260
Iteration 16, ||r|| = 0.128301, ||s|| = 0.985750
Iteration 17, ||r|| = 0.128248, ||s|| = 0.000402
Iteration 18, ||r|| = 0.128196, ||s|| = 0.000017
Iteration 19, ||r|| = 0.128143, ||s|| = 0.000017
Iteration 20, ||r|| = 0.128091, ||s|| = 0.000017
Iteration 21, ||r|| = 0.128038, ||s|| = 0.000017
Iteration 22, ||r|| = 0.127986, ||s|| = 0.000017
Iteration 23, ||r|| = 0.1

                                                                                

Iteration 32, ||r|| = 0.127462, ||s|| = 0.000017
Iteration 33, ||r|| = 0.127409, ||s|| = 0.000017
Iteration 34, ||r|| = 0.127357, ||s|| = 0.000017
Iteration 35, ||r|| = 0.127305, ||s|| = 0.000017
Iteration 36, ||r|| = 0.127253, ||s|| = 0.000017
Iteration 37, ||r|| = 0.127201, ||s|| = 0.000017
Iteration 38, ||r|| = 0.127148, ||s|| = 0.000017
Iteration 39, ||r|| = 0.127096, ||s|| = 0.000017


ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/Users/xinby/Library/Spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/Users/xinby/Library/Spark/python/lib/py4j-0.10.9.5-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/Users/xinby/opt/anaconda3/lib/python3.9/socket.py", line 704, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

Exception in thread "serve RDD 307" java.net.SocketTimeoutException: Accept timed out
	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:694)
	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:738)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:690)
	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:655)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:631)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:588)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:546)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:64)


In [None]:
z

In [None]:
xtrue

In [None]:
sc.stop()