In [1]:
# Packages

import networkx as nx
import matplotlib.pyplot as plt
import torch
import copy
import time
import numpy as np
import pandas as pd
import scipy
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch_geometric.utils as utils
from torch_geometric.datasets import Planetoid

In [7]:
class Central_Server:
    
    def __init__(self, node_list, A_hat, K, alpha=0.95, A_tilde=None):
        
        # K: The number of power iterations to perform.
        # alpha: The jump probability in personalized page rank.
        # node_list: A list of Node objects.
        # A_hat: The normalized adjacency matrix with added self-loops.
        
        self.node_list = node_list
        self.A_hat = A_hat
        self.alpha = alpha
        self.K = K
        self.A_tilde = A_tilde # The matrix used to perform power iterations.
        self.central_params = None # The central parameters, which should be None at initilization.
        self.N = len(node_list) # Number of nodes.
        self.mlp = None
            
    def receive_h(self):
        
        # Receive the representation of each node.
        
        H = [] # A list of representations for each node.
        for v in self.node_list:
            h_v = v.upload_h() # h_v: A 1-d tensor.
            H.append(h_v)
        return H
    
    def compute_A_tilde(self):
        
        N = self.N # Number of nodes.
        A_tilde = torch.zeros((N,N)) # The matrix used to perform power iterations. (Initilized to be a zero matrix)
        A_i = torch.diag(torch.ones(N)) # Part of the equation. (Identity matrix at initilization)
        alpha_i = 1 # Part of the equation. (1 at initilization)
        for i in range(0, self.K+1):
            A_tilde = A_tilde + alpha_i*A_i
            alpha_i = alpha_i * self.alpha
            A_i = torch.matmul(A_i, self.A_hat)
        A_tilde = (1-self.alpha)*A_tilde
        self.A_tilde = A_tilde
    
    def init_central_params(self):
        
        # Initilize the central parameters.
        
        params = copy.deepcopy(self.node_list[0].mlp.state_dict()) # Make a deep copy of the parameters from the first node.
        self.cparam_names = list(params.keys()) # Get the name of each parameter in mlp.
        
        with torch.no_grad():
            for param_name in params.keys():
                nn.init.normal_(params[param_name]) # Initilize each parameter by normal distribution.
                
        self.central_params = params
        
            
    def one_time_communication(self, E, num_training=200):
        
        # Perform one time communication.
        # E: number of training epochs in local updates.
        
        # Calculate the A_tilde matrix, if it is not calculated already.
        if (self.A_tilde == None):
            self.compute_A_tilde()
            
        # If this is the first communication, initilalize the central parameters.
        if (self.central_params == None):
            self.init_central_params()
            
        # Broadcast the current central parameters to all nodes:
        for v in self.node_list:
            v.receive_params(self.central_params)
            
        # Get the representation of each node
        H = self.receive_h() # A list of tensors.
        
        # Peform the power iteration without gradient involved for simplification.
        H_copy = torch.stack(H).clone().detach()
        Z_K = torch.matmul(self.A_tilde, H_copy)
        
        # Shape of H
        num_nodes, num_classes = H_copy.shape
        
        # Calculate the needed information for each node to perform local updates.
        for i in range(num_training):
            with torch.no_grad():
                z_u = Z_K[i,:] - H_copy[i,:]*self.A_tilde[i,i] # Aggregated neighborhood information for node v.
            # Perform Local update
            self.node_list[i].receive_and_update(self.A_tilde[i,i], z_u, E)
            
        # Collect the updated local parameters and aggregate
        with torch.no_grad():
            for pname in self.cparam_names:
                p = self.node_list[0].mlp.state_dict()[pname]
                for i in range(1, num_training):
                    p = p + self.node_list[i].mlp.state_dict()[pname]
                p = p/num_training
                self.central_params[pname] = p
                
    def testing_accuracy(self, graph_data, params, num_training=200):

        X = graph_data.x
        Y = graph_data.y
        self.mlp.load_state_dict(params)
        with torch.no_grad():
                H = self.mlp(X)
                Z_K = torch.matmul(self.A_tilde, H)
                preds = torch.max(Z_K, dim=1)[1][num_training:]
                counts = (preds == Y[num_training:]).sum()
        return counts.item()/(self.N - num_training)
    
    def training_and_testing(self, T, E, data):
        
        # T: number of communications.
        # E: number of training epochs in local updates.
        accuracy = []
        for t in range(T):
            self.one_time_communication(E)
            accuracy.append(self.testing_accuracy(data, self.central_params))
            print ("Communication:", t)
                
        plt.plot(np.arange(T)+1, accuracy)
        return accuracy

In [8]:
class MLP(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, bias=False):
        super(MLP, self).__init__()
        self.linear_1 = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.linear_2 = nn.Linear(hidden_dim, output_dim, bias=bias)
        
    def forward(self, x):
        x = self.linear_1(x)
        x = F.relu(x)
        x = self.linear_2(x)
        return (x)

In [9]:
class Node:
    
    def __init__(self, X, y, mlp, learning_rate, batchsize=4):
        
        # X: Tensor with dimension: n x input_dim.
        # Y: class: the corresponding class for each node, one class per node.
        # mlp: The mlp for each node.
        # learning_rate: learning rate for the optimizer.
        # Batch size: default value is 3.
        
        self.n = X.shape[0] # Number of feature vectors.
        self.X = X # Feature matrix.
        self.y = y.view(1).repeat(self.n) # 1-d tensor.
        self.mlp = mlp # The mlp for this node.
        self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=learning_rate, weight_decay=5e-4) # Optimizer
        
        # Create a data loader for mini-bactch training.
        self.dataset_v = Data.TensorDataset(self.X, self.y) 
        self.loader = Data.DataLoader(dataset=self.dataset_v,
                                      shuffle=True,
                                      batch_size=batchsize,
                                      num_workers=2)
            
    def receive_params(self, parameters):
        
        # parameters: a dict of parameters: {"linear_1.weight": tensor,......}
        with torch.no_grad():
            for name, param in self.mlp.named_parameters():
                param.copy_(parameters[name])
                
        
    def upload_params(self):
        
        # return a dict of parameters: {"linear_1,weight":" tensor,......,}
        return mlp.state_dict()
            
    
    def upload_h(self): 
        
        # Use this funtion only at the beginning of each communication.
        # There should be no gradient involved in this function.
        with torch.no_grad():
            H = self.mlp(self.X)
            h = H.sum(dim=0)/self.n
        return h # 1-d tensor.
    
    def receive_and_update(self, Atilde_v, z_u, E):
        
        # Atilde_v: The linear coeifficent for node v.
        # z_u: The aggregated neighborhood information for node v.
        # E: number of training epochs in local updates.
        for epoch in range(E):
            for X, y in self.loader:
                self.optimizer.zero_grad()
                H = self.mlp(X)
                Z = Atilde_v * H + z_u
                y_hat = F.log_softmax(Z, dim=1)
                loss = F.nll_loss(y_hat, y)
                loss.backward()
                self.optimizer.step() 

In [10]:
def init_network(data, A_hat, learning_rate, K,
                 num_classes, hidden_dim, output_dim, 
                 data_dict, 
                 batch_size=3, num_features=10, alpha=0.95, A_tilde=None):
    
    num_nodes, input_dim = data.x.shape
    node_list = []
    
    for v in range(num_nodes):
        
        mlp = MLP(input_dim, hidden_dim, output_dim)
        
        c_v =  data.y[v].item()
        n_c = data_dict[c_v].shape[0]
        indices = np.random.choice(np.arange(n_c), replace=False, size=num_features)
        X_v = torch.tensor(data_dict[c_v][indices,:])
        node_v = Node(X_v, data.y[v], mlp, learning_rate, batch_size)
        node_list.append(node_v)
    
    network = Central_Server(node_list, A_hat, K, alpha, A_tilde)
    network.mlp = MLP(input_dim, hidden_dim, output_dim)
    
    return network

In [6]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
data.edge_index = utils.remove_self_loops(data.edge_index)[0]
data.edge_index = utils.add_remaining_self_loops(data.edge_index)[0]
G = utils.to_networkx(data, to_undirected=True)
A = torch.tensor(nx.linalg.graphmatrix.adjacency_matrix(G).todense()).type(torch.FloatTensor)
D = nx.linalg.graphmatrix.adjacency_matrix(G).todense() + nx.linalg.laplacianmatrix.laplacian_matrix(G).todense()
Dhalf = scipy.linalg.sqrtm(D)
Dnhalf = torch.tensor(scipy.linalg.inv(Dhalf)).type(torch.FloatTensor)
A_hat = torch.matmul(torch.matmul(Dnhalf, A), Dnhalf)

data_dict = {}
N, _ = data.x.shape
for i in range(N):
    c = data.y[i].item()
    if c in data_dict:
        data_dict[c].append(data.x[i,:].numpy())
    else:
        data_dict[c] = []
        data_dict[c].append(data.x[i,:].numpy())
        
for k, v in data_dict.items():
    data_dict[k] = np.array(data_dict[k])
    
    
N = data.x.shape[0]
A_tilde = torch.zeros((N,N)) # The matrix used to perform power iterations. (Initilized to be a zero matrix)
A_i = torch.diag(torch.ones(N)) # Part of the equation. (Identity matrix at initilization)
alpha_i = 1 # Part of the equation. (1 at initilization)
for i in range(0, 100+1):
    A_tilde = A_tilde + alpha_i*A_i
    alpha_i = alpha_i * 0.95
    A_i = torch.matmul(A_i, A_hat)
A_tilde = (1-0.95)*A_tilde

In [11]:
net = init_network(data, A_hat, 0.01, 100, 7, 200, 7, data_dict, batch_size=10, num_features=10, A_tilde=A_tilde)

In [12]:
accuracy = net.training_and_testing(10, 10, data)

KeyboardInterrupt: 