## Importing libraries

In [1]:
#generic
from pathlib import Path
import os, sys
import argparse
import random
import copy
from random import choices


#torch
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv, global_add_pool, SAGEConv
from torch_geometric.transforms import OneHotDegree
from torch_geometric.utils import to_networkx, degree, to_dense_adj, to_scipy_sparse_matrix
from sklearn.model_selection import train_test_split
from scipy import sparse as sp
import torch_geometric
from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.utils import to_networkx, subgraph
import torch_geometric.utils as utils
from torch.nn.functional import one_hot


#utility
import networkx as nx
from dtaidistance import dtw
from tensorboardX import SummaryWriter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymetis
from ogb.nodeproppred import PygNodePropPredDataset

# Hyper parameters

In [2]:
num_clients = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
alg = 'fedstar'
num_rounds = 20
local_epoch = 10
lr = 0.01
weight_decay = 5e-4
nlayer = 3 # number of GINConv layers
hidden = 64
dropout = 0.5
batch_size = 128  # not used
seed = 69
datapath = '.Data'
outbase = 'outputs'
data_group = 'arxiv'
n_rw = 16
n_dg = 16
n_ones = 16
type_init = 'rw_dg' #options are rw, dg and rw_dg
print(device)

cpu


In [3]:
seed_dataSplit = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Utils

In [4]:
def get_numGraphLabels(g):
    s = set(g.y.flatten().tolist())
    return len(s)

def init_structure_encoding(  g, type_init = 'rw_dg'):

    if type_init == 'rw':
        # Geometric diffusion features with Random Walk
        A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes)
        D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy()

        Dinv=sp.diags(D)
        RW=A*Dinv
        M=RW

        SE_rw=[torch.from_numpy(M.diagonal()).float()]
        M_power=M
        for _ in range(n_rw-1):
            M_power=M_power*M
            SE_rw.append(torch.from_numpy(M_power.diagonal()).float())
        SE_rw=torch.stack(SE_rw,dim=-1)

        g['stc_enc'] = SE_rw

    elif type_init == 'dg':
        # PE_degree
        g_dg = (degree(g.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, n_dg)
        SE_dg = torch.zeros([g.num_nodes, n_dg])
        for i in range(len(g_dg)):
            SE_dg[i,int(g_dg[i]-1)] = 1

        g['stc_enc'] = SE_dg

    elif type_init == 'rw_dg':
        # SE_rw
        A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes)
        D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy()

        Dinv=sp.diags(D)
        RW=A*Dinv
        M=RW

        SE=[torch.from_numpy(M.diagonal()).float()]
        M_power=M
        for _ in range(n_rw-1):
            M_power=M_power*M
            SE.append(torch.from_numpy(M_power.diagonal()).float())
        SE_rw=torch.stack(SE,dim=-1)

        # PE_degree
        g_dg = (degree(g.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, n_dg)
        SE_dg = torch.zeros([g.num_nodes, n_dg])
        for i in range(len(g_dg)):
            SE_dg[i,int(g_dg[i]-1)] = 1

        g['stc_enc'] = torch.cat([SE_rw, SE_dg], dim=1)

    return g

def get_stats(df, ds, graph_train, graph_val=None, graph_test=None):
    from collections import Counter
    labels_train = graph_train.y.flatten().tolist()
    df.loc[ds, '#Nodes_train'] = graph_train.num_nodes
    df.loc[ds, '#Edges_train'] = graph_train.num_edges
    df.loc[ds, 'Avg_degree_train'] = graph_train.num_edges/graph_train.num_nodes
    df.loc[ds, '#Labels_train'] = len(set(labels_train))
    df.loc[ds, 'Class_dist_train'] = str(dict(Counter(labels_train)))
    
    if graph_test:
        labels_test = graph_test.y.flatten().tolist()
        df.loc[ds, '#Nodes_test'] = graph_test.num_nodes
        df.loc[ds, '#Edges_test'] = graph_test.num_edges
        df.loc[ds, 'Avg_degree_test'] = graph_test.num_edges/graph_test.num_nodes
        df.loc[ds, '#Labels_test'] = len(set(labels_test))
        df.loc[ds, 'Class_dist_test'] = str(dict(Counter(labels_test)))
        
    if graph_val:
        labels_val = graph_val.y.flatten().tolist()
        df.loc[ds, '#Nodes_val'] = graph_val.num_nodes
        df.loc[ds, '#Edges_val'] = graph_val.num_edges
        df.loc[ds, 'Avg_degree_val'] = graph_val.num_edges/graph_val.num_nodes
        df.loc[ds, '#Labels_val'] = len(set(labels_val))
        df.loc[ds, 'Class_dist_val'] = str(dict(Counter(labels_val)))
        
    return df

# Making data

In [5]:
def prepareData_multiDS(num_clients, datapath,  batchSize=32, seed=None):

    num_clients = num_clients
    splitedData = {}
    df = pd.DataFrame()
    dataset = PygNodePropPredDataset(name='ogbn-arxiv')
    graph = dataset[0]
    num_nodes = graph.num_nodes
    num_edges = graph.num_edges
    nx_graph = utils.to_networkx(graph)
    partitions = pymetis.part_graph(num_clients, adjacency=nx.to_dict_of_lists(nx_graph))
    partitions_np = np.array(partitions[1])
    partition_tensor = torch.from_numpy(partitions_np)
    subgraphs = []
    print(f'Number of nodes in the orignal graph = {num_nodes}\nNumber of edges in the orignal graph = {num_edges}\nAverage degree of the orignal graph = {num_edges/num_nodes}\n')
    train_ratio = 0.2
    val_ratio = 0.2
    test_ratio = 0.6
    for i in range(num_clients):
        data = f'Client{i+1}'
        nodes = (partition_tensor == i).nonzero(as_tuple=True)[0]
        subgraph = graph.subgraph(nodes)
        subgraphs.append(subgraph)
        print(f'Number of nodes = {subgraph.num_nodes} and Number of edges = {subgraph.num_edges}')
    

# #     for subgraph in subgraphs:
        num_nodes = subgraph.num_nodes
        train_idx, test_idx = train_test_split(range(num_nodes), test_size=test_ratio)
        train_idx, val_idx = train_test_split(train_idx, test_size=val_ratio/(1-test_ratio))

        train_idx = torch.tensor(train_idx, dtype=torch.long)
        val_idx = torch.tensor(val_idx, dtype=torch.long)
        test_idx = torch.tensor(test_idx, dtype=torch.long)

        # Create new subgraphs with the same structure, but with the nodes split and init their struct encodings
        train_subgraph = init_structure_encoding(subgraph.subgraph(train_idx))
        val_subgraph = init_structure_encoding(subgraph.subgraph(val_idx))
        test_subgraph = init_structure_encoding(subgraph.subgraph(test_idx))

#         split_subgraph = []
#         partitions_split = pymetis.part_graph(3, adjacency=nx.to_dict_of_lists(nx_graph), ubvec = (train_ratio, val_ratio, test_ratio))
#         partitions_np_split = np.array(partitions[1])
#         partition_tensor_split = torch.from_numpy(partitions_np_split)
#         nodes_split = (partition_tensor_split == i).nonzero(as_tuple=True)[0]
#         split = graph.subgraph(nodes_split)
#         split_subgraph.append(split)

        
#         train_subgraph = split_subgraph[0]
#         val_subgraph = split_subgraph[1]
#         test_subgraph = split_subgraph[2]
        
        num_node_features = train_subgraph.num_node_features
        num_graph_labels = get_numGraphLabels(train_subgraph)#rewritten
        
        
        
        splitedData[data] = ({'train': train_subgraph, 'val': val_subgraph, 'test': test_subgraph},
                             num_node_features, num_graph_labels, train_subgraph.num_nodes)

        df = get_stats(df, data, train_subgraph, graph_val=val_subgraph, graph_test=test_subgraph)# rewritten
        
        train_subgraph.y = one_hot(train_subgraph.y).squeeze(dim=1)
        val_subgraph.y = one_hot(val_subgraph.y).squeeze(dim=1)
        test_subgraph.y = one_hot(test_subgraph.y).squeeze(dim=1)

    return splitedData, df

In [6]:
print("Preparing data ...")
splitedData, df_stats = prepareData_multiDS( num_clients, datapath, batch_size)
print("Done")

Preparing data ...
Number of nodes in the orignal graph = 169343
Number of edges in the orignal graph = 1166243
Average degree of the orignal graph = 6.886868663009395

Number of nodes = 56447 and Number of edges = 170543
Number of nodes = 56448 and Number of edges = 185667
Number of nodes = 56448 and Number of edges = 304555
Done


In [7]:
df_stats

Unnamed: 0,#Nodes_train,#Edges_train,Avg_degree_train,#Labels_train,Class_dist_train,#Nodes_test,#Edges_test,Avg_degree_test,#Labels_test,Class_dist_test,#Nodes_val,#Edges_val,Avg_degree_val,#Labels_val,Class_dist_val
Client1,11289.0,7068.0,0.626096,40.0,"{8: 670, 28: 2143, 27: 203, 10: 717, 31: 203, ...",33869.0,60862.0,1.796982,40.0,"{18: 177, 36: 960, 28: 6429, 33: 445, 34: 2602...",11289.0,6872.0,0.608734,40.0,"{27: 208, 39: 217, 16: 249, 28: 2123, 34: 801,..."
Client2,11289.0,7202.0,0.637966,40.0,"{24: 2652, 6: 131, 16: 1457, 8: 281, 31: 293, ...",33869.0,67300.0,1.987068,40.0,"{30: 5658, 34: 703, 31: 885, 4: 1097, 24: 7866...",11290.0,7551.0,0.668822,39.0,"{30: 1821, 4: 391, 24: 2674, 16: 1531, 8: 286,..."
Client3,11289.0,11377.0,1.007795,40.0,"{16: 3665, 11: 43, 28: 1647, 14: 68, 2: 229, 3...",33869.0,109247.0,3.225575,40.0,"{16: 11151, 3: 250, 28: 4951, 24: 3643, 10: 86...",11290.0,13168.0,1.166342,40.0,"{28: 1571, 36: 260, 27: 381, 2: 238, 1: 73, 34..."


In [8]:
outf = os.path.join(outbase, 'stats_trainData.csv')
df_stats.to_csv(outf)
print(f"Wrote to {outf}")

Wrote to outputs/stats_trainData.csv


In [9]:
splitedData

{'Client1': ({'train': Data(num_nodes=11289, edge_index=[2, 7068], x=[11289, 128], node_year=[11289, 1], y=[11289, 40], stc_enc=[11289, 32]),
   'val': Data(num_nodes=11289, edge_index=[2, 6872], x=[11289, 128], node_year=[11289, 1], y=[11289, 40], stc_enc=[11289, 32]),
   'test': Data(num_nodes=33869, edge_index=[2, 60862], x=[33869, 128], node_year=[33869, 1], y=[33869, 40], stc_enc=[33869, 32])},
  128,
  40,
  11289),
 'Client2': ({'train': Data(num_nodes=11289, edge_index=[2, 7202], x=[11289, 128], node_year=[11289, 1], y=[11289, 40], stc_enc=[11289, 32]),
   'val': Data(num_nodes=11290, edge_index=[2, 7551], x=[11290, 128], node_year=[11290, 1], y=[11290, 40], stc_enc=[11290, 32]),
   'test': Data(num_nodes=33869, edge_index=[2, 67300], x=[33869, 128], node_year=[33869, 1], y=[33869, 40], stc_enc=[33869, 32])},
  128,
  40,
  11289),
 'Client3': ({'train': Data(num_nodes=11289, edge_index=[2, 11377], x=[11289, 128], node_year=[11289, 1], y=[11289, 40], stc_enc=[11289, 32]),
   'v

In [10]:
n_se = n_rw + n_dg
n_se

32

# Making client and server

In [11]:
class Client_GC():
    def __init__(self, model, client_id, client_name, train_size, graphs, optimizer, device):
        self.model = model.to(device)
        self.id = client_id
        self.name = client_name
        self.train_size = train_size
        self.graphs = graphs
        self.device = device
        self.optimizer = optimizer

        self.W = {key: value for key, value in self.model.named_parameters()}
        self.dW = {key: torch.zeros_like(value) for key, value in self.model.named_parameters()}
        self.W_old = {key: value.data.clone() for key, value in self.model.named_parameters()}

        self.gconvNames = None

        self.train_stats = ([0], [0], [0], [0])
        self.weightsNorm = 0.
        self.gradsNorm = 0.
        self.convGradsNorm = 0.
        self.convWeightsNorm = 0.
        self.convDWsNorm = 0.

    def download_from_server(self, server):
        self.gconvNames = server.W.keys()
        for k in server.W:
            if '_s' in k:
                self.W[k].data = server.W[k].data.clone()
                

    def cache_weights(self):
        for name in self.W.keys():
            self.W_old[name].data = self.W[name].data.clone()

    def reset(self):
        copy(target=self.W, source=self.W_old, keys=self.gconvNames)

    def local_train(self, local_epoch):
        """ For self-train & FedAvg """
        train_stats = train_gc(self.model, self.graphs, self.optimizer, local_epoch, self.device)

        self.train_stats = train_stats
        self.weightsNorm = torch.norm(flatten(self.W)).item()

        weights_conv = {key: self.W[key] for key in self.gconvNames}
        self.convWeightsNorm = torch.norm(flatten(weights_conv)).item()

        grads = {key: value.grad for key, value in self.W.items()}
        self.gradsNorm = torch.norm(flatten(grads)).item()

        grads_conv = {key: self.W[key].grad for key in self.gconvNames}
        self.convGradsNorm = torch.norm(flatten(grads_conv)).item()

    def compute_weight_update(self, local_epoch):
        """ For GCFL """
        copy(target=self.W_old, source=self.W, keys=self.gconvNames)

        train_stats = train_gc(self.model, self.graphs, self.optimizer, local_epoch, self.device)

        subtract_(target=self.dW, minuend=self.W, subtrahend=self.W_old)

        self.train_stats = train_stats

        self.weightsNorm = torch.norm(flatten(self.W)).item()

        weights_conv = {key: self.W[key] for key in self.gconvNames}
        self.convWeightsNorm = torch.norm(flatten(weights_conv)).item()

        dWs_conv = {key: self.dW[key] for key in self.gconvNames}
        self.convDWsNorm = torch.norm(flatten(dWs_conv)).item()

        grads = {key: value.grad for key, value in self.W.items()}
        self.gradsNorm = torch.norm(flatten(grads)).item()

        grads_conv = {key: self.W[key].grad for key in self.gconvNames}
        self.convGradsNorm = torch.norm(flatten(grads_conv)).item()

    def evaluate(self):
        return eval_gc(self.model, self.graphs['test'], self.device)

    

def copy(target, source, keys):
    for name in keys:
        target[name].data = source[name].data.clone()

def subtract_(target, minuend, subtrahend):
    for name in target:
        target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone()

def flatten(w):
    return torch.cat([v.flatten() for v in w.values()])

def calc_gradsNorm(gconvNames, Ws):
    grads_conv = {k: Ws[k].grad for k in gconvNames}
    convGradsNorm = torch.norm(flatten(grads_conv)).item()
    return convGradsNorm

def train_gc(model, graphs, optimizer, local_epoch, device):
    losses_train, accs_train, losses_val, accs_val, losses_test, accs_test = [], [], [], [], [], []
    train_subgraph, val_subgraph, test_subgraph = graphs['train'], graphs['val'], graphs['test']
    model.to(device)
    for epoch in range(local_epoch):
        model.train()

        total_loss = 0.
        num_nodes = train_subgraph.num_nodes

        acc_sum = 0
        loss = torch.tensor([0.0])
        optimizer.zero_grad()
        out = model(train_subgraph)
        label = train_subgraph.y
        acc_sum = out.max(dim=1)[1].eq(label.max(dim=1)[1]).sum().item()
        
        for i in range(out.shape[0]): loss += model.loss(out[i], label[i])
        total_loss = loss.item()/num_nodes
#         print(total_loss, epoch)
        loss.backward()
        optimizer.step()
        acc = acc_sum / num_nodes

        loss_v, acc_v = eval_gc(model, val_subgraph, device)
        loss_tt, acc_tt = eval_gc(model, test_subgraph, device)

        losses_train.append(total_loss)
        accs_train.append(acc)
        losses_val.append(loss_v)
        accs_val.append(acc_v)
        losses_test.append(loss_tt)
        accs_test.append(acc_tt)

    return {'trainingLosses': losses_train, 'trainingAccs': accs_train, 'valLosses': losses_val, 'valAccs': accs_val,
            'testLosses': losses_test, 'testAccs': accs_test}



def eval_gc(model, test_graph, device):
    model.eval()
    loss = torch.tensor([0.0])
    total_loss = 0.
    acc_sum = 0.
    num_nodes = test_graph.num_nodes
    out = model(test_graph)
    label = test_graph.y
    
    for i in range(out.shape[0]): loss += model.loss(out[i], label[i])
    
    total_loss = loss.item()
#     print(out.max(dim=1)[1].eq(label.max(dim=1)[1]))
    acc_sum = out.max(dim=1)[1].eq(label.max(dim=1)[1]).sum().item()
#     print(out.max(dim=1)[1], label.max(dim=1)[1])
    return total_loss/num_nodes, acc_sum/num_nodes


class Server():
    def __init__(self, model, device):
        self.model = model.to(device)
        self.W = {key: value for key, value in self.model.named_parameters()}
        self.model_cache = []

    def randomSample_clients(self, all_clients, frac):
        return random.sample(all_clients, int(len(all_clients) * frac))

    def aggregate_weights(self, selected_clients):
        # pass train_size, and weighted aggregate
        total_size = 0
        for client in selected_clients:
            total_size += client.train_size
        for k in self.W.keys():
            self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()

    def aggregate_weights_per(self, selected_clients):
        # pass train_size, and weighted aggregate
        total_size = 0
        for client in selected_clients:
            total_size += client.train_size
        for k in self.W.keys():
            if 'graph_convs' in k:
                self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()

    def aggregate_weights_se(self, selected_clients):
        # pass train_size, and weighted aggregate
        total_size = 0
        for client in selected_clients:
            total_size += client.train_size
        for k in self.W.keys():
            if '_s' in k:
                self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()

    def aggregate_weights_fe(self, selected_clients):
        # pass train_size, and weighted aggregate
        total_size = 0
        for client in selected_clients:
            total_size += client.train_size
        for k in self.W.keys():
            if '_s' not in k:
                self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()


    def compute_pairwise_similarities(self, clients):
        client_dWs = []
        for client in clients:
            dW = {}
            for k in self.W.keys():
                dW[k] = client.dW[k]
            client_dWs.append(dW)
        return pairwise_angles(client_dWs)

    def compute_pairwise_distances(self, seqs, standardize=False):
        """ computes DTW distances """
        if standardize:
            # standardize to only focus on the trends
            seqs = np.array(seqs)
            seqs = seqs / seqs.std(axis=1).reshape(-1, 1)
            distances = dtw.distance_matrix(seqs)
        else:
            distances = dtw.distance_matrix(seqs)
        return distances

    def min_cut(self, similarity, idc):
        g = nx.Graph()
        for i in range(len(similarity)):
            for j in range(len(similarity)):
                g.add_edge(i, j, weight=similarity[i][j])
        cut, partition = nx.stoer_wagner(g)
        c1 = np.array([idc[x] for x in partition[0]])
        c2 = np.array([idc[x] for x in partition[1]])
        return c1, c2

    def aggregate_clusterwise(self, client_clusters):
        for cluster in client_clusters:
            targs = []
            sours = []
            total_size = 0
            for client in cluster:
                W = {}
                dW = {}
                for k in self.W.keys():
                    W[k] = client.W[k]
                    dW[k] = client.dW[k]
                targs.append(W)
                sours.append((dW, client.train_size))
                total_size += client.train_size
            # pass train_size, and weighted aggregate
            reduce_add_average(targets=targs, sources=sours, total_size=total_size)

    def compute_max_update_norm(self, cluster):
        max_dW = -np.inf
        for client in cluster:
            dW = {}
            for k in self.W.keys():
                dW[k] = client.dW[k]
            update_norm = torch.norm(flatten(dW)).item()
            if update_norm > max_dW:
                max_dW = update_norm
        return max_dW
        # return np.max([torch.norm(flatten(client.dW)).item() for client in cluster])

    def compute_mean_update_norm(self, cluster):
        cluster_dWs = []
        for client in cluster:
            dW = {}
            for k in self.W.keys():
                dW[k] = client.dW[k]
            cluster_dWs.append(flatten(dW))

        return torch.norm(torch.mean(torch.stack(cluster_dWs), dim=0)).item()

    def cache_model(self, idcs, params, accuracies):
        self.model_cache += [(idcs,
                              {name: params[name].data.clone() for name in params},
                              [accuracies[i] for i in idcs])]

def flatten(source):
    return torch.cat([value.flatten() for value in source.values()])

def pairwise_angles(sources):
    angles = torch.zeros([len(sources), len(sources)])
    for i, source1 in enumerate(sources):
        for j, source2 in enumerate(sources):
            s1 = flatten(source1)
            s2 = flatten(source2)
            angles[i, j] = torch.true_divide(torch.sum(s1 * s2), max(torch.norm(s1) * torch.norm(s2), 1e-12)) + 1

    return angles.numpy()

def reduce_add_average(targets, sources, total_size):
    for target in targets:
        for name in target:
            tmp = torch.div(torch.sum(torch.stack([torch.mul(source[0][name].data, source[1]) for source in sources]), dim=0), total_size).clone()
            target[name].data += tmp

# Defining models

In [12]:
class GIN_dc(torch.nn.Module):
    def __init__(self, nfeat, n_se, nhid, nclass, nlayer, dropout):
        super(GIN_dc, self).__init__()
        self.num_layers = nlayer
        self.dropout = dropout

        self.pre = torch.nn.Sequential(torch.nn.Linear(nfeat, nhid))

        self.embedding_s = torch.nn.Linear(n_se, nhid)

        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        self.graph_convs_s_gcn = torch.nn.ModuleList()
        self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))
            self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

        self.Whp = torch.nn.Linear(nhid + nhid, nhid)
        self.post = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU())
        self.readout = torch.nn.Sequential(torch.nn.Linear(nhid, nclass))

    def forward(self, data):
        x, edge_index, s = data.x, data.edge_index, data.stc_enc
        x = self.pre(x)
        s = self.embedding_s(s)
        for i in range(len(self.graph_convs)):
            x = torch.cat((x, s), -1)
            x = self.graph_convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            s = self.graph_convs_s_gcn[i](s, edge_index)
            s = torch.tanh(s)
        x = self.Whp(torch.cat((x, s), -1))
        x = self.post(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.readout(x)
        x = F.log_softmax(x, dim=1)
        return x
    def loss(self, pred, label):
#         print(pred, label)
        return F.nll_loss(pred, label)


class GIN(torch.nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayer, dropout):
        super(GIN, self).__init__()
        self.num_layers = nlayer
        self.dropout = dropout

        self.pre = torch.nn.Sequential(torch.nn.Linear(nfeat, nhid))

        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))

        self.post = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU())
        self.readout = torch.nn.Sequential(torch.nn.Linear(nhid, nclass))

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.pre(x)
        for i in range(len(self.graph_convs)):
            x = self.graph_convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
        x = self.post(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.readout(x)
        x = F.log_softmax(x, dim=1)
        return x

    def loss(self, pred, label):
        return F.nll_loss(pred, label)
    
    
class serverGIN_dc(torch.nn.Module):
    def __init__(self, n_se, nlayer, nhid):
        super(serverGIN_dc, self).__init__()

        self.embedding_s = torch.nn.Linear(n_se, nhid)
        self.Whp = torch.nn.Linear(nhid + nhid, nhid)

        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        self.graph_convs_s_gcn = torch.nn.ModuleList()
        self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))
            self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))
            
            
class serverGIN(torch.nn.Module):
    def __init__(self, nlayer, nhid):
        super(serverGIN, self).__init__()
        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(),
                                       torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))

In [13]:
def setup_devices(splitedData) :
    idx_clients = {}
    clients = []
    for idx, ds in enumerate(splitedData.keys()):
        idx_clients[idx] = ds
        graphs, num_node_features, num_graph_labels, train_size = splitedData[ds]
        if alg == 'fedstar':
            cmodel_gc = GIN_dc(num_node_features, n_se, hidden, num_graph_labels, nlayer, dropout)
        else:
            cmodel_gc = GIN(num_node_features, hidden, num_graph_labels, nlayer, dropout)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cmodel_gc.parameters()), lr=lr, weight_decay=weight_decay)
        clients.append(Client_GC(cmodel_gc, idx, ds, train_size, graphs, optimizer, device ))

    if alg == 'fedstar':
        smodel = serverGIN_dc(n_se=n_se, nlayer=nlayer, nhid=hidden)
    else:
        smodel = serverGIN(nlayer=nlayer, nhid=hidden)
    server = Server(smodel, device)
    return clients, server, idx_clients

In [14]:
init_clients, init_server, init_idx_clients = setup_devices(splitedData)
print("\nDone setting up devices.")


Done setting up devices.


In [15]:
init_clients[0].device

'cpu'

# Training

In [16]:
def run_selftrain_GC(clients, server, local_epoch):
    # all clients are initialized with the same weights
    for client in clients:
        client.download_from_server(server)

    allAccs = {}
    for client in clients:
        client.local_train(local_epoch)

        loss, acc = client.evaluate()
        allAccs[client.name] = [client.train_stats['trainingAccs'][-1], client.train_stats['valAccs'][-1], acc]
        print("  > {} done.".format(client.name))

    return allAccs


def run_fedavg(clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
    for client in clients:
        client.download_from_server(server)

    if samp is None:
        sampling_fn = server.randomSample_clients
        frac = 1.0

    for c_round in range(1, COMMUNICATION_ROUNDS + 1):
        if (c_round) % 50 == 0:
            print(f"  > round {c_round}")

        if c_round == 1:
            selected_clients = clients
        else:
            selected_clients = sampling_fn(clients, frac)

        for client in selected_clients:
            # only get weights of graphconv layers
            client.local_train(local_epoch)

        server.aggregate_weights(selected_clients)
        for client in selected_clients:
            client.download_from_server(server)

        # write to log files
        if c_round % 5 == 0:
            for idx in range(len(clients)):
                loss, acc = clients[idx].evaluate()
                summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round)
                summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round)

    frame = pd.DataFrame()
    for client in clients:
        loss, acc = client.evaluate()
        frame.loc[client.name, 'test_acc'] = acc

    def highlight_max(s):
        is_max = s == s.max()
        return ['background-color: yellow' if v else '' for v in is_max]

    fs = frame.style.apply(highlight_max).data
    print(fs)
    return frame


def run_fedstar(clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
    for client in clients:
        client.download_from_server(server)

    if samp is None:
        sampling_fn = server.randomSample_clients
        frac = 1.0

    for c_round in range(1, COMMUNICATION_ROUNDS + 1):
        if (c_round) % 50 == 0:
            print(f"  > round {c_round}")

        if c_round == 1:
            selected_clients = clients
        else:
            selected_clients = sampling_fn(clients, frac)

        for client in selected_clients:
            # only get weights of graphconv layers
            client.local_train(local_epoch)

        server.aggregate_weights_se(selected_clients)
        for client in selected_clients:
            client.download_from_server(server)

        # write to log files
        if c_round % 5 == 0:
            for idx in range(len(clients)):
                loss, acc = clients[idx].evaluate()
                summary_writer.add_scalar('Test/Acc/user' + str(idx + 1), acc, c_round)
                summary_writer.add_scalar('Test/Loss/user' + str(idx + 1), loss, c_round)
                print(acc, loss, c_round)

    frame = pd.DataFrame()
    for client in clients:
        loss, acc = client.evaluate()
        frame.loc[client.name, 'test_acc'] = acc

    def highlight_max(s):
        is_max = s == s.max()
        return ['background-color: yellow' if v else '' for v in is_max]

    fs = frame.style.apply(highlight_max).data
    print(fs)
    return frame

In [17]:
def process_selftrain( clients, server, local_epoch):
    print("Self-training ...")
    df = pd.DataFrame()
    allAccs = run_selftrain_GC( clients, server, local_epoch)
    for k, v in allAccs.items():
        df.loc[k, [f'train_acc', f'val_acc', f'test_acc']] = v
    print(df)
    outfile = os.path.join(outbase, f'accuracy_'+alg+'_GC.csv')
    df.to_csv(outfile)
    print(f"Wrote to file: {outfile}")

def process_fedstar( clients, server, summary_writer):
    print("\nDone setting up FedStar devices.")

    print("Running FedStar ...")
    frame = run_fedstar( clients, server, num_rounds, local_epoch, samp=None, summary_writer=summary_writer)
    outfile = os.path.join(outbase, f'accuracy_fedstar_{type_init}_GC.csv')
    frame.to_csv(outfile)
    print(f"Wrote to file: {outfile}")

In [18]:
import copy
sw_path = os.path.join(outbase, 'raw', 'tensorboard', f'{data_group}_{alg}_{type_init}')
summary_writer = SummaryWriter(sw_path)
process_fedstar( clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), summary_writer=summary_writer)


Done setting up FedStar devices.
Running FedStar ...
0.0037792671764740617 0.025551088946912184 5
0.002804924857539343 0.3713111835767959 5
0.003306858779414804 0.02268377015347737 5
0.0037792671764740617 0.02554325523985798 10
0.002804924857539343 3.805158063051758 10
0.003306858779414804 0.022406735888067814 10
0.0037792671764740617 0.02544148552776414 15
0.002804924857539343 4.869108273790191 15
0.003306858779414804 0.02230782970496158 15
0.0037792671764740617 0.025059046724604357 20
0.002804924857539343 0.033697302704907145 20
0.003306858779414804 0.022432062531486205 20
         test_acc
Client1  0.003779
Client2  0.002805
Client3  0.003307
Wrote to file: outputs/accuracy_fedstar_rw_dg_GC.csv


In [19]:
process_selftrain(clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=200)

Self-training ...
  > Client1 done.
  > Client2 done.
  > Client3 done.
         train_acc   val_acc  test_acc
Client1   0.003189  0.003543  0.003779
Client2   0.002835  0.002923  0.002805
Client3   0.003366  0.004517  0.003307
Wrote to file: outputs/accuracy_fedstar_GC.csv


In [24]:
def process_fedavg( clients, server, summary_writer):
    print("\nDone setting up FedAvg devices.")

    print("Running FedAvg ...")
    frame = run_fedavg( clients, server, num_rounds, local_epoch, samp=None, summary_writer=summary_writer)
    outfile = os.path.join(outbase, f'accuracy_fedavg_{type_init}_GC.csv')
    frame.to_csv(outfile)
    print(f"Wrote to file: {outfile}")

In [25]:
sw_path = os.path.join(outbase, 'raw', 'tensorboard', f'{data_group}_fedavg_{type_init}')
summary_writer_avg = SummaryWriter(sw_path)
process_fedavg(clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server),summary_writer=summary_writer_avg)


Done setting up FedAvg devices.
Running FedAvg ...
         test_acc
Client1  0.003779
Client2  0.002805
Client3  0.003307
Wrote to file: outputs/accuracy_fedavg_rw_dg_GC.csv


I don't understand why its giving the same accuracy in all cases even though the loss was decreasing. 