In [77]:
import numpy as np
# from SpaRSA import SpaRSA
from scipy.optimize import minimize

def sigmoid(x):
    positive_mask = x >= 0
    sigmoid_values = np.zeros_like(x, dtype=np.float64)

    sigmoid_values[positive_mask] = 1 / (1 + np.exp(-x[positive_mask]))
    sigmoid_values[~positive_mask] = np.exp(x[~positive_mask]) / (1 + np.exp(x[~positive_mask]))

    return sigmoid_values


def admm(w0, X0, X1, rhofull, tau):
    d, _, n = X0.shape
    q = 0.9

    u0 = np.zeros((d, n))
    for i in range(n):
        u0[:, i] = w0

    lmda = np.zeros((d, n))
    for i in range(n):
        def DPi(w):
            inner_prod_X0 = np.dot(w.T, X0[:, :, i])
            inner_prod_X1 = np.dot(w.T, X1[:, :, i])
            sigmoid_term0 = sigmoid(inner_prod_X0)
            sigmoid_term1 = sigmoid(-inner_prod_X1)
            first_term = np.dot(X0[:, :, i], sigmoid_term0.T)
            second_term = -np.dot(X1[:, :, i], sigmoid_term1.T)
            return (1/(X0.shape[1]+X1.shape[1])) * (first_term + second_term)

        lmda[:, i] = -DPi(w0)

    T = 10000
    eps = 1e-50
    wc = w0
    uc = u0
    mod_uc = uc + np.dot(lmda, np.diag(1.0 / rhofull))
    tepst_full = np.zeros(n)

    for t in range(T + 1):
        print(t)
        vareps = max(1e-20, q ** t)

        # perform aggregation
        wc = np.dot(mod_uc, rhofull) / np.sum(rhofull)

        # perform local updates
        for i in range(n):

            def Pi(w):
                inner_prod_X0 = np.dot(w.T, X0[:, :, i])
                inner_prod_X1 = np.dot(w.T, X1[:, :, i])
                first_term = np.sum(inner_prod_X0-np.log(sigmoid(inner_prod_X0)+eps))
                second_term = np.sum(-np.log(sigmoid(inner_prod_X1)+eps))
                third_term = np.dot(lmda[:, i], w - wc)
                fourth_term = np.dot(rhofull[i], np.linalg.norm(w - wc, 2) ** 2) / 2
                return (1/(X0.shape[1]+X1.shape[1])) * (first_term + second_term) + third_term + fourth_term


            def DPi(w):
                inner_prod_X0 = np.dot(w.T, X0[:, :, i])
                inner_prod_X1 = np.dot(w.T, X1[:, :, i])
                sigmoid_term0 = sigmoid(inner_prod_X0)
                sigmoid_term1 = sigmoid(-inner_prod_X1)
                first_term = np.dot(X0[:, :, i], sigmoid_term0.T)
                second_term = -np.dot(X1[:, :, i], sigmoid_term1.T)
                third_term = lmda[:, i] + rhofull[i] * (w - wc)
                return (1/(X0.shape[1]+X1.shape[1])) * (first_term + second_term) + third_term

            DPic = DPi(wc)
            up = uc.copy()
            # uc[:, i] = SpaRSA(up[:, i], Pi, DPi, vareps)
            uc[:, i] = minimize(Pi, up[:, i], tol=vareps, method="L-BFGS-B")['x']
            lmda[:, i] = lmda[:, i] + rhofull[i] * (uc[:, i] - wc)
            tepst_full[i] = np.linalg.norm(DPic - rhofull[i] * (wc - up[:, i]), np.inf)

        mod_uc = uc + np.dot(lmda, np.diag(1.0 / rhofull))

        print("tau: ", np.mean(tepst_full))
        # termination criterion
        if np.mean(tepst_full) <= tau:
            break

    in_iter = t + 1
    w = wc
    return w, in_iter


In [84]:
d = 30
n = 5
Ni0 = 500
Ni1 = 100

X0 = (np.random.rand(d, Ni0, n) - 0.5) * 0.1
X1 = (np.random.rand(d, Ni1, n) + 0.5) * 0.1


w0 = np.ones(d)
tau = 1e-5
rhofull = np.ones(n) * 0.01

admm(w0, X0, X1, rhofull, tau)

0
tau:  0.0008661582616194998
1
tau:  0.0010073496328731266
2
tau:  0.0010073496328731257
3
tau:  0.0010073496328731257
4
tau:  0.0010073496328731257
5
tau:  0.0010073496328731257
6
tau:  0.0010073496328731257
7
tau:  0.0010073496328731257
8
tau:  0.0010073496328731257
9
tau:  0.0010073496328731257
10
tau:  0.0010073496328731257
11
tau:  0.0010073496328731257
12
tau:  0.0010073496328731257
13
tau:  0.0010073496328731257
14
tau:  0.0010073496328731257
15
tau:  0.0010073496328731257
16
tau:  0.0010073496328731257
17
tau:  0.0010073496328731257
18
tau:  0.0010073496328731257
19
tau:  0.0010073496328731257
20
tau:  0.0010073496328731257
21
tau:  0.0010073496328731257
22
tau:  0.0010073496328731257
23
tau:  0.0010073496328731257
24
tau:  0.0010073496328731257
25
tau:  0.0010073496328731257
26
tau:  0.0010073496328731257
27
tau:  0.0010073496328731257
28
tau:  0.0010073496328731257
29
tau:  0.0010073496328731257
30
tau:  0.0010073496328731257
31
tau:  0.0010073496328731257
32
tau:  0.0010073

(array([ 2.87621163,  0.55397113,  1.06135669,  2.44056995,  0.77516151,
         3.7704301 , -0.23551158,  1.63686149,  0.05182695,  1.98845751,
         0.78395132,  3.27373892, -0.6434171 ,  1.40237492,  1.23316069,
         1.62678256,  1.52402097,  3.06040266,  3.45455727, -1.66654065,
         1.23430958,  0.68202342,  1.82361908,  0.48965377,  2.55295494,
         0.06650959,  0.39725149,  2.39473486,  0.39713762,  3.71502831]),
 312)