In [1]:
import numpy as np
from oracle import Oracle

In [2]:
def decentralizedSGD(x0, lr, max_iter, W, oracle):
    
    x_t = x0
    
    for t in range(max_iter):
         
        noise = np.random.normal(0, 1, size=x_t.shape)
        
        f, df = oracle(x_t + noise)
        
        y_t = x_t - lr * df
        x_t = np.matmul(y_t, W)
        
    return x_t

def MetropolisHastings(W):
    
    degrees = np.sum(W, axis=1)
    
    for i in range(W.shape[0]):
        for j in range(i, W.shape[0]):
            
            if W[i, j] != 0:
                
                weight = min(1 / (degrees[i] + 1), 1 / (degrees[j] + 1))
                
                W[i, j] = weight
                W[j, i] = weight
                
    return W


def toMatrix(size, adj_list):
    
    matrix = np.zeros((size, size))
    
    for node in adj_list:
        
        neighbors = adj_list[node]
        
        for neighbor in neighbors:
            
            matrix[node, neighbor] = 1
            matrix[neighbor, node] = 1
            
    return matrix.astype(np.float64)

In [3]:
nodes = 4
params = 6
shape = (params, nodes)
max_iter = 100
lr = 1e-4

x0 = np.random.randn(params, nodes)

connections = {0: [1, 2], 1: [3], 2: [1]}

adjacency_matrix = toMatrix(nodes, connections)

W = MetropolisHastings(adjacency_matrix)

func_type = "strongly convex"
oracle = Oracle(func_type)

decentralizedSGD(x0, lr, max_iter, W, oracle)

array([[ 2.27174964e-05, -6.53774581e-05,  3.15044267e-05,
         3.35007479e-05],
       [ 6.12463941e-05,  2.32537985e-05, -1.08651099e-04,
        -4.03365511e-05],
       [ 1.12331802e-04, -7.19710396e-05,  4.45478723e-05,
         8.77153228e-05],
       [-6.37883738e-05, -9.11599587e-05, -1.49428642e-04,
        -4.81727584e-05],
       [-6.49307955e-06,  1.01094894e-04,  4.64995411e-05,
         6.94645577e-06],
       [-7.34495323e-05, -6.34099193e-05, -4.04980444e-05,
        -1.57343894e-05]])