# ADMM 算法

### 一致性问题

假设我们需要求解优化问题
$$
\min_{x} \sum_{i=1}^N f_i(x)+g(x)
$$
其中 $g(\cdot)$ 是一个全局的正则项。一个典型的例子是 Lasso
$$
\frac{1}{2}\Vert Ax-b\Vert^2+\lambda \Vert x\Vert_1.
$$

它可以转化成 ADMM 一致性问题
$$\begin{align*}
\min_{x_1,\ldots,x_N,z}\  & f_i(x_i)+g(z)\\
\mathrm{s.t.}\  & x_i=z,i=1,\ldots,N.
\end{align*}$$

迭代公式为
$$
\begin{align*}
x_i^{k+1} & =\underset{x_i}{\arg\min}\ f_i(x_i)+\frac{\rho}{2}\Vert x_i-z^{k}+u_i^{k}\Vert^{2}\\
z^{k+1} & =\underset{z}{\arg\min}\ g(z)+\frac{N\rho}{2}\Vert z-\bar{x}^{k+1}-\bar{u}^{k}\Vert^{2}\\
u_i^{k+1} & =u_i^{k}+x_i^{k+1}-z^{k+1}.
\end{align*}
$$

注意其中带下标 $i$ 的更新要对所有 $i=1,\ldots,N$ 进行。

总的原问题残差范数为
$$
\Vert r^{k+1}\Vert=\sqrt{\sum_{i=1}^N \Vert x_i^{k+1} - z^{k+1}\Vert^2},
$$

总的对偶问题残差范数为
$$
\Vert s^{k+1}\Vert=\rho \sqrt{N} \Vert z^{k+1} - z^k\Vert.
$$

### Lasso

对于 Lasso 问题，假设数据按行分块，

$$
A_{n\times p}=\left[\begin{array}{c}
A_{1}\in\mathbb{R}^{n_{1}\times p}\\
\vdots\\
A_{N}\in\mathbb{R}^{n_{N}\times p}
\end{array}\right],\quad b_{n\times1}=\left[\begin{array}{c}
b_{1}\in\mathbb{R}^{n_{1}}\\
\vdots\\
b_{N}\in\mathbb{R}^{n_{N}}
\end{array}\right].
$$

其迭代公式为

$$
\begin{align*}
x_i^{k+1} & =(A_i'A_i+\rho I)^{-1}(A_i'b_i+\rho(z^{k}-u_i^{k}))\\
z^{k+1} & =S_{\lambda/(\rho N)}(\bar{x}^{k+1}+\bar{u}^{k})\\
u_i^{k+1} & =u_i^{k}+x_i^{k+1}-z^{k+1}.
\end{align*}
$$

### PySpark 实现

In [2]:
import os
import numpy as np
np.set_printoptions(linewidth=100)

先生成模拟数据：

In [3]:
np.random.seed(123)
n = 100000
p = 100
nz = 20
A = np.random.normal(size=(n, p))
# 真实的 x 只有前20个元素非零，其余均为0
xtrue = np.random.normal(size=nz)
xtrue = np.concatenate((xtrue, np.zeros(p - nz)))
b = A.dot(xtrue) + np.random.normal(size=n)
dat = np.hstack((b.reshape(n, 1), A))
if not os.path.exists("data"):
    os.makedirs("data", exist_ok=True)
np.savetxt("data/lasso.txt", dat, fmt="%.9f", delimiter="\t")

配置和启动 PySpark：

In [4]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
# 本地模式
spark = SparkSession.builder.\
    master("local[*]").\
    config("spark.executor.memory", "2g").\
    config("spark.driver.memory", "2g").\
    appName("ADMM").\
    getOrCreate()
sc = spark.sparkContext
# sc.setLogLevel("ERROR")
print(spark)
print(sc)
#自定义增加每个计算单元的内存

<pyspark.sql.session.SparkSession object at 0x0000017AB4E38340>
<SparkContext master=local[*] appName=ADMM>


PySpark 读取文件：

In [7]:
file = sc.textFile("data/lasso.txt")

# 打印前5行，并将每行字符串截尾
text = file.map(lambda x: x[:70] + "...").take(5)
print(*text, sep="\n")

file.getNumPartitions()

0.783983964	-1.085630603	0.997345447	0.282978498	-1.506294714	-0.57860...
7.234829271	0.642054689	-1.977887932	0.712264635	2.598303927	-0.024625...
1.353003297	0.703310118	-0.598105331	2.200702099	0.688296930	-0.006307...
-5.014040860	0.765054846	-0.828988834	-0.659151311	0.611123550	-0.1440...
-4.701729193	1.534090289	-0.529914099	-0.490972283	-1.309165314	-0.008...


4

分区转换：

In [8]:
# str => np.array
def str_to_vec(line):
    # 分割字符串
    str_vec = line.split("\t")
    # 将每一个元素从字符串变成数值型
    num_vec = map(lambda s: float(s), str_vec)
    # 创建 Numpy 向量
    return np.fromiter(num_vec, dtype=float)

# Iter[str] => Iter[matrix]
def part_to_mat(iterator):
    # Iter[str] => Iter[np.array]
    iter_arr = map(str_to_vec, iterator)

    # Iter[np.array] => list(np.array)
    dat = list(iter_arr)

    # list(np.array) => matrix
    if len(dat) < 1:  # Test zero iterator
        mat = np.array([])
    else:
        mat = np.vstack(dat)

    # matrix => Iter[matrix]
    yield mat

dat = file.mapPartitions(part_to_mat).filter(lambda x: x.shape[0] > 0)
dat.cache()
print(dat.count())

4


创建一个新的 RDD，用来存储每个数据分区上的 $x_i$ 和 $u_i$ 变量：

In [9]:
def create_xu(block_dat):
    p = block_dat.shape[1] - 1
    xi = np.zeros(p)
    ui = np.zeros(p)
    return xi, ui

params = dat.map(create_xu)
params.first()

(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]))

在后续更新中，我们会将数据 RDD（dat）和参数 RDD（params）进行拼接，从而更新参数。

In [10]:
dat.zip(params).first()
#缝合

(array([[ 0.78398396, -1.0856306 ,  0.99734545, ..., -1.36347154,  0.37940061, -0.37917643],
        [ 7.23482927,  0.64205469, -1.97788793, ..., -0.11085072, -0.34126172, -0.21794626],
        [ 1.3530033 ,  0.70331012, -0.59810533, ...,  0.41569454,  0.16054442,  0.81976061],
        ...,
        [-4.12818502,  2.13326782,  0.47922157, ...,  1.02750214,  0.40875237, -1.6814463 ],
        [-1.40995119,  0.80706604,  1.56112148, ...,  0.87865259, -0.3035723 ,  0.81765085],
        [ 7.00499946, -0.87059335,  1.00231986, ..., -1.37447687,  0.36551111, -1.3087634 ]]),
 (array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.

在主进程上设定若干变量，包括初始化 $z$：

In [11]:
# 样本量
n = dat.map(lambda x: x.shape[0]).reduce(lambda x, y: x + y)
print(f"n = {n}")

# 变量数
p = dat.first().shape[1] - 1
print(f"p = {p}")

# 分区数
N = dat.count()
print(f"N = {N}")

# rho
rho = 10.0

# lambda
lam = 0.001 * n

# z变量
z = np.zeros(p)

n = 100000
p = 100
N = 4


更新一次 $x_i$：

In [None]:
def update_x(block_dat_and_param, z, rho):
    block_dat, (xi, ui) = block_dat_and_param
    # 分离因变量向量和自变量矩阵
    bi = block_dat[:, 0]
    Ai = block_dat[:, 1:]
    p = Ai.shape[1]
    # 计算新 xi
    xi_new = np.linalg.solve(Ai.T.dot(Ai) + rho * np.eye(p), Ai.T.dot(bi) + rho * (z - ui))
    return xi_new, ui
#便于zi_new
updated_x = dat.zip(params).map(lambda d: update_x(d, z, rho))
updated_x.first()

计算 $\bar{x}$ 和 $\bar{u}$：

In [13]:
xbar, ubar = updated_x.reduce(lambda xu1, xu2: (xu1[0] + xu2[0], xu1[1] + xu2[1]))
xbar /= N
ubar /= N
print(xbar)
print(ubar)

[-5.89419532e-01 -1.11139323e+00  1.15671062e+00  6.86271561e-01  6.42181771e-01 -1.70783174e+00
  8.72773966e-01 -6.92354816e-01  1.21199997e+00 -1.76569237e-01 -5.89449899e-01  4.50429521e-01
  1.56230003e+00  9.34682099e-01  3.33636913e-01 -6.25236206e-01 -1.63888142e-01  1.04196377e+00
  9.87372576e-01 -2.27247585e-01 -6.12036878e-03  9.37547158e-05  4.06635022e-03 -4.68312732e-03
  3.04138379e-03 -1.26594489e-04 -2.72747461e-03 -5.03764642e-03 -2.26035160e-03  4.81320144e-04
 -5.51422542e-03  2.33178691e-03 -3.88393506e-03  3.07193508e-04  3.19566625e-03 -9.99453330e-04
  8.18758609e-03 -4.43382232e-03 -3.59589761e-03 -3.16625572e-03  4.96804996e-04 -8.00671634e-03
  1.80715612e-04 -3.58980874e-03  2.45181207e-03 -6.61332168e-03  5.85209277e-03 -4.74902604e-03
  4.83088487e-03 -1.31876861e-03 -3.40183489e-03 -1.32325194e-03 -3.63968247e-03 -2.17380192e-03
  1.71696772e-03  3.91057921e-03 -2.26421944e-03 -3.29266131e-03 -1.38675100e-03  2.69352785e-03
  4.27023878e-03 -3.36995794e-

更新一次 $z$：

In [14]:
def soft_thresholding(a, k):
    return np.sign(a) * np.maximum(0.0, np.abs(a) - k)

znew = soft_thresholding(xbar + ubar, lam / (rho * N))
znew

array([-0., -0.,  0.,  0.,  0., -0.,  0., -0.,  0., -0., -0.,  0.,  0.,  0.,  0., -0., -0.,  0.,
        0., -0., -0.,  0.,  0., -0.,  0., -0., -0., -0., -0.,  0., -0.,  0., -0.,  0.,  0., -0.,
        0., -0., -0., -0.,  0., -0.,  0., -0.,  0., -0.,  0., -0.,  0., -0., -0., -0., -0., -0.,
        0.,  0., -0., -0., -0.,  0.,  0., -0., -0.,  0., -0., -0., -0., -0.,  0., -0.,  0.,  0.,
       -0.,  0., -0.,  0., -0., -0., -0., -0.,  0.,  0., -0.,  0., -0., -0., -0.,  0., -0., -0.,
        0.,  0., -0.,  0.,  0.,  0., -0.,  0., -0.,  0.])

更新一次 $u_i$：

In [15]:
def update_u(param, z):
    xi_new, ui = param
    # 计算新 ui
    ui_new = ui + xi_new - z
    return xi_new, ui_new

updated_u = updated_x.map(lambda d: update_u(d, znew))
updated_u.first()

(array([-5.90891230e-01, -1.11316101e+00,  1.15118520e+00,  6.87119086e-01,  6.40461335e-01,
        -1.69365779e+00,  8.71350796e-01, -6.94468582e-01,  1.20887153e+00, -1.75076360e-01,
        -5.90574730e-01,  4.51854802e-01,  1.56860343e+00,  9.37088007e-01,  3.29918988e-01,
        -6.21764363e-01, -1.67827119e-01,  1.04850460e+00,  9.84924270e-01, -2.24025282e-01,
        -1.10339965e-02,  4.80186788e-03,  1.98148799e-02, -4.02297944e-03,  3.05191710e-03,
        -3.09676813e-03, -1.02057223e-02, -4.79883136e-03, -4.54899791e-03, -5.83288153e-03,
        -3.67947030e-03,  5.09693034e-03, -3.38483744e-03,  4.93337600e-03,  2.31644906e-03,
        -1.11660789e-02,  1.47914464e-02, -2.71281059e-03, -8.34385522e-03,  4.41063100e-04,
        -4.45645877e-03, -3.11051955e-03,  3.95695611e-03, -1.30166500e-03,  1.65346429e-03,
        -1.45887472e-02,  7.72474580e-03, -3.11627620e-03, -3.10905486e-03, -2.67056160e-03,
        -8.22430437e-03, -4.24525820e-03, -6.19090214e-03, -4.37564448

计算原问题残差范数

$$
\Vert r^{k+1}\Vert=\sqrt{\sum_{i=1}^N \Vert x_i^{k+1} - z^{k+1}\Vert^2}.
$$

In [16]:
# 在每个分块上计算 ||xi_new - z||^2，然后汇总求和
resid_r_norm = updated_u.map(lambda d: np.sum(np.square(d[0] - znew))).reduce(lambda x, y: x + y)
resid_r_norm = np.sqrt(resid_r_norm)
print(resid_r_norm)

7.993021782824685


计算对偶问题残差范数：

In [17]:
resid_s_norm = rho * np.sqrt(N) * np.linalg.norm(znew - z)
print(resid_s_norm)

0.0


接下来将整个过程写入一个循环，同时设定最大迭代次数为100，收敛的阈值为0.001。

In [18]:
max_iter = 100
tol = 0.001

for i in range(max_iter):
    # x 更新
    updated_x = dat.zip(params).map(lambda d: update_x(d, z, rho))
    updated_x.cache()
    # 计算平均
    xbar, ubar = updated_x.reduce(lambda xu1, xu2: (xu1[0] + xu2[0], xu1[1] + xu2[1]))
    xbar /= N
    ubar /= N
    # z 更新
    znew = soft_thresholding(xbar + ubar, lam / (rho * N))
    # u 更新
    updated_u = updated_x.map(lambda d: update_u(d, znew))
    updated_u.cache()
    # 计算残差大小
    resid_r_norm = updated_u.map(lambda d: np.sum(np.square(d[0] - znew))).reduce(lambda x, y: x + y)
    resid_r_norm = np.sqrt(resid_r_norm)
    resid_s_norm = rho * np.sqrt(N) * np.linalg.norm(znew - z)
    # 更新 x、z 和 u 的取值
    params = updated_u
    params.cache()
    z = znew
    # 打印残差信息，判断是否收敛
    if i % 1 == 0:
        print(f"Iteration {i}, ||r|| = {resid_r_norm:.6f}, ||s|| = {resid_s_norm:.6f}")
    if resid_r_norm <= tol and resid_s_norm <= tol:
        break

Iteration 0, ||r|| = 7.993022, ||s|| = 0.000000
Iteration 1, ||r|| = 6.960693, ||s|| = 22.149352
Iteration 2, ||r|| = 4.198804, ||s|| = 45.316189
Iteration 3, ||r|| = 2.976524, ||s|| = 25.852161
Iteration 4, ||r|| = 1.368126, ||s|| = 24.405553
Iteration 5, ||r|| = 1.070668, ||s|| = 5.702382
Iteration 6, ||r|| = 0.946690, ||s|| = 4.993388
Iteration 7, ||r|| = 0.751586, ||s|| = 3.305523
Iteration 8, ||r|| = 0.672615, ||s|| = 3.348174
Iteration 9, ||r|| = 0.672343, ||s|| = 0.001349
Iteration 10, ||r|| = 0.672071, ||s|| = 0.000038
Iteration 11, ||r|| = 0.496673, ||s|| = 4.418458
Iteration 12, ||r|| = 0.496356, ||s|| = 0.108130
Iteration 13, ||r|| = 0.496154, ||s|| = 0.000057
Iteration 14, ||r|| = 0.357007, ||s|| = 2.819662
Iteration 15, ||r|| = 0.161443, ||s|| = 2.380831
Iteration 16, ||r|| = 0.128294, ||s|| = 0.979881
Iteration 17, ||r|| = 0.128241, ||s|| = 0.000399
Iteration 18, ||r|| = 0.128189, ||s|| = 0.000017
Iteration 19, ||r|| = 0.128136, ||s|| = 0.000017
Iteration 20, ||r|| = 0.12

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "D:\Spark\python\lib\py4j-0.10.9.5-src.zip\py4j\java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "D:\Spark\python\lib\py4j-0.10.9.5-src.zip\py4j\clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "C:\Users\Bmanksy\miniconda3\lib\socket.py", line 705, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

In [None]:
z

In [None]:
xtrue

In [None]:
sc.stop()