In [None]:
import matplotlib
import numpy as np
import psutil
import ray
import os
import seaborn as sns
import time

import matplotlib.pyplot as plt
import numpy.linalg as la

sns.set(style="whitegrid", context="talk", font_scale=1.2, palette=sns.color_palette("bright"), color_codes=False)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = 'DejaVu Sans'
matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams['figure.figsize'] = (9, 6)

$$\phi_m(\theta) = \frac{1}{2}\|H_m\theta - b_m\|^2$$
$$\nabla \phi_m(\theta) = H_m^\top(H_m\theta - b_m)$$
$$
\begin{align}
f_m(\theta, w) &= \phi_m(\theta) + \frac{1}{2}\|A_m\theta+B_mw-y_m\|^2\\
\nabla_1f_m(\theta, w) &= \nabla \phi_m(\theta) + A_m^\top(A_m\theta+B_mw-y_m)\\
\nabla_2f_m(\theta, w) &= B_m^\top(A_m\theta+B_mw-y_m)\\
w_m^*(\theta) &= (B_m^\top B_m)^{-1}(B_m^\top y_m - B_m^\top A_m\theta)
\end{align}
$$

In [None]:
def generate_A_b(H, n, d, zeta, noise_scale):
    A = rng.normal(size=(n, d)) / d
    A = H + A/np.linalg.norm(A)
    x = rng.normal(size=d)
    b = A @ x + noise_scale * rng.normal(n)
    return A, b


def generate_A_B_c(A, B, n, d1, d2, zeta, noise_scale):
    A1 = rng.normal( size=(n, d1), ) / (d1)
    B1 = rng.normal( size=(n, d2), ) / (d2)
    A = A + zeta*A1/np.linalg.norm(A1)
    B = B + zeta*B1/np.linalg.norm(B1)
    x1 = rng.normal(size=d1)
    x2 = rng.normal(size=d2)
    y = A @ x1 + B @ x2 + noise_scale * rng.normal(n)
    return A, B, y

In [None]:
@ray.remote
class ParameterServer(object):
    def __init__(self, lr, asynchronous, d_theta):
        self.x = np.zeros(d_theta)
        self.lr = lr
        self.asynchronous = asynchronous

    def apply_gradients(self, update, *updates):
        if self.asynchronous:
            self.x -= self.lr * update
        else:
            summed_updates = np.sum(updates, axis=0)
            self.x -= self.lr * summed_updates

        return self.x

    def get_x(self):
        return self.x
    
    def update_lr(self, lr_coef_mul=1, lr_new=None):
        if lr_new is not None:
            self.lr = lr_new
        else:
            self.lr *= lr_coef_mul
        
    def get_hyperparams(self):
        return self.lr, self.asynchronous

In [None]:
def phi_grad(idx, list_Hm_bm, theta):
    (Hm, bm) = list_Hm_bm[idx]
    return Hm.T@(Hm@theta - bm)

def grad_theta(idx, list_Hm_bm, list_Am_Bm_ym, theta, w):
    (Am, Bm, ym) = list_Am_Bm_ym[idx]
    return phi_grad(idx, list_Hm_bm, theta) + Am.T@(Am@theta + Bm@w - ym)

def grad_w(idx, list_Am_Bm_ym, theta, w):
    (Am, Bm, ym) = list_Am_Bm_ym[idx]
    return Bm.T@(Am@theta + Bm@w - ym)

def opt_w(idx, list_Am_Bm_ym, theta):
    (Am, Bm, ym) = list_Am_Bm_ym[idx]
    A = Bm.T@Bm
    b = Bm.T@(ym - Am@theta)
    return np.linalg.solve(A, b)

def operator(idx, list_Hm_bm, list_Am_Bm_ym, theta):
    w_star = opt_w(idx, list_Am_Bm_ym, theta)
    return grad_theta(idx, list_Hm_bm, list_Am_Bm_ym, theta, w_star)

def evaluate(d_theta, list_Hm_bm, list_Am_Bm_ym, theta):
    out = np.zeros((d_theta,))
    for m in range(len(list_Hm_bm)):
        out += operator(m, list_Hm_bm, list_Am_Bm_ym, theta)
    return sum((out / len(list_Hm_bm))**2)

In [None]:
def local_operator(idx, theta):
    w_star = opt_sol(idx, theta)
    return grad_1(theta, w_star)

def operator_norm_(theta):
    return sum(local_operator(theta)**2)

@ray.remote
class DataWorker(object):
    """
    The class for an individual Ray worker.
    Arguments:
        lr (float): the stepsize to be used at initialization
        label (int, optional): batch size for sampling gradients (default: 1)
        seed (int, optional): random seed to generate random variables for reproducibility (default: 0)
        bad_worker (bool, optional): if True, the worker will be forced to be slower than others (default: False)
    """
    def __init__(self, idx, lr, n, d_theta, d_w, tau, list_Hm_bm, list_Am_Bm_ym, w0, exact_comp=False, bad_worker=False, seed=0):
        self.m = idx
        self.lr = lr
        self.tau = tau
        self.lrin = lr
        self.d_theta = d_theta
        self.d_w = d_w
        self.bad_worker = bad_worker
        self.exact_comp = exact_comp
        self.rng = np.random.default_rng(seed)
        #self.Hm, self.bm = generate_A_b(H, n, d_theta, zeta, noise_scale)
        (self.Hm, self.bm)  = list_Hm_bm[idx]
        (self.Am, self.Bm, self.ym) = list_Am_Bm_ym[idx]
        self.w0 = w0
        #self.Am, self.Bm, self.ym = generate_A_B_c(A, B, n, d_theta, d_w, zeta, noise_scale)

    def grad_1(self, theta, w):
        return self.Hm.T@(self.Hm@theta - self.bm) + self.Am.T@(self.Am@theta + self.Bm@w - self.ym)
    
    def grad_2(self, theta, w):
        return self.Bm.T@(self.Am@theta + self.Bm@w - self.ym)
    
    def opt_sol(self, theta):
        A = self.Bm.T@self.Bm
        b = self.Bm.T@(self.ym - self.Am@theta)
        return np.linalg.solve(A, b)

    def compute_gradients(self, theta):
        t0 = time.perf_counter()
        if not self.exact_comp:
            w = self.w0.copy()
            for t_local in range(self.tau):
                grad = self.grad_2( theta, w)
                w -= self.lrin*grad
    
            grad_theta = self.grad_1(theta, w)
        else:
            w_star = self.opt_sol(theta)
            grad_theta = self.grad_1(theta, w_star)
            
        
        if self.bad_worker:
            dt = time.perf_counter() - t0
            time.sleep(100 * dt)

        return grad_theta
    
    def update_lr(self, lr_coef_mul=1, lr_new=None):
        if lr_new is not None:
            self.lr = lr_new
        else:
            self.lr *= lr_coef_mul
        
    def get_hyperparams(self):
        return self.lr, self.batch_size
    
    def get_lr(self):
        return self.lr

In [None]:
def run(seeds, num_workers, lrout, lrin, list_Hm_bm, list_Am_Bm_ym, lr_decay=0, iterations=200, asynchronous=True, delay_adaptive=False, it_check=20, 
        n=1000, d_theta=100, d_w=50, tau=20, exact_comp=False,
        one_bad_worker=False):
    delays_all = []
    worker_updates = [0 for i in range(num_workers)]
    rng = np.random.default_rng(42)
    seeds_workers = [rng.choice(max_seed, size=1, replace=False)[0] for _ in range(num_workers)]
    ray.init(ignore_reinit_error=True)
    ps = ParameterServer.remote(lrout, asynchronous, d_theta)
    workers = []
    for i in range(num_workers):
        workers.append(DataWorker.remote(idx=i, lr=lrin, n=n, d_theta=d_theta, d_w=d_w, tau=tau, 
                                 list_Hm_bm=list_Hm_bm, list_Am_Bm_ym=list_Am_Bm_ym, w0=np.zeros((d_w,)),
                                 seed=seeds_workers[i], exact_comp=exact_comp))
   # workers = [ for i in range(num_workers)]

    x = ps.get_x.remote()
    if asynchronous:
        gradients = {}
        worker_last_it = [0 for _ in range(num_workers)]
        worker_id_to_num = {}
        for e, worker in enumerate(workers):
            gradients[worker.compute_gradients.remote(x)] = worker
            worker_id_to_num[worker] = e


    losses = []
    its = []
    ts = []
    delays = []
    t0 = time.perf_counter()
    delay = 0
    trace = []
    grads_per_it = 1 if asynchronous else num_workers

    for it in range(iterations * (num_workers if asynchronous else 1)):
        n_grads = it * grads_per_it
        if asynchronous:
            ready_gradient_list, _ = ray.wait(list(gradients))
            ready_gradient_id = ready_gradient_list[-1]
            worker = gradients.pop(ready_gradient_id)

            # Compute and apply gradients.
            gradients[worker.compute_gradients.remote(x)] = worker
            worker_num = worker_id_to_num[worker]
            delay = it - worker_last_it[worker_num]
            if delay_adaptive:
                lr_new = lr * num_workers / max(num_workers, delay)
                ps.update_lr.remote(lr_new=lr_new)
            x = ps.apply_gradients.remote(update=ready_gradient_id)
            worker_last_it[worker_num] = it
            worker_updates[worker_num] += 1
        else:
            gradients = [
                worker.compute_gradients.remote(x) for worker in workers
            ]
            # Calculate update after all gradients are available.
            x = ps.apply_gradients.remote(None, *gradients)

        if it % it_check == 0 or (not asynchronous and it % (max(it_check // num_workers, 1)) == 0):
            # Evaluate the current model.
            x = ray.get(ps.get_x.remote())
            trace.append(x.copy())
            its.append(it)
            ts.append(time.perf_counter() - t0)

        lr_new = lrout / (1 + lr_decay * n_grads)
        ps.update_lr.remote(lr_new=lr_new)
        t = time.perf_counter()
        if asynchronous:
            delays.append(delay)

    ray.shutdown()
    return np.asarray(its), np.asarray(ts), np.asarray([evaluate(d_theta, list_Hm_bm, list_Am_Bm_ym, x) for x in trace]), np.asarray(delays)

# Parameters

In [None]:
psutil.cpu_count(logical=True)

In [None]:
iterations = 800
num_workers = 40

d_theta = 400
d_w = 50
n_data = 10000

lmb = 0
zeta = 10
tau = 10
noise_scale = 1e-3

M = num_workers
it_check = 40
n_seeds = 5
max_seed = 424242
rng = np.random.default_rng(42)
seeds = [rng.choice(max_seed, size=1, replace=False)[0] for _ in range(n_seeds)]
seed_to_run = {}
for r, seed in enumerate(seeds):
    seed_to_run[seed] = r

In [None]:
H = rng.uniform(size=(n_data, d_theta)) / d_theta
A = rng.uniform(size=(n_data, d_theta)) / d_theta
B = rng.uniform(size=(n_data, d_w)) / d_w

list_Hm_bm = [generate_A_b(H, n_data, d_theta, zeta, noise_scale) for m in range(M)]
list_Am_Bm_ym = [generate_A_B_c(A, B, n_data, d_theta, d_w, zeta, noise_scale) for m in range(M)]

In [None]:
lrout_sync = 1 / 325
lrin_sync = 1 / 180

lr_decay = 0
its_, ts_, losses_, _ = run(seeds, num_workers, lrout=lrout_sync, lrin=lrin_sync, list_Hm_bm=list_Hm_bm, list_Am_Bm_ym=list_Am_Bm_ym, 
                            lr_decay=lr_decay, iterations=iterations, 
                            asynchronous=False, delay_adaptive=False, it_check=it_check, 
                            n=n_data, d_theta=d_theta, d_w=d_w, tau=tau,
                            one_bad_worker=False, exact_comp=False)

In [None]:
plt.semilogy(losses_)

In [None]:
lrout_async = 1 / 325
lrin_async = 1 / 180

lr_decay = 0
its_as_, ts_as_, losses_as_, delays = run(seeds, num_workers, lrout=lrout_async, lrin=lrin_async, list_Hm_bm=list_Hm_bm, list_Am_Bm_ym=list_Am_Bm_ym, 
                            lr_decay=lr_decay, iterations=iterations, 
                            asynchronous=True, delay_adaptive=False, it_check=it_check, 
                            n=n_data, d_theta=d_theta, d_w=d_w, tau=tau, exact_comp=True,
                            one_bad_worker=False)

In [None]:
plt.semilogy(losses_as_)