In [None]:
## Python libs

import math
import torch
import pickle
import warnings
import numpy as np
import random

import matplotlib.pyplot as plt
%matplotlib inline

from pathlib import Path
from sklearn.datasets import load_svmlight_file

warnings.simplefilter('ignore')

In [None]:
COLAB = False

## There must be data 'data.zip' on google drive
## (in directory specified by 'GDDIR')
GDDIR = 'uploads/rogozin'

if COLAB:
    from google.colab import drive
    drive.mount('/content/gdrive')

    !cp -r /content/gdrive/My\ Drive/{GDDIR} .
    !unzip -qn {Path(GDDIR).name}/data.zip
    !bunzip2 data/*.bz2

    !git clone https://github.com/alexrogozin12/decentralized_methods.git
    !mv decentralized_methods/* .
    !rm decentralized_methods -rf
    !sed -ri '8d' src/utils.py

# DGM Minimal Environment 

In [None]:
## Local libs

from src.objectives import ( 
    LeastSquares, LogRegression,
    StochLeastSquares, StochLogRegression)
from src.methods import (
    EXTRA, DIGing, DSGD,
    DAccGD, Mudag, APM_C,
    SMudag, SAPM_C)
from src.utils import PythonGraph, lambda_2, expected_lambda2

def name_corrector(names):
    corrected_names = []
    for name in names:
        if name[0] == 'S': new_name = name[1:]
        else: new_name = name
            
        if new_name in names: new_name = name 
        new_name = new_name.replace('_', '-')
        
        corrected_names.append(new_name)
    return corrected_names

In [None]:
# TASK = StochLeastSquares
GRAPH_EVOLUTION = False
TASK = StochLogRegression 

DDIR = 'logreg_solutions' if 'LogRegression' in TASK.__name__ else 'least_squares_solutions' 
DSDIR = Path('data/a9a')
DDIR = Path(DDIR)

num_nodes = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
soldir = DDIR / DSDIR.name

A, b = load_svmlight_file(str(DSDIR))

A_cpu = torch.Tensor(A.todense())
b_cpu = torch.Tensor(b)

A = A_cpu.to(device)
b = b_cpu.to(device)

In [None]:
fnum = 0
fname = list(soldir.iterdir())[fnum].name
sigma = float(fname.split('=')[1])

with open(soldir/fname, 'rb') as file:
    f_star = pickle.load(file)['func_star']

In [None]:
# For simulating a graph evolution,
# only graphs like 'erdos_renyi' and 'random_geometric' are appropriate

p = .68
graph = 'random_geometric'
# graph = 'erdos_renyi'
# graph = 'path'
# graph = 'cycle'
# graph = 'complete'

avg = 1
static = True
bs = 10  # < ----------- NUMBER OF RANDOMLY SELECTED ROWS IN THE MATRIX

batch_sizes = b.new_full((num_nodes,), bs).long()

args = [A, b, num_nodes, sigma]
if 'Stoch' in TASK.__name__:
    args += [avg, batch_sizes, static]


F = TASK(*args)
F_cpu = TASK(*args)
X0 = torch.zeros(num_nodes, A.size(1)).to(device)

In [None]:
if GRAPH_EVOLUTION:
    gen = lambda : PythonGraph(F, graph, p).gen()[1]
    _gen = lambda: PythonGraph(F_cpu, graph, p).gen()[1]
    E_s2,_ = expected_lambda2(_gen, 6000)
else:
    W = PythonGraph(F, graph, p).gen()[1]
    s2 = lambda_2(W)

####
# Fixing seed doesn't really make a difference
####
# torch.manual_seed(123)  #  I don't remember whether I use torch random numbers anywhere
# random.seed(123)  #  networkx depends on lib random
# graphs = [PythonGraph(F, graph, p).gen()[1] for _ in range(int(1e4))]

# class GraphEvolution:
#     def __init__(self, graphs):
#         self.gi = iter(graphs)
        
#     def __call__(self):
#         return next(self.gi)

In [None]:
L = torch.svd(A)[1][0] ** 2 / (4*len(A))
kappa_g = torch.svd(F.A)[1][:, 0].mean() / sigma

consensus_iters = 4  # < --------------- HERE IS CONSENSUS ITERS 
eta_scale = 130
gamma_scale = 1.5
beta = 1e-6

M = (1-1e-5)*L*math.exp(consensus_iters*math.sqrt(1-s2)) / kappa_g

In [None]:
if GRAPH_EVOLUTION:
    opt = SDAccGD(F, gen, L=L, mu=sigma, con_iters=consensus_iters)
else:
    opt = DAccGD(F, W, L=L, mu=sigma, M=M, kappa_g=kappa_g, scale=1.)

checkpoint = [X0]

In [None]:
#%%time
## Running the cell X times yields X * n_iters optimization steps of each optimizer
## (if n_iters is not redefined during it). To run from scratch,
## execute the cell with optimizers' initialization first

n_iters = 1000

X0, *args = checkpoint
checkpoint = opt.run(X0, *args, n_iters=n_iters);

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

XLIM = opt.logs['nmix'][-1]
span = np.searchsorted(opt.logs['nmix'], XLIM, 'right')
axes[0].plot(
    opt.logs['nmix'][:span], 
    abs(opt.logs['fn'][:span] - f_star),
    marker=6, markevery=span//10)
    
axes[0].set_ylabel(r'$f(\overline{x}_k) - f^*$', size=15)
axes[0].set_xlabel('communication steps', size=15)


span = np.searchsorted(opt.logs['nmix'], XLIM, 'right')
axes[1].plot(
    opt.logs['nmix'][:span],
    opt.logs['dist2con'][:span],
    marker=6, markevery=span//10)

axes[1].set_ylabel(r'$||(I-\frac{1}{n}11^T)X||^2$', size=15)
axes[1].set_xlabel('communication steps', size=15)


for axis in axes:
    axis.set_yscale('log')
    axis.grid()

# ylim(0, 1.2*opts[0].logs['fn'][0])
plt.tight_layout();

fname = f'{DSDIR.name}-{bs}bs-{consensus_iters}cons.png'
if COLAB:
    plt.savefig(fname)
    !mv {fname} /content/gdrive/My\ Drive/{GDDIR}/figures
else:
    !mkdir -p figures
    plt.savefig(f'figures/{fname}')