In [22]:
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid

In [85]:
class MLP(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim=None, sigma=False, bias=False):
        super(MLP, self).__init__()
        self.linear_1 = nn.Linear(input_dim, hidden_dim, bias=bias)
        if output_dim != None:
            self.linear_2 = nn.Linear(hidden_dim, output_dim, bias=bias)
        self.sigma = sigma
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
    def forward(self, x):
        x = self.linear_1(x)
        if self.sigma and self.output_dim != None:
            x = F.relu(x)
            x = self.linear_2(x)
        return (x)
    

    
class Node:
    
    def __init__(self, x, y, mlp, learning_rate, node_id):
        
        self.x = x
        self.y = y.view(1)
        self.mlp = mlp
        self.node_id = node_id
        self.learning_rate = learning_rate
        self.optimizer = torch.optim.SGD(self.mlp.parameters(), lr=self.learning_rate)
        
        
    def receive_params(self, parameters):
        
        # parameters: a dict of parameters: {"linear_1.weight": tensor,......}
        
        with torch.no_grad():
            for name, param in mlp.named_parameters():
                param.copy_(parameters[name])
                
        
    def upload_params(self):
        
        # return a dict of parameters: {"linear_1,weight":" tensor,......,}
        
        return mlp.state_dict()
    
    def forward_upload(self):
        
        h = self.mlp(x)
        return h
        
    def local_update(self, y_hat):
        
        # y_hat: The output of log_softmax function: 1 x num_classes.
        # self.y: one-dimension array, true label.
        # You can only update one time in this setup.
        loss = F.nll_loss(y_hat, self.y)
        self.optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
class Central_Server:
    
    def __init__(self, edge_index, node_list):
        
        self.edge_index = edge_index
        self.node_list = node_list

In [86]:
def init_network(data, num_classes, hidden_dim, learning_rate, 
                       output_dim=None, sigma=False):
    
    num_nodes, input_dim = data.x.shape
    node_list = []
    
    for v in range(num_nodes):
        
        if (output_dim != None and sigma == True):
            mlp = MLP(input_dim, hidden_dim, output_dim, sigma)
            
        else:
            mlp = MLP(input_dim, hidden_dim)
            
        node_v = Node(data.x[v,:], data.y[v], mlp, learning_rate, v)
        node_list.append(node_v)
    
    network = Central_Server(data.edge_index, node_list)
        
    
    return network
        

In [79]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]