In [23]:
import torch
import scipy
import copy
import scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import calculate_Atilde, cSBM
from GFL import MLP, LR, Node

In [24]:
def create_node_list(csbm, A_tilde, num_train, hidden_dim, output_dim=2, n_k=40, nn_type="MLP"):
    
    N = A_tilde.shape[0]
    
    input_dim= csbm.p
    
    node_list = []
    
    for i in range(N):
        
        X = []
        
        if (nn_type == "MLP"):
            model_i = MLP(input_dim, hidden_dim, output_dim)
            
        elif (nn_type == "LR"):
            model_i = LR(input_dim, output_dim)
            
        else:
            raise ValueError("Type of neural network must be either LR or MLP!")
            
        
        for j in range(n_k):
            
            x_j = np.sqrt(csbm.mu/N)*csbm.v[i]*csbm.u + np.random.normal(loc=0, scale=1, size=csbm.p)/np.sqrt(csbm.p)
            
            X.append(x_j)
            
        X = torch.tensor(np.array(X))
        
        if csbm.v[i] == -1:
            
            y = np.zeros(n_k)
            
        elif csbm.v[i] == 1:
            
            y = np.ones(n_k)

        y = torch.tensor(y).type(torch.LongTensor)
        
        node_i = Node(local_model=model_i, node_idx=i, X=X, y=y)
        
        node_list.append(node_i)
        
    class1_train = np.random.choice(a=csbm.class1_ids, size=int(num_train/2), replace=False)
    
    class2_train = np.random.choice(a=csbm.class2_ids, size=int(num_train/2), replace=False)
    
    train_indices = np.concatenate((class1_train, class2_train), axis=0)
    
    test_indices = list(set(np.arange(N)) - set(train_indices))
    
    return node_list, train_indices, test_indices

In [28]:
N = 100
p = 10
d = 5
mu = 1
l = 2
csbm = cSBM(N, p, d, mu, l)
A_tilde = calculate_Atilde(csbm.A, 100, 0.95)

In [78]:
class Central_Machine():
    
    def __init__(self, node_list, A_tilde):
        
        """
        A_tilde: PageRank matrix
        node_list: A list contains objects from Node class
        """
        
        self.A_tilde = A_tilde
        self.node_list = node_list
        self.N = len(node_list)
        self.cmodel = None
        
    def init_central_parameters(self, input_dim, hidden_dim, output_dim, nn_type):
        
        """
        Initialize the central server parameter dictonary
        """
        
        if (nn_type == "MLP"):
            self.cmodel = MLP(input_dim, hidden_dim, output_dim)
            
        elif (nn_type == "LR"):
            self.cmodel = LR(input_dim, output_dim)
        
            
    def collect_data(self, m):
        
        Xs = []
        
        ys = []
        
        for node in self.node_list:
            
            X, y = node.upload_data(m)
            
            Xs.append(X)
            ys.append(y)
            
            
        # Xs; [m, N, p]
        # ys: [m, N] 
        Xs = torch.cat(Xs, dim=1)
        
        ys = torch.cat(ys, dim=1)
        
        return Xs, ys
            
            
        
    def train_one_epoch(self, train_indices, batch_size=10):
        
        # Xs: [batch_size, N, p], ys: [batch_size, N]
        Xs, ys = self.collect_data(batch_size)
        
        # Hs: [batch_size, N, num_class]
        Hs = self.cmodel(Xs)
        
        # Zs: [m, N, num_class], ys: [m, N]
        Zs = torch.matmul(self.A_tilde, Hs)
        
        # train_Zs: [m, num_train, num_class]
        # train_ys: [m, num_train]
        train_Zs = Zs[:,train_indices,:]
        train_ys = ys[:,train_indices]
        

        num_train = len(train_indices)

        train_loss = F.cross_entropy(train_Zs.view(batch_size*num_train, -1), train_ys.view(batch_size*num_train))
        
        return train_loss    

In [85]:
def train_GML_cSBM(node_list, train_indices, test_indices, A_tilde, input_dim, hidden_dim,
               batch_size=20, learning_rate=0.1, opt="Adam", num_epochs=10,
               nn_type="MLP", output_dim=2):
    
    N = A_tilde.shape[0]
    
    cm = Central_Machine(node_list, A_tilde)
    
    cm.init_central_parameters(input_dim, hidden_dim, output_dim, nn_type)
    
    if (opt == "Adam"):
        optimizer = optim.Adam(cm.cmodel.parameters())
            
    else:
        optimizer = optim.SGD(cm.cmodel.parameters(), lr=learning_rate)
    
    train_loss = []
    
    for ith in range(num_epochs):
        
        optimizer.zero_grad()
        
        average_train_loss = cm.train_one_epoch(train_indices, batch_size)
        
        average_train_loss.backward()
        
        optimizer.step()
        
        train_loss.append(average_train_loss.item())
        
        if (num_epochs <= 30):
            print ("Communication:", ith+1, "Average train loss:", average_train_loss.item())
            
        elif (num_epochs > 30 and num_epochs <= 100):
            if (ith % 5 == 0):
                print ("Communication:", ith+1, "Average train loss:", average_train_loss.item())
                
        elif (num_epochs > 1000):
            if (ith % 100 == 0):
                print ("Communication:", ith+1, "Average train loss:", average_train_loss.item())
                
        else:
            if (ith % 10 == 0):
                print ("Communication:", ith+1, "Average train loss:", average_train_loss.item())
                
    return train_loss

In [89]:
nl, t1, t2 = create_node_list(csbm, A_tilde, num_train=10, hidden_dim=100, output_dim=2, n_k=40, nn_type="MLP")
tl = train_GML_cSBM(nl, t1, t2, A_tilde, input_dim=csbm.p, hidden_dim=200,
               batch_size=20, learning_rate=0.1, opt="Adam", num_epochs=10000,
               nn_type="LR", output_dim=2)

Communication: 1 Average train loss: 0.6858515739440918
Communication: 101 Average train loss: 0.6733407378196716
Communication: 201 Average train loss: 0.6616483926773071
Communication: 301 Average train loss: 0.6474541425704956
Communication: 401 Average train loss: 0.6341202259063721
Communication: 501 Average train loss: 0.6303312182426453
Communication: 601 Average train loss: 0.6141489744186401
Communication: 701 Average train loss: 0.6116967797279358
Communication: 801 Average train loss: 0.5889756083488464
Communication: 901 Average train loss: 0.5748032331466675
Communication: 1001 Average train loss: 0.567499041557312
Communication: 1101 Average train loss: 0.5623959898948669
Communication: 1201 Average train loss: 0.5395956635475159
Communication: 1301 Average train loss: 0.5409956574440002
Communication: 1401 Average train loss: 0.527929961681366
Communication: 1501 Average train loss: 0.5320008397102356
Communication: 1601 Average train loss: 0.5118868350982666


KeyboardInterrupt: 