In [None]:
import numpy as np
import matplotlib.pylab as plt
from sklearn.preprocessing import StandardScaler
from numba import njit, jit
from numba.experimental import jitclass
import pickle
stand_scaler = StandardScaler()

In [None]:
max_seed = 424242

d_theta = 100
d_w = 50
n_data = 1000
zeta = 0
lmb = 0
M = 8

rng = np.random.default_rng(42)
noise_scale = 1e-3
b_theta = rng.normal(size=(d_theta,))

In [None]:
H = rng.uniform(size=(n_data, d_theta)) / d_theta
A = rng.uniform(size=(n_data, d_theta)) / d_theta
B = rng.uniform(size=(n_data, d_w)) / d_w

def generate_A_b(H, n, d, zeta, noise_scale):
    A = rng.uniform(size=(n, d)) / d
    A = H + zeta*A
    x = rng.normal(size=d)
    b = A @ x + noise_scale * rng.normal(n)
    return A, b

def generate_A_B_c(A, B, n, d1, d2, zeta, noise_scale):
    A1 = rng.uniform(size=(n, d1), ) / (d1)
    B1 = rng.uniform(size=(n, d2)) / (d2)
    A = A + zeta*A1
    B = B + zeta*B1
    x1 = rng.normal(size=d1)
    x2 = rng.normal(size=d2)
    y = A @ x1 + B @ x2 + noise_scale * rng.normal(n)
    return A, B, y

$$\phi_m(\theta) = \frac{1}{2}\|H_m\theta - b_m\|^2$$
$$\nabla \phi_m(\theta) = H_m^\top(H_m\theta - b_m)$$

In [None]:
class Phi_m:
    def __init__(self, d_theta, n_data, zeta, noise_scale, H):
        self.d = d_theta
        self.Hm, self.bm = generate_A_b(H, n_data, d_theta, zeta, noise_scale)
    
    def func(self, theta):
        return 0.5*sum((self.Hm@theta - self.bm)**2) 
    
    def grad(self, theta):
        return self.Hm.T@(self.Hm@theta - self.bm)

$$
\begin{align}
f_m(\theta, w) &= \phi_m(\theta) + \frac{1}{2}\|A_m\theta+B_mw-y_m\|^2\\
\nabla_1f_m(\theta, w) &= \nabla \phi_m(\theta) + A_m^\top(A_m\theta+B_mw-y_m)\\
\nabla_2f_m(\theta, w) &= B_m^\top(A_m\theta+B_mw-y_m)\\
w_m^*(\theta) &= (B_m^\top B_m)^{-1}(B_m^\top y_m - B_m^\top A_m\theta)
\end{align}
$$

In [None]:
class F_m:
    def __init__(self, idx, d_theta, d_w, n_data, zeta, noise_scale, H, A, B):
        self.phi_m = Phi_m(d_theta, n_data, zeta, noise_scale, H)
        self.d_theta = d_theta
        self.d_w = d_w
        self.m = idx
        self.Am, self.Bm, self.ym = generate_A_B_c(A, B, n_data, d_theta, d_w, zeta, noise_scale)
    
    def func(self, theta, w):
        return self.phi_m.func(theta) + 0.5*sum((self.Am@theta + self.Bm@w - self.ym)**2)

    def grad_theta(self, theta, w):
        return self.phi_m.grad(theta) + self.Am.T@(self.Am@theta + self.Bm@w - self.ym)

    def grad_w(self, theta, w):
        return self.Bm.T@(self.Am@theta + self.Bm@w - self.ym)

    def opt_w(self, theta):
        A = self.Bm.T@self.Bm
        b = self.Bm.T@(self.ym - self.Am@theta)
        return np.linalg.solve(A, b)
    
    def operator(self, theta):
        w_star = self.opt_w(theta)
        return self.grad_theta(theta, w_star)
    
    def operator_norm_(self, theta):
        return sum(self.operator(theta)**2)

In [None]:
class F:
    def __init__(self, clients, d_theta, M):
        self.clients = clients
        self.M = M
        self.d_theta = d_theta
        
    def operator(self, theta):
        out = np.zeros((self.d_theta,))
        for m in self.clients:
            out += m.operator(theta)
        return out / self.M
    
    
    def operator_norm(self, theta):
        return np.linalg.norm(self.operator(theta))

In [None]:
def compute_L(clients, M, d_theta):
    L = 0
    for m in clients:
        L_hat = np.linalg.norm(m.Am.T@(np.eye(n_data) - m.Bm@np.linalg.pinv(m.Bm))@m.Am)
        L_phi = max(np.linalg.eigvals(m.phi_m.Hm.T@m.phi_m.Hm))
        L_max = max(L_hat, L_phi)
        if L_max > L:
            L = L_max
    return L_max

def compute_L_mu(clients):
    for m in clients:
        H = m.Bm.T@m.Bm
        eigs = np.linalg.eigvals(H)
        L = max(eigs)
        mu = min(eigs)
        print(f'client {m.m}', mu, L)

In [None]:
clients = [F_m(m, d_theta, d_w, n_data, zeta, noise_scale, H, A, B) for m in range(M)]

In [None]:
theta0 = np.zeros((d_theta,))
w0 = np.zeros((d_w,))

In [None]:
full_op = F(clients, d_theta, M)
full_op.operator_norm(theta0)

In [None]:
def FFGG(theta0, batch_size, lrout, lrin, T, tau, clients, M, d_theta, d_w, n_data, exact_comp=True):
    theta = theta0.copy()
    
    full_op = F(clients, d_theta, M)
    
    history = {'F_norm':[full_op.operator_norm(theta)], 'F_norm_min':[full_op.operator_norm(theta)]}
    history['iter'] = [0]
    print('iteration 0', history['F_norm'][-1])
    
    for t in range(T):
        batch = rng.choice(clients, size=batch_size, replace=False)
        g_theta = np.zeros((d_theta,))
        if exact_comp:
            for m in batch:
                g_theta += m.operator(theta)
        else:
            for m in batch:
                w = rng.normal(size=d_w)
                for l in range(tau):
                    w -= lrin*m.grad_w(theta, w)
                g_theta += m.grad_theta(theta, w)
            
        theta -= lrout/batch_size*g_theta
        
        if t%10==0:
            history['F_norm'].append(full_op.operator_norm(theta))
            history['F_norm_min'].append(min(history['F_norm_min'][-1], history['F_norm'][-1]))
            history['iter'].append(t+1)
            print(f'iteration {t+1}', history['F_norm'][-1])

            with open(f'./results/FFGG_cl{M}_lrout{lrout}_lrin{lrin}_dt{d_theta}_dw{d_w}_T{T}_tau{tau}_zeta{zeta}.pkl', 'wb') as fp:
                pickle.dump(history, fp)
            
    return theta, history

In [None]:
def Local_SGD(theta0, w0, batch_size, lrout, lrin, T, tau, clients, M, d_theta, d_w, n_data):
    theta = theta0.copy()
    w = w0.copy()
    full_op = F(clients, d_theta, M)
    
    history = {'F_norm':[full_op.operator_norm(theta)], 'F_norm_min':[full_op.operator_norm(theta)]}
    history['iter'] = [0]
    print('iteration 0', history['F_norm_min'][-1])
    
    theta_local = [theta.copy() for m in range(M)]
    w_local = [w.copy() for m in range(M)]
    for t in range(T):
        for client in clients:
            theta_local[client.m] = theta.copy()
            w_local[client.m] = w.copy()
            for l in range(tau):
                g_theta = client.grad_theta(theta_local[client.m], w_local[client.m])
                g_w = client.grad_w(theta_local[client.m], w_local[client.m])
                theta_local[client.m] -= lrin*g_theta
                w_local[client.m] -= lrin*g_w
                
        theta = np.mean(theta_local, axis=0)
        w = np.mean(w_local, axis=0)
        
        if t%10==0:
            history['F_norm'].append(full_op.operator_norm(theta))
            history['F_norm_min'].append(min(history['F_norm_min'][-1], history['F_norm'][-1]))
            history['iter'].append(t+1)
            print(f'iteration {t+1}', history['F_norm'][-1])

            with open(f'./results/LocalSGD_cl{M}_lrout{lrout}_lrin{lrin}_dt{d_theta}_dw{d_w}_T{T}_tau{tau}_zeta{zeta}.pkl', 'wb') as fp:
                pickle.dump(history, fp)
        
    return theta, history

In [None]:
def SCAFFOLD(theta0, w0, batch_size, lrout, lrin, T, tau, clients, M, d_theta, d_w, n_data):
    theta = theta0.copy()
    w = w0.copy()
    full_op = F(clients, d_theta, M)
    
    c_theta = np.zeros((d_theta,)) #global 
    c_w = np.zeros((d_w,)) # global
    
    C_theta = [np.zeros((d_theta,)) for m in range(M)] #local 
    C_w = [np.zeros((d_w,)) for m in range(M)] #local
    
    history = {'F_norm':[full_op.operator_norm(theta)], 'F_norm_min':[full_op.operator_norm(theta)]}
    history['iter'] = [0]
    print('iteration 0', history['F_norm_min'][-1])
    
    for t in range(T):
        batch = rng.choice(clients, size=batch_size, replace=False)
        
        theta_store = np.zeros((d_theta,))
        w_store = np.zeros((d_w,))

        delta_c_theta = np.zeros((d_theta, ))
        delta_c_w = np.zeros((d_w,))

        delta_y_theta = np.zeros((d_theta,))
        delta_y_w = np.zeros((d_w,))
        for client in batch:
            theta_local = theta.copy()
            w_local = w.copy()
            for l in range(tau):
                g_theta = client.grad_theta(theta_local, w_local)
                g_w = client.grad_w(theta_local, w_local)

                theta_local -= lrin*(g_theta - C_theta[client.m] + c_theta)
                w_local -= lrin*(g_w - C_w[client.m] + c_w)
            
            #compute c_i^+
            c_theta_plus = client.grad_theta(theta, w)
            c_w_plus = client.grad_w(theta, w)

            #compute Delta c_i
            delta_c_theta += c_theta_plus - C_theta[client.m]
            delta_c_w += c_w_plus - C_w[client.m]

            #compute Delta y_i
            delta_y_theta = theta_local - theta
            delta_y_w = w_local - w

            #update c_i <- c_i^+
            C_theta[client.m] = c_theta_plus
            C_w[client.m] = c_w_plus

        #update x
        theta += lrout*delta_y_theta/batch_size
        w += lrout*delta_y_w/batch_size

        #update c
        c_theta += batch_size/M*delta_c_theta
        c_w += batch_size/M*delta_c_w


        if t%10==0:
            history['F_norm'].append(full_op.operator_norm(theta))
            history['F_norm_min'].append(min(history['F_norm_min'][-1], history['F_norm'][-1]))
            history['iter'].append(t+1)
            print(f'iteration {t+1}', history['F_norm'][-1])

            with open(f'./results/Scaffold_cl{M}_lrout{lrout}_lrin{lrin}_dt{d_theta}_dw{d_w}_T{T}_tau{tau}_zeta{zeta}.pkl', 'wb') as fp:
                pickle.dump(history, fp)
        
    return theta, history

In [None]:
def Mixture(theta0, w0, lrin, T, tau, clients, lmbd, M, d_theta, d_w, n_data):
    full_op = F(clients, d_theta, M)

    p = 1/tau
    theta_local = [theta0.copy() for m in range(M)]
    w_local = [w0.copy() for m in range(M)]

    theta = np.mean(theta_local, axis=0)
    w = np.mean(w_local, axis=0)
    
    history = {'F_norm':[full_op.operator_norm(theta)], 'F_norm_min':[full_op.operator_norm(theta)]}
    history['iter'] = [0]
    print('iteration 0', history['F_norm_min'][-1])
    
    xi = bernoulli(p)
    for t in range(T):
        
        sample = xi.rvs(1)[0]
        if sample==0:
            for client in clients:
                #local steps
                g_theta = client.grad_theta(theta_local[client.m], w_local[client.m])
                g_w = client.grad_w(theta_local[client.m], w_local[client.m])

                theta_local[client.m] -= lrin/(M*(1-p))*g_theta
                w_local[client.m] -= lrin/(M*(1-p))*g_w
        else:
            theta = np.mean(theta_local, axis=0)
            w = np.mean(w_local, axis=0)

            for client in clients:
                temp = lrin*lmbd/(M*p)
                theta_local[client.m] = (1-temp)*theta_local[client.m] + temp*theta
                w_local[client.m] = (1-temp)*w_local[client.m] + temp*w
        
        
        if t%10==0:
            theta = np.mean(theta_local, axis=0)
            w = np.mean(w_local, axis=0)
            history['F_norm'].append(full_op.operator_norm(theta))
            history['F_norm_min'].append(min(history['F_norm_min'][-1], history['F_norm'][-1]))
            history['iter'].append(t+1)
            print(f'iteration {t+1}', history['F_norm'][-1])

            
            with open(f'./results/Mixture_cl{M}_lmbd{lmbd}_lrin{lrin}_dt{d_theta}_dw{d_w}_T{T}_tau{tau}_zeta{zeta}.pkl', 'wb') as fp:
                pickle.dump(history, fp)
        
    return np.mean(theta_local), history

In [None]:
L = compute_L(clients, M, d_theta)
L, 1/L

In [None]:
compute_L_mu(clients)

In [None]:
1/2

In [None]:
method = 'FFGG'
batch_size = M
tau = 100
lrout = 1/L
lrin = 1/2
T = 200

In [None]:
theta_ffgg_full, record_ffgg_full = FFGG(theta0, batch_size, lrout, lrin, 
                           T, tau, clients, M, d_theta, d_w, n_data, exact_comp=False)