# Importing packages

In [1]:
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
from sklearn.datasets import load_svmlight_file
import numpy as np
import pickle

In [2]:
from matplotlib.colors import LogNorm

In [3]:
import pandas as pd

In [4]:
import time

In [5]:
import seaborn as sns

# Functions

In [6]:
def loss(w, X, y, lambda_):
    z = - X * y[:, None]
    exp = jnp.exp(z @ w)
    log = jnp.log(1 + exp)
    loss_1 = log.mean()
    reg = 1. - 1. / (w ** 2 + 1.)
    reg = reg.sum()
    return loss_1 + lambda_ * reg

In [7]:
# def _sigmoid(x):
#     return 1 / (1 + jnp.exp(-x))


# def get_h(x, w):
#     z = x.dot(w)
#     h = _sigmoid(z)
#     return h

In [8]:
# def f_value(w, X, y):
#     h = get_h(X, w)
#     zeros = jnp.zeros_like(h)
#     return -(jnp.where(y == 1, jnp.log(h), zeros) +
#              jnp.where(y == 0, jnp.log(1-h), zeros)).mean()


# def loss(w, X, y, lambda_):
#     d = X.shape[1]
#     return f_value(w, X, y) - lambda_ * (1 / (1 + w ** 2)).sum() + d * lambda_

In [9]:
# def grad(w, X, y, lambda_):
#     z = - X.toarray() * y[:, None]
#     coef = 1. - 1. / (1 + jnp.exp(z @ w))
#     grad_loss_1 = (z * coef[:, None]).mean(axis=0)
#     grad_loss_2 = 2 * w * (1. / (1 + w ** 2) ** 2)
#     return grad_loss_1 + lambda_ * grad_loss_2

In [10]:
grad = jax.jit(jax.grad(loss))

In [11]:
def grad_compute(w, data, lambda_):
    return jnp.array([grad(w, feats, labels, lambda_)
                      for (feats, labels) in data])

In [12]:
def smoothness(features, lambda_):
    xtx = features.T.dot(features)
    # xtx = xtx.toarray()
    n = features.shape[0]
    return np.max(np.linalg.eigvalsh(xtx)) / (4 * n) + 2 * lambda_

In [13]:
def top_k(vec, k):
    assert len(vec) >= k
    _, inds = jax.lax.top_k(jnp.abs(vec), k)
    inds = list(inds)
    mask = np.zeros_like(vec)
    mask[inds] = 1
    return np.multiply(vec, mask)

In [14]:
def beta_top_k(k, d):
    alpha = float(k) / float(d)
    return (1 - alpha) / (1 - np.sqrt(1 - alpha))

In [15]:
def theta_top_k(k, d):
    alpha = float(k) / float(d)
    return 1 - np.sqrt(1 - alpha)

In [16]:
def print_and_return(hist_list):
    min_ind = 0
    curr_ind = 0
    min_val = hist_list[0][0][-1]
    for hist in hist_list:
        if hist[0][-1] < min_val:
            min_ind = curr_ind
            min_val = hist[0][-1]
        curr_ind += 1
        print(hist[0][-1], hist[1][-1])
    return min_ind

# Algorithms

In [17]:
def LAG_step(x_k, g_k, comm, data, lambda_, grad_k_prev, zeta, stepsize):
    x_k -= stepsize * g_k.mean(axis=0)
    grad_k = np.array(grad_compute(x_k, data, lambda_), dtype=np.float64)
    trigger_rhs = zeta * \
        np.linalg.norm(grad_k - grad_k_prev, ord=2, axis=1) ** 2
    trigger_lhs = np.linalg.norm(g_k - grad_k, ord=2, axis=1) ** 2
    trigger = jnp.expand_dims(jnp.array(trigger_lhs > trigger_rhs), 1)
    g_k = jnp.multiply(grad_k, trigger) + jnp.multiply(g_k, 1 - trigger)
    comm += len(x_k) * trigger.sum()
    return x_k, g_k, grad_k, comm

In [18]:
def LAG(x_0, data, lambda_, zeta, max_comm):
    g_k = jnp.array(grad_compute(x_0, data, lambda_), dtype=np.float64)
    grad_k_prev = jnp.array(np.copy(g_k))
    x_k = jnp.array(np.copy(x_0))
    history = [np.linalg.norm(g_k.mean(axis=0))]
    comm = len(data) * len(x_0)
    history_comm = [0]
    stepsize = 1. / (L + L_tilde * jnp.sqrt(zeta))
    while(history_comm[-1] < max_comm):
        x_k, g_k, grad_k_prev, comm = LAG_step(
            x_k, g_k, comm, data, lambda_, grad_k_prev, zeta, stepsize
        )
        history.append(np.linalg.norm(grad_k_prev.mean(axis=0)))
        history_comm.append(comm)
        print('Currently communicated {} float numbers'.format(comm), end='\r')
    return history, history_comm

In [19]:
def GD(x_0, data, lambda_, num_clients, max_iter):
    g_k = None
    x_k = jnp.array(np.copy(x_0), dtype=np.float64)
    history = []
    history_comm = []
    stepsize = 1. / L
    comm = 0
    for i in range(max_iter):
        g_k = jnp.array(grad_compute(x_k, data, lambda_), dtype=np.float64)
        print('Iteration {} / {}'.format(i + 1, max_iter), end='\r')
        grad_ = g_k.mean(axis=0)
        history.append(np.linalg.norm(grad_))
        history_comm.append(comm)
        x_k -= stepsize * grad_
        comm += num_clients * len(x_0)
    return history, history_comm

In [20]:
def CLAG_step(x_k, g_k, comm, data, lambda_,
              trigger_beta, grad_k_prev, k, stepsize):
    x_k -= stepsize * g_k.mean(axis=0, dtype=np.float64)
    grad_k = np.array(grad_compute(x_k, data, lambda_), dtype=np.float64)
    trigger_rhs = trigger_beta * np.linalg.norm(
        grad_k - grad_k_prev, ord=2, axis=1) ** 2
    trigger_lhs = np.linalg.norm(g_k - grad_k, ord=2, axis=1) ** 2
    trigger = jnp.expand_dims(jnp.array(trigger_lhs > trigger_rhs), 1)
    compressed = jnp.vstack([top_k(grad_k[i] - g_k[i], k)
                            for i in range(len(g_k))])
    g_k += jnp.multiply(compressed, trigger)
    comm += trigger.sum() * k
    return x_k, g_k, grad_k, comm

In [21]:
def CLAG(x_0, data, lambda_, k, trigger_beta, max_comm):
    assert trigger_beta > 0
    alpha = float(k) / len(x_0)
    beta = (1 - alpha) / (1 - np.sqrt(1 - alpha))
    theta = 1 - np.sqrt(1 - alpha)
    g_k = np.array(grad_compute(x_0, data, lambda_), dtype=np.float64)
    grad_k_prev = jnp.array(np.copy(g_k))
    x_k = jnp.array(np.copy(x_0))
    history = [np.linalg.norm(g_k.mean(axis=0))]
    history_comm = [0]
    comm = len(x_0) * len(data)
    stepsize = 1. / (L + L_tilde * jnp.sqrt(max(beta, trigger_beta)/theta))
    while(history_comm[-1] < max_comm):
        print('Currently communicated {} float numbers'.format(comm), end='\r')
        history_comm.append(int(comm))
        x_k, g_k, grad_k_prev, comm = CLAG_step(
            x_k, g_k, comm, data, lambda_, trigger_beta,
            grad_k_prev, k, stepsize
        )
        history.append(np.linalg.norm(grad_k_prev.mean(axis=0)))
    return history, history_comm

In [22]:
def CLAG_it(x_0, data, lambda_, k, trigger_beta, tol, stepsize, time_budget):
    assert trigger_beta >= 0
    g_k = jnp.array(grad_compute(x_0, data, lambda_), dtype=np.float64)
    grad_k_prev = jnp.array(np.copy(g_k))
    x_k = jnp.array(np.copy(x_0))
    comm = len(x_0) * len(data)
    begin_time = time.time()
    success_flag = True
    while np.linalg.norm(grad_k_prev.mean(axis=0)) > tol:
        x_k, g_k, grad_k_prev, comm = CLAG_step(
            x_k, g_k, comm, data, lambda_, trigger_beta,
            grad_k_prev, k, stepsize
        )
        print('Tolerance = ', np.linalg.norm(grad_k_prev.mean(axis=0)),
              end='\r')
        if time.time() - begin_time > time_budget:
            success_flag = False
            break
    print('')
    return comm, success_flag

In [23]:
def heatmap_CLAG(x_0,
                 data,
                 lambda_,
                 ks,
                 trigger_betas,
                 tol,
                 stepsize_coefs,
                 time_budget,
                 file):
    heatmap = jnp.zeros(shape=(len(ks), len(trigger_betas)))
    d = len(x_0)
    assert trigger_betas[0] == 0
    for k_id, k in enumerate(ks):
        print('k = ', k)
        # choosing the best stepsize for EF21
        beta = beta_top_k(k, d)
        theta = theta_top_k(k, d)
        theoretical_stepsize = 1. / (L + L_tilde * np.sqrt(beta / theta))
        best_comm = float('inf')
        best_stepsize = 0.
        for coef in stepsize_coefs:
            stepsize = coef * theoretical_stepsize
            comm, flag = CLAG_it(x_0, data, lambda_, k, 0, tol, stepsize, time_budget)
            if flag and comm < best_comm:
                best_comm = comm
                best_stepsize = stepsize
        print(best_stepsize)
        for beta_id, trigger_beta in enumerate(trigger_betas):
            print('trigger_beta = ', trigger_beta)
            comm, flag = CLAG_it(x_0, data, lambda_, k, trigger_beta, tol, best_stepsize, time_budget)
            if flag:
                heatmap = heatmap.at[k_id, beta_id].set(comm)
            else:
                heatmap = heatmap.at[k_id, beta_id].set(-1)
            jnp.save(file, heatmap)
    return heatmap

In [24]:
def heatmap_CLAG_full(x_0,
                      data,
                      lambda_,
                      ks,
                      trigger_betas,
                      tol,
                      stepsize_coefs,
                      time_budget,
                      file):
    heatmap = jnp.zeros(shape=(len(ks), len(trigger_betas)))
    d = len(x_0)
    for k_id, k in enumerate(ks):
        print('k = ', k)
        beta = beta_top_k(k, d)
        theta = theta_top_k(k, d)
        for beta_id, trigger_beta in enumerate(trigger_betas):
            print('trigger_beta = ', trigger_beta)
            # choosing the best stepsize for CLAG
            theoretical_stepsize = 1. / (L + L_tilde * np.sqrt(max(beta, trigger_beta) / theta))
            best_comm = float('inf')
            best_stepsize = 0.
            for coef in stepsize_coefs:
                stepsize = coef * theoretical_stepsize
                comm, flag = CLAG_it(x_0, data, lambda_, k, trigger_beta, tol, stepsize, time_budget)
                if flag and comm < best_comm:
                    best_comm = comm
                    best_stepsize = stepsize
            print(best_stepsize)
            if best_comm != float('inf'):
                heatmap = heatmap.at[k_id, beta_id].set(best_comm)
            else:
                heatmap = heatmap.at[k_id, beta_id].set(-1)
            jnp.save(file, heatmap)
    return heatmap

In [25]:
def EF21(x_0, data, lambda_, k, max_iter):
    alpha = float(k) / len(x_0)
    beta = (1 - alpha) / (1 - np.sqrt(1 - alpha))
    theta = 1 - np.sqrt(1 - alpha)
    g_k = jnp.array(grad_compute(x_0, data, lambda_), dtype=np.float64)
    x_k = jnp.array(np.copy(x_0))
    history = [np.linalg.norm(g_k.mean(axis=0))]
    history_comm = [0]
    comm = len(x_0) * len(data)
    stepsize = 1. / (L + L_tilde * np.sqrt(beta/theta))
    for i in range(max_iter):
        print('Currently communicated {} float numbers'.format(comm), end='\r')
        x_k -= stepsize * g_k.mean(axis=0)
        history_comm.append(int(comm))
        grad_k = jnp.array(grad_compute(x_k, data, lambda_), dtype=np.float64)
        history.append(np.linalg.norm(grad_k.mean(axis=0)))
        compressed = jnp.vstack([top_k(grad_k[i] - g_k[i], k)
                                 for i in range(len(g_k))])
        g_k += compressed
        comm += len(data) * k
        print(comm, end='\r')
    return history, history_comm

# Phishing dataset

## Setup

In [26]:
num_clients = 20

In [27]:
lambda_ = 0.1

In [28]:
dataset_name = 'phishing'

In [29]:
raw_data = load_svmlight_file('../data/phishing')

In [30]:
X, y = raw_data

In [31]:
y = 2 * y - 1

In [32]:
residual = X.shape[0] % num_clients

In [33]:
X = X[:-residual].toarray()
y = y[:-residual]

In [34]:
n = X.shape[0]

In [35]:
d = X.shape[1]

In [36]:
y

array([-1., -1., -1., ..., -1., -1., -1.])

In [37]:
inds = np.array_split(np.arange(n), num_clients)

In [38]:
data = []
for i in range(num_clients):
    data.append((X[inds[i]][:], y[inds[i]]))

In [39]:
L = smoothness(X, lambda_)

In [40]:
L

0.3625653679762313

In [41]:
L_i = [smoothness(data[i][0], lambda_) for i in range(num_clients)]

In [42]:
L_tilde = np.sqrt((np.array(L_i) ** 2).mean())

In [43]:
L_tilde

0.36609956525744064

In [44]:
n

11040

## Experiments

In [45]:
max_comm = 20000

### 1. LAG tuning

In [None]:
zetas = np.geomspace(1e-2, 1e1, 4)

In [None]:
zetas

In [None]:
def LAG_grid(x_0, data, lambda_, zetas, max_comm):
    LAG_histories = []
    for zeta in zetas:
        print('zeta = {}'.format(zeta), end='\r')
        hist = LAG(x_0, data, lambda_, zeta, max_comm)
        LAG_histories.append(hist)
    return LAG_histories

In [None]:
x_0 = np.zeros(d)

In [None]:
LAG_histories = LAG_grid(x_0, data, lambda_, zetas, max_comm)

In [None]:
with open('../results/lag_phishing.pickle', 'wb') as file:
    pickle.dump(LAG_histories, file)

In [None]:
with open('../results/lag_phishing.pickle', 'rb') as file:
    LAG_histories = pickle.load(file)

In [None]:
min_ind = print_and_return(LAG_histories)

In [None]:
print(min_ind)

In [None]:
h_LAG, h_LAG_comm = LAG_histories[min_ind]

In [None]:
new_h_LAG, new_h_LAG_comm = LAG(x_0, data, lambda_, 1., max_comm)

In [None]:
# preliminary plot
plt.plot(h_LAG_comm, h_LAG)
plt.plot(new_h_LAG_comm, new_h_LAG)
plt.yscale('log')

### 2. EF21 tuning

In [None]:
ks = np.linspace(1, d, 10, endpoint=False, dtype=int)

In [None]:
ks

In [None]:
def EF21_grid(x_0, data, d, lambda_, ks, max_comm):
    EF21_histories = []
    for k in ks:
        print('k=', k)
        max_iter = 2 + int((max_comm - d * len(data)) / (k * num_clients))
        hist = EF21(x_0, data, lambda_, k, max_iter)
        EF21_histories.append(hist)
    return EF21_histories

In [None]:
EF21_histories = EF21_grid(x_0, data, d, lambda_, ks, max_comm)

In [None]:
with open('../results/ef21_phishing.pickle', 'wb') as file:
    pickle.dump(EF21_histories, file)

In [None]:
min_ind = print_and_return(EF21_histories)

In [None]:
min_ind

In [None]:
h_EF21, h_EF21_comm = EF21_histories[min_ind]

In [None]:
# preliminary plot
plt.plot(h_LAG_comm, h_LAG, label='LAG')
plt.plot(h_EF21_comm, h_EF21, label='EF21')
plt.legend()
plt.yscale('log')

### 3. CLAG tuning

In [None]:
x_0 = np.zeros(d)

In [None]:
ks = np.flip(np.linspace(1, d, 6, endpoint=True, dtype=int))
# beta_multipliers = np.geomspace(1e-3, 10, num=5)

In [None]:
ks = ks[:-1]

In [None]:
ks

In [None]:
trigger_betas = [beta_top_k(k, d) for k in ks]

In [None]:
# trigger_betas = trigger_betas[:-1] + [10]

In [None]:
trigger_betas

In [None]:
def CLAG_grid(x_0, data, lambda_, ks, beta_multipliers, max_comm):
    CLAG_histories = []
    for k in ks:
        curr_list = []
        for mult in beta_multipliers:
            trigger_beta = mult * beta_top_k(k, d)
            print('k = ', k, 'trigger_beta = ', trigger_beta)
            hist = CLAG(x_0, data, lambda_, k, trigger_beta, max_comm)
            curr_list.append(hist)
        CLAG_histories.append(curr_list)
    return CLAG_histories

In [None]:
CLAG_histories = CLAG_grid(x_0, data, lambda_, ks, beta_multipliers, max_comm)

In [None]:
with open('../results/clag_phishing.pickle', 'wb') as file:
    pickle.dump(CLAG_histories, file)

In [None]:
with open('../results/clag_phishing.pickle', 'rb') as file:
    CLAG_histories = pickle.load(file)

In [None]:
len(CLAG_histories)

In [None]:
min_inds = []
for hist_k in CLAG_histories:
    ind = print_and_return(hist_k)
    min_inds.append(ind)

In [None]:
min_inds

In [None]:
min_list = [CLAG_histories[ind][min_inds[ind]] for ind in range(len(ks))]

In [None]:
min_ind = print_and_return(min_list)

In [None]:
min_ind

In [None]:
h_CLAG, h_CLAG_comm = min_list[min_ind]

In [None]:
new_h_CLAG, new_h_CLAG_comm = CLAG(x_0, data, lambda_, 7, beta_top_k(7, d), max_comm)

In [None]:
# preliminary plot
# plt.plot(h_LAG_comm, h_LAG, label='LAG')
# plt.plot(h_EF21_comm, h_EF21, label='EF21')
plt.plot(h_CLAG_comm, h_CLAG, label='CLAG')
plt.plot(new_h_CLAG_comm, new_h_CLAG, label='check')
plt.legend()
plt.yscale('log')

#### heatmap

In [46]:
x_0 = np.zeros(d)

In [47]:
ks = np.flip(np.linspace(1, d, 8, endpoint=True, dtype=int))

In [48]:
ks

array([68, 58, 48, 39, 29, 20, 10,  1])

In [49]:
trigger_betas = 2 ** jnp.arange(0, 10, dtype=np.float32)

In [50]:
trigger_betas = jnp.insert(trigger_betas, 0, 0)

In [51]:
trigger_betas

DeviceArray([  0.,   1.,   2.,   4.,   8.,  16.,  32.,  64., 128., 256.,
             512.], dtype=float32)

In [52]:
stepsize_coefs = 2 ** jnp.arange(-1, 6, dtype=np.float32)

In [None]:
heatmap = heatmap_CLAG(x_0, data, lambda_, ks, trigger_betas, 1e-5,
                       stepsize_coefs, 60, '../results/heatmap_phishing.npy')

In [None]:
df = pd.DataFrame(heatmap, index=ks, columns=trigger_betas, dtype=int)

In [None]:
df

In [None]:
heatmap_full = heatmap_CLAG_full(x_0,
                                 data,
                                 lambda_,
                                 ks,
                                 trigger_betas,
                                 1e-5,
                                 stepsize_coefs,
                                 60,
                                 '../results/heatmap_phishing_full_prec5.npy'
                                )

k =  68
trigger_beta =  0.0


  lax._check_user_dtype_supported(dtype, "array")
  lax._check_user_dtype_supported(dtype, "mean")


Tolerance =  9.746220532931637e-065
Tolerance =  9.978962961153305e-065
Tolerance =  9.9402590469785e-06-05
Tolerance =  0.75090637486074092
Tolerance =  0.46627686953622326
Tolerance =  0.36426654598057673
Tolerance =  0.17627613777418158
2.7581234
trigger_beta =  1.0
Tolerance =  9.798792480104275e-065
Tolerance =  8.727729864762357e-065
Tolerance =  6.1114705120435e-06-05
Tolerance =  0.131763513268707428
Tolerance =  0.31222854650531195
Tolerance =  0.45163644199769126
Tolerance =  0.28218313630478644
2.7447457
trigger_beta =  2.0
Tolerance =  9.006027888877896e-065
Tolerance =  9.421737446335354e-065
Tolerance =  9.836016700166097e-065
Tolerance =  0.73079541740209247
Tolerance =  0.50068256308904882
Tolerance =  0.32459127928518396
Tolerance =  0.11759352559325248
2.2719312
trigger_beta =  4.0
Tolerance =  9.880356606211296e-065
Tolerance =  8.525264340247027e-065
Tolerance =  5.5884601440550065e-06
Tolerance =  0.723969104330168567
Tolerance =  0.35994302796546546
Tolerance =  0

In [None]:
heatmap = jnp.load('../results/heatmap_phishing_full.npy')

In [None]:
trigger_betas_legend = [round(x, 2) for x in trigger_betas]

In [None]:
trigger_betas_legend

In [None]:
df = pd.DataFrame(heatmap, index=ks, columns=trigger_betas_legend, dtype=int)

In [None]:
df

In [None]:
plt.figure(figsize=(20, 10))
log_norm = LogNorm(vmin=heatmap.min().min(), vmax=heatmap.max().max())
ax = sns.heatmap(df, annot=True, fmt="d", cmap="YlGnBu", norm=log_norm)
ax.set_xlabel('zeta')
ax.set_ylabel('compression level')
plt.title('heatmap {}'.format(dataset_name))
plt.tight_layout()
plt.savefig('../plots/heatmap_phishing_full.pdf')

### GD

In [None]:
max_iter = int(max_comm / (num_clients * len(x_0)) + 1)
h_GD, h_GD_comm = GD(x_0, X, y, lambda_, num_clients, max_iter)

## Plot

In [None]:
len(h_CLAG_comm)

In [None]:
plt.plot(h_EF21_comm, h_EF21, label='EF21', marker='+')
plt.plot(h_CLAG_comm, h_CLAG, label='CLAG', marker='D',
         markevery=30)
plt.plot(h_GD_comm, h_GD, label='GD', marker='*')
plt.plot(h_LAG_comm, h_LAG, label='LAG', marker='s',
         markevery=3)
plt.legend()
plt.xlabel('# of floats')
plt.ylabel(r'$||\nabla f(x^k)||^2$')
plt.title('phishing')
plt.xlim(left=0)
plt.grid()
plt.yscale('log')
plt.tight_layout()
plt.savefig('../results/phishing.pdf')

# a9a dataset

## Setup

In [None]:
num_clients = 20

In [None]:
lambda_ = 0.1

In [None]:
dataset_name = 'a9a'

In [None]:
raw_data = load_svmlight_file('../data/' + dataset_name)

In [None]:
X, y = raw_data

In [None]:
n = X.shape[0]

In [None]:
d = X.shape[1]

In [None]:
y

In [None]:
inds = np.array_split(np.arange(n), num_clients)

In [None]:
data = []
for i in range(num_clients):
    data.append((X[inds[i]][:], y[inds[i]]))

In [None]:
L = smoothness(X, lambda_)

In [None]:
L

In [None]:
L_i = [smoothness(data[i][0], lambda_) for i in range(num_clients)]

In [None]:
L_tilde = np.sqrt((np.array(L_i) ** 2).mean())

In [None]:
L_tilde

## Experiments

In [None]:
max_comm = 20000

### 1. LAG tuning

In [None]:
zetas = np.geomspace(1e-2, 1e1, 4)

In [None]:
x_0 = np.zeros(d)
LAG_histories = LAG_grid(x_0, data, lambda_, zetas, max_comm)

In [None]:
with open('../results/lag_{}.pickle'.format(dataset_name), 'wb') as file:
    pickle.dump(LAG_histories, file)

In [None]:
min_ind = print_and_return(LAG_histories)

In [None]:
print(min_ind)

In [None]:
h_LAG, h_LAG_comm = LAG_histories[min_ind]

In [None]:
# preliminary plot
plt.plot(h_LAG_comm, h_LAG)
plt.yscale('log')

### 2. EF21 tuning

In [None]:
ks = np.linspace(1, d, 10, endpoint=False, dtype=int)

In [None]:
ks

In [None]:
EF21_histories = EF21_grid(x_0, data, d, lambda_, ks, max_comm)

In [None]:
with open('../results/ef21_{}.pickle'.format(dataset_name), 'wb') as file:
    pickle.dump(EF21_histories, file)

In [None]:
min_ind = print_and_return(EF21_histories)

In [None]:
min_ind

In [None]:
h_EF21, h_EF21_comm = EF21_histories[min_ind]

In [None]:
# preliminary plot
plt.plot(h_LAG_comm, h_LAG, label='LAG')
plt.plot(h_EF21_comm, h_EF21, label='EF21')
plt.legend()
plt.yscale('log')

### 3. CLAG tuning

In [None]:
ks = np.linspace(1, d, 6, endpoint=True, dtype=int)

In [None]:
ks = np.flip(ks)

In [None]:
ks

In [None]:
trigger_betas = np.linspace(0, 50, 20, endpoint=True)

In [None]:
trigger_betas

In [None]:
beta_multipliers = np.geomspace(1e-1, 10, num=3)

In [None]:
beta_multipliers

In [None]:
ks

In [None]:
CLAG_histories = CLAG_grid(x_0, data, lambda_, ks, beta_multipliers, max_comm)

In [None]:
with open('../results/clag_{}.pickle'.format(dataset_name), 'wb') as file:
    pickle.dump(CLAG_histories, file)

In [None]:
len(CLAG_histories)

In [None]:
min_inds = []
for hist_k in CLAG_histories:
    ind = print_and_return(hist_k)
    min_inds.append(ind)

In [None]:
min_inds

In [None]:
min_list = [CLAG_histories[ind][min_inds[ind]] for ind in range(len(ks))]

In [None]:
min_ind = print_and_return(min_list)

In [None]:
min_ind

In [None]:
h_CLAG, h_CLAG_comm = min_list[min_ind]

In [None]:
# preliminary plot
plt.plot(h_LAG_comm, h_LAG, label='LAG')
plt.plot(h_EF21_comm, h_EF21, label='EF21')
plt.plot(h_CLAG_comm, h_CLAG, label='CLAG')
plt.legend()
plt.yscale('log')

In [None]:
heatmap = heatmap_CLAG(x_0, data, lambda_, ks, trigger_betas, 1e-3, '../results/heatmap_a9a.npy')

In [None]:
df = pd.DataFrame(heatmap, index=ks, columns=trigger_betas)

In [None]:
df

### GD

In [None]:
max_iter = int(max_comm / (num_clients * len(x_0)) + 1)
h_GD, h_GD_comm = GD(x_0, X, y, lambda_, num_clients, max_iter)

## Plot

In [None]:
len(h_CLAG_comm)

In [None]:
plt.plot(h_EF21_comm, h_EF21, label='EF21', marker='+')
plt.plot(h_CLAG_comm, h_CLAG, label='CLAG', marker='D',
         markevery=1500)
plt.plot(h_GD_comm, h_GD, label='GD', marker='*')
plt.plot(h_LAG_comm, h_LAG, label='LAG', marker='s',
         markevery=3)
plt.legend()
plt.xlabel('# of floats')
plt.ylabel(r'$||\nabla f(x^k)||^2$')
plt.title(dataset_name)
plt.xlim(left=0)
plt.grid()
plt.yscale('log')
plt.tight_layout()
plt.savefig('../results/{}.pdf'.format(dataset_name))

# w6a dataset

## Setup

In [None]:
num_clients = 20

In [None]:
lambda_ = 0.1

In [None]:
dataset_name = 'w6a'

In [None]:
raw_data = load_svmlight_file('../data/' + dataset_name)

In [None]:
X, y = raw_data

In [None]:
n = X.shape[0]

In [None]:
d = X.shape[1]

In [None]:
y

In [None]:
inds = np.array_split(np.arange(n), num_clients)

In [None]:
data = []
for i in range(num_clients):
    data.append((X[inds[i]][:], y[inds[i]]))

In [None]:
L = smoothness(X, lambda_)

In [None]:
L

In [None]:
L_i = [smoothness(data[i][0], lambda_) for i in range(num_clients)]

In [None]:
L_tilde = np.sqrt((np.array(L_i) ** 2).mean())

In [None]:
L_tilde

## Experiments

In [None]:
max_comm = 100000

### 1. LAG tuning

In [None]:
zetas = np.geomspace(1e-2, 1e1, 4)

In [None]:
x_0 = np.zeros(d)


In [None]:
LAG_histories = LAG_grid(x_0, data, lambda_, zetas, max_comm)

In [None]:
with open('../results/lag_{}.pickle'.format(dataset_name), 'wb') as file:
    pickle.dump(LAG_histories, file)

In [None]:
min_ind = print_and_return(LAG_histories)

In [None]:
print(min_ind)

In [None]:
h_LAG, h_LAG_comm = LAG_histories[min_ind]

In [None]:
# preliminary plot
plt.plot(h_LAG_comm, h_LAG)
plt.yscale('log')

### 2. EF21 tuning

In [None]:
ks = np.linspace(1, d, 10, endpoint=False, dtype=int)

In [None]:
ks

In [None]:
EF21_histories = EF21_grid(x_0, data, d, lambda_, ks, max_comm)

In [None]:
with open('../results/ef21_{}.pickle'.format(dataset_name), 'wb') as file:
    pickle.dump(EF21_histories, file)

In [None]:
min_ind = print_and_return(EF21_histories)

In [None]:
min_ind

In [None]:
h_EF21, h_EF21_comm = EF21_histories[min_ind]

In [None]:
# preliminary plot
plt.plot(h_LAG_comm, h_LAG, label='LAG')
plt.plot(h_EF21_comm, h_EF21, label='EF21')
plt.legend()
plt.yscale('log')

### 3. CLAG tuning

In [None]:
ks = np.flip(np.linspace(1, d, 6, endpoint=True, dtype=int))

In [None]:
ks = ks[:-1]

In [None]:
ks

In [None]:
trigger_betas = np.linspace(0, 50, 6, endpoint=True)

In [None]:
trigger_betas

In [None]:
beta_multipliers = np.geomspace(1e-1, 10, num=3)

In [None]:
beta_multipliers

In [None]:
ks

In [None]:
CLAG_histories = CLAG_grid(x_0, data, lambda_, ks, beta_multipliers, max_comm)

In [None]:
with open('../results/clag_{}.pickle'.format(dataset_name), 'wb') as file:
    pickle.dump(CLAG_histories, file)

In [None]:
len(CLAG_histories)

In [None]:
min_inds = []
for hist_k in CLAG_histories:
    ind = print_and_return(hist_k)
    min_inds.append(ind)

In [None]:
min_inds

In [None]:
min_list = [CLAG_histories[ind][min_inds[ind]] for ind in range(len(ks))]

In [None]:
min_ind = print_and_return(min_list)

In [None]:
min_ind

In [None]:
h_CLAG, h_CLAG_comm = min_list[min_ind]

In [None]:
# preliminary plot
plt.plot(h_LAG_comm, h_LAG, label='LAG')
plt.plot(h_EF21_comm, h_EF21, label='EF21')
plt.plot(h_CLAG_comm, h_CLAG, label='CLAG')
plt.legend()
plt.yscale('log')

In [None]:
heatmap = heatmap_CLAG(x_0, data, lambda_, ks, trigger_betas, 1e-2, '../results/heatmap_w6a.npy')

In [None]:
df = pd.DataFrame(heatmap, index=ks, columns=trigger_betas)

In [None]:
df

### GD

In [None]:
max_iter = int(max_comm / (num_clients * len(x_0)) + 1)
h_GD, h_GD_comm = GD(x_0, data, lambda_, num_clients, max_iter)

## Plot

In [None]:
len(h_CLAG_comm)

In [None]:
plt.plot(h_EF21_comm, h_EF21, label='EF21', marker='+')
plt.plot(h_CLAG_comm, h_CLAG, label='CLAG', marker='D',
         markevery=700)
plt.plot(h_GD_comm, h_GD, label='GD', marker='*')
plt.plot(h_LAG_comm, h_LAG, label='LAG', marker='s',
         markevery=10)
plt.legend()
plt.xlabel('# of floats')
plt.ylabel(r'$||\nabla f(x^k)||^2$')
plt.title(dataset_name)
plt.xlim(left=0)
plt.grid()
plt.yscale('log')
plt.tight_layout()
plt.savefig('../results/{}.pdf'.format(dataset_name))

In [None]:
h_check, h_check_comm = CLAG(x_0, data, lambda_, 30, 10, 100000)

In [None]:
plt.plot(h_check_comm, h_check)
plt.yscale('log')