In [18]:
import networkx as nx
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.utils as utils
from torch_geometric.datasets import Planetoid

In [62]:
class MLP(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim=None, sigma=False, bias=False):
        super(MLP, self).__init__()
        self.linear_1 = nn.Linear(input_dim, hidden_dim, bias=bias)
        if output_dim != None:
            self.linear_2 = nn.Linear(hidden_dim, output_dim, bias=bias)
        self.sigma = sigma
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
    def forward(self, x):
        x = self.linear_1(x)
        if self.sigma and self.output_dim != None:
            x = F.relu(x)
            x = self.linear_2(x)
        return (x)
    

    
class Node:
    
    def __init__(self, x, y, mlp, learning_rate, node_id):
        
        self.x = x
        self.y = y.view(1)
        self.mlp = mlp
        self.node_id = node_id
        self.learning_rate = learning_rate
        self.optimizer = torch.optim.SGD(self.mlp.parameters(), lr=self.learning_rate)
        
        
    def receive_params(self, parameters):
        
        # parameters: a dict of parameters: {"linear_1.weight": tensor,......}
        
        with torch.no_grad():
            for name, param in self.mlp.named_parameters():
                param.copy_(parameters[name])
                
        
    def upload_params(self):
        
        # return a dict of parameters: {"linear_1,weight":" tensor,......,}
        
        return mlp.state_dict()
    
    def upload_h(self):
        
        h = self.mlp(self.x)
        return h
        
        
class Central_Server:
    
    def __init__(self, edge_index, node_list, A, K, alpha=0.95):
        
        self.edge_index = edge_index
        self.node_list = node_list
        self.A = A
        self.H = None
        self.alpha = alpha
        self.K = K
        
    def receive_h(self):
        
        H = []
        
        for v in self.node_list:
            
            h_v = v.upload_h()
            H.append(h_v)
        self.H = torch.stack(H)
        
    def compute_Abar(self):
        
        N = self.A.shape[0]
        A_bar = torch.zeros((N,N))
        A_i = torch.diag(torch.ones(N))
        alpha_i = 1
        for i in range(0, K+1):
            A_bar = A_bar + alpha_i*A_i
            alpha_i = alpha_i * self.alpha
            A_i = torch.matmul(A_i, self.A)
        A_bar = (1-self.alpha)*A_bar
        self.A_bar = A_bar

In [103]:
def init_network(data, A, num_classes, hidden_dim, learning_rate, K,
                       output_dim=None, sigma=False):
    
    num_nodes, input_dim = data.x.shape
    node_list = []
    
    for v in range(num_nodes):
        
        if (output_dim != None and sigma == True):
            mlp = MLP(input_dim, hidden_dim, output_dim, sigma)
            
        else:
            mlp = MLP(input_dim, hidden_dim)
            
        node_v = Node(data.x[v,:], data.y[v], mlp, learning_rate, v)
        node_list.append(node_v)
    
    network = Central_Server(data.edge_index, node_list, A, K)
        
    
    return network
        

In [104]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
data.edge_index = utils.remove_self_loops(data.edge_index)[0]
data.edge_index = utils.add_remaining_self_loops(data.edge_index)[0]
G = utils.to_networkx(data, to_undirected=True)
A = torch.tensor(nx.linalg.graphmatrix.adjacency_matrix(G).todense())

In [105]:
server = init_network(data, A, 7, 40, 0.1, 10)