# DGM Minimal Environment 

In [None]:
import math
import numpy as np
import torch
import networkx as nx

import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.datasets import load_svmlight_file
from torch import autograd
from torch.nn.utils.rnn import pad_sequence

In [None]:
def uniform_decompose(N, m, b=8):
    """
    Decomposes N into m terms: a_i, i=1,m;
    so that: max_i a_i - min_i a_i <= 1 applies
    --------
    b is a number of bits used for an integer
    in the output array
    """
    terms = np.empty(m, dtype=f'i{b}')
    terms[:] = math.floor(N/m)
    terms[:N-terms.sum()] += 1
    return terms.tolist() 


def D(y, x):
    """
    Differential operator
    """
    grad = autograd.grad(
        outputs=y, inputs=x,
        grad_outputs=torch.ones_like(y),
        create_graph=True, allow_unused=True)

    if len(grad) == 1:
        return grad[0]
    return grad


def metropolis_weights(A):
    A = A / (1 + torch.max(A.sum(1,keepdims=True),A.sum(0,keepdims=True)))
    A.as_strided([len(A)], [len(A)+1]).copy_(1-A.sum(1))
    return A

In [None]:
class Objective:
    """
    Base class for optimization functional
    """
    def __init__(self, A, b, num_nodes):
        chunk_sizes = uniform_decompose(A.size(0), num_nodes)
        self.A = pad_sequence(A.split(chunk_sizes), batch_first=True)
        self.b = pad_sequence(b.split(chunk_sizes), batch_first=True)

        
class LeastSquares(Objective):
    def __call__(self, X):
        s = '' if X.ndim < 2 else 'i'
        Y = torch.einsum(f'ijk,k{s}->ij', self.A, X) - self.b
        return Y.square().sum()
    
    
class LogRegression(Objective):
    def __call__(self, X):
        s = '' if X.ndim < 2 else 'i'
        Y = torch.einsum(f'ij,ijk,k{s}->ij', self.b, self.A, X)
        Y = torch.logaddexp(-Y, torch.tensor(0.)).mean()
        return Y

In [None]:
class DGM:
    """
    Base class for decentralized gradient methods
    """
    def __init__(self, W, F, alpha=1.):
        self.W = W
        self.F = F
        self.n = W.size(0)
        self.alpha = alpha
        self._initLogs()
        
    def _initLogs(self):
        self.logs = {0: [], 1: [], 2: [], 3: []}
        
    def _metric2(self, X):
        h = X.new(self.n).fill_(1.)
        Q = torch.norm(X - X/self.n @ h[:,None]*h)
        return Q
    
    def _metric3(self, X):
        ImW = (-self.W).as_strided(
            [self.n], [self.n+1]).add(torch.ones(self.n))
        Q = torch.norm(X @ ImW)
        return Q
        
    def _record(self, X, k):
        self.logs[1].append(self.F(X.mean(1)).item())
        self.logs[2].append(self._metric2(X).item())
        self.logs[3].append(self._metric3(X).item())
        self.logs[0].append(k)
        
        
class EXTRON(DGM):
    """
    ONe-process EXTRA algorithm
    """
    def _step1(self, X0):
        X0.requires_grad_(True)
        G0 = D(self.F(X0), X0)
        with torch.no_grad():
            X1 = X0@self.W - self.alpha*G0
        return G0, X1

    def _step2(self, X0, G0, X1):
        X1.requires_grad_(True)
        G1 = D(self.F(X1), X1)
        with torch.no_grad():
            X2 = X1 - X0/2 + (X1-X0/2)@self.W - self.alpha*(G1-G0)
        return X1, G1, X2

    def run(self, X0, G0=None, X1=None, n_iters=10, lp=1):
        if G0 is None or X1 is None:
            G0, X1 = self._step1(X0)
            self._initLogs()
            self._record(X0, 0)
            self._record(X1, 1)

        for k in range(n_iters-1):
            X0, G0, X1 = self._step2(X0, G0, X1)
            if k%lp == 0: self._record(X1, k)

        return X0, G0, X1


class DIGONing(DGM):
    """
    ONe-process DIGing algorithm
    """
    def run(self, X0, Y0=None, n_iters=10, lp=1):
        if Y0 is None:
            self._initLogs()
            self._record(X0, 0)
            X0.requires_grad_(True)
            Y0 = D(self.F(X0), X0)
            G0 = Y0.clone()
            
        for k in range(1, n_iters):
            with torch.no_grad():
                X1 = X0@self.W - self.alpha*Y0
                
            X1.requires_grad_(True)
            G1 = D(self.F(X1), X1)
            with torch.no_grad():
                Y1 = Y0@self.W + G1 - G0
                
            X0, Y0, G0 = X1, Y1, G1
            if k%lp == 0: self._record(X1, k)
            
        return X0, G0, X1

In [None]:
A, b = load_svmlight_file('data/a9a.2')
A = torch.Tensor(A.todense())
b = torch.Tensor(b)

In [None]:
num_nodes = 20

#G = nx.path_graph(num_nodes)
#G = nx.cycle_graph(num_nodes)
#G = nx.complete_graph(num_nodes)
G = nx.erdos_renyi_graph(num_nodes, .2)

S = nx.adjacency_matrix(G).todense()
W = metropolis_weights(torch.Tensor(S))
assert torch.allclose(W @ torch.ones(len(W)), torch.ones(len(W)))
assert torch.allclose(torch.ones(len(W)) @ W, torch.ones(len(W)))
nx.draw(G)

In [None]:
F = LogRegression(A, b, num_nodes)
X0 = torch.zeros(A.size(1), num_nodes)
opt = EXTRON(W, F)

In [None]:
opt.run(X0, n_iters=100);

In [None]:
plt.figure(figsize=(8, 8))
plt.plot(opt.logs[1]);
plt.title('Optimization Functional Value over Iteration Number', size=20)
plt.xlabel('# k', size=20)
plt.ylabel('f(x)', size=20);

In [None]:
plt.figure(figsize=(8, 8))
plt.plot(opt.logs[2][1::]);
plt.title(r'$||X(I-11^T)||^2$', size=20)
plt.xlabel('# k', size=20)