In [1]:
import torch
import scipy
import copy
import scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from GFL import MLP, LR, Node, Central_Server
from utils import calculate_Atilde, mean_agg, cSBM

In [2]:
N = 100
p = 10
d = 5
mu = 1
l = 2
csbm = cSBM(N, p, d, mu, l)
A_tilde = calculate_Atilde(csbm.A, 100, 0.95)

In [26]:
def train_GFL_cSBM(csbm, A_tilde, hidden_dim, num_train, I, n_k, 
               num_communication=20, aggregation=mean_agg,
               learning_rate=0.1, opt="Adam", num_epochs=10,
               gradient=True, m=10, gradient_clipping=None,
               nn_type="MLP"):
    
    """
    csbm: An cSBM object (contextual stochastic block model)
    
    A_tilde: pageRank matrix
    
    n_k: number of feature vectors each node has.
    
    I: number of local updates for each node, so batch size = n_k/I for each node k.
    
    num_train: Number of nodes used in training.
    
    aggregation: aggregation method, for now, only mean aggregation is implemented. Default: mean_agg. 
    
    num_communication: Number of communicatons. Default: 20
    
    learning_rate: Learning rate for SGD. Default: 0.1
    
    opt: optimization method: Adam or SGD. Default: "Adam"
    
    gradient: boolean, whether to include the "fake gradient" or not. Default: True
    
    m: The number of feature vectors used for training loss evaluation in the end of each communication for each node. 
       Default: 10
       
    gradient_clipping: Whether to peform gradient clipping method during training process. None means no gradient clipping,
                       if a number (int or float) is given, 
                       then the maximum norm is determined by this number. Default: None.
                       
    nn_type: The type of neural network. either "MLP" or "LR" (i.e. MLP or Logistic Regression). Default:"MLP".
    """
    
    N = A_tilde.shape[0]
    
    input_dim = csbm.p
    
    output_dim = 2
    
    node_list = []
    
    for i in range(N):
        
        X = []
        
        if (nn_type == "MLP"):
            model_i = MLP(input_dim, hidden_dim, output_dim)
            
        elif (nn_type == "LR"):
            model_i = LR(input_dim, output_dim)
            
        else:
            raise ValueError("Type of neural network must be either LR or MLP!")
            
        
        for j in range(n_k):
            
            x_j = np.sqrt(csbm.mu/N)*csbm.v[i]*csbm.u + np.random.normal(loc=0, scale=1, size=csbm.p)/np.sqrt(csbm.p)
            
            X.append(x_j)
            
        X = torch.tensor(np.array(X))
        
        if csbm.v[i] == -1:
            
            y = np.zeros(n_k)
            
        elif csbm.v[i] == 1:
            
            y = np.ones(n_k)

        y = torch.tensor(y).type(torch.LongTensor)
        
        node_i = Node(local_model=model_i, node_idx=i, X=X, y=y)
        
        node_list.append(node_i)
        
    server = Central_Server(node_list, A_tilde)
    
    server.init_central_parameters(input_dim, hidden_dim, output_dim, nn_type)
    
    class1_train = np.random.choice(a=csbm.class1_ids, size=int(num_train/2), replace=False)
    
    class2_train = np.random.choice(a=csbm.class2_ids, size=int(num_train/2), replace=False)
    
    train_indices = np.concatenate((class1_train, class2_train), axis=0)
    
    test_indices = list(set(np.arange(N)) - set(train_indices))
    
    train_loss = []
    
    for ith in range(num_communication):
        
        average_train_loss = server.communication(train_indices, test_indices, 
                                                  I, aggregation, opt, learning_rate, num_epochs,
                                                  gradient, m, gradient_clipping)
        train_loss.append(average_train_loss)
        
        if (num_communication <= 30):
            print ("Communication:", ith+1, "Average train loss:", average_train_loss)
            
        elif (num_communication > 30 and num_communication <= 100):
            if (ith % 5 == 0):
                print ("Communication:", ith+1, "Average train loss:", average_train_loss)
                
        elif (num_communication >= 10000):
            if (ith % 100 == 0):
                print ("Communication:", ith+1, "Average train loss:", average_train_loss)
                
        else:
            if (ith % 10 == 0):
                print ("Communication:", ith+1, "Average train loss:", average_train_loss)
                
    return train_loss

In [31]:
tl = train_GFL_cSBM(csbm=csbm, A_tilde=A_tilde, hidden_dim=100, n_k=40,
           I=1, num_communication=20000, aggregation=mean_agg, num_train=10, 
           num_epochs=1,
           gradient=True, m=20, gradient_clipping=None, nn_type="LR")

Communication: 1 Average train loss: 0.6935615539550781
Communication: 101 Average train loss: 0.6922706365585327
Communication: 201 Average train loss: 0.6947137713432312
Communication: 301 Average train loss: 0.6900880336761475
Communication: 401 Average train loss: 0.6885721683502197
Communication: 501 Average train loss: 0.6888320446014404
Communication: 601 Average train loss: 0.6833720207214355
Communication: 701 Average train loss: 0.68278968334198
Communication: 801 Average train loss: 0.6803514361381531
Communication: 901 Average train loss: 0.6806635856628418
Communication: 1001 Average train loss: 0.6792182326316833
Communication: 1101 Average train loss: 0.6773951649665833
Communication: 1201 Average train loss: 0.6787551641464233
Communication: 1301 Average train loss: 0.6777995228767395
Communication: 1401 Average train loss: 0.6763874888420105
Communication: 1501 Average train loss: 0.6725531220436096
Communication: 1601 Average train loss: 0.6692662835121155
Communicati

KeyboardInterrupt: 