In [1]:
import numpy as np
import torch as th
import dgl
import os
import pandas as pd
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.loss import LabelSmoothingCrossEntropy
import random
import torch
from sklearn.metrics import roc_auc_score,f1_score,accuracy_score
from sklearn.model_selection import KFold,train_test_split
import copy
device = th.device('cuda:0')
from torch.optim import Optimizer


In [2]:
class MyDataset(DGLDataset):
    """
    """
    def __init__(self,func_atlas,datainfo,feat_root,removeAlone=True):
        super(MyDataset,self).__init__(
            name='ABIDE I'
        )
        self.func_atlas = func_atlas
        self.datainfo = datainfo
        self.feat_root = feat_root
        self.getGraphStructual()
        self.removeAlone=removeAlone  
    def process(self):
        pass
    def save(self):
        pass
    def load(self):
        pass
        
    def __getitem__(self,idx):
        try:
            label = self.datainfo.loc[idx,'LABEL']
        except:
            print(idx)
            label = 0 
        label = th.tensor(label).long()
        subid = self.datainfo.loc[idx,'SubID']
        nfeat = os.path.join(self.feat_root,str(subid)+'.npy')
        nfeat = th.from_numpy(np.load(nfeat)).float()
        graph = dgl.heterograph(self.graphS)
        graph.ndata['feat'] = nfeat
        return graph,label
    
    def __len__(self):
        return len(self.datainfo)
    
    def getGraphStructual(self):
        graphdata = {}
        graphnode = []
        for index,func in enumerate(self.func_atlas):
            same_net_i,same_net_j=[],[]
            for i in range(func.shape[0]):
                for j in range(func.shape[0]):
                    if (func[i] == True) and (func[j] == True):
                        same_net_i.append(i)
                        same_net_j.append(j)
                        if i not in graphnode:
                            graphnode.append(i)
            same_net_i = th.tensor(same_net_i,dtype=th.int32)
            same_net_j = th.tensor(same_net_j,dtype=th.int32)
            graphdata[('area','func'+str(index),'area')] = (same_net_i,same_net_j)
        self.graphS  = graphdata
        self.graphN = {'area':graphnode}

In [3]:


class client(object):
    def __init__(self,train_dataset,test_dataset,localBatchsize,device):
        self.train_ds = train_dataset
        self.test_ds = test_dataset
        self.dev = device
        self.train_dl = GraphDataLoader(self.train_ds, batch_size = localBatchsize)
        self.test_dl = GraphDataLoader(self.test_ds,batch_size = localBatchsize)
        self.local_parameters = None

    def localUpdate(self,localEpoch,model,lossfn,optimizer,lr_schedule,global_parameters):
        model.load_state_dict(global_parameters,strict=True)
        result_acc,result_f1,result_auc=[],[],[]
        for epoch in range(localEpoch):
            model.train()
            train_loss,train_acc=.0,.0
            for image,label in self.train_dl:
                image=image.to(self.dev)
                label=label.to(self.dev)
                pred=model(image)
                loss=lossfn(pred,label)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss+=loss.item()
                acc = (pred.argmax(dim=1) == label).float().sum()
                train_acc +=acc.item()
            lr_schedule.step(epoch)
            
            print('Epoch: {:2d}  Train Loss: {:.4f}  Train Acc:  {:.4F}  '.format(
                epoch,train_loss/len(self.train_dl),train_acc/len(self.train_ds)))

        return model.state_dict()
        

    def local_validate(self,model):
        with th.no_grad():
            model.eval()
            test_pred,test_true,test_pred_prob=[],[],[]
            for image,label in self.test_dl:
                image=image.to(self.dev)
                label=label.to(self.dev)
                pred=model(image)
                test_pred.extend(pred.argmax(dim=1).tolist())
                test_true.extend(label.tolist())
                test_pred_prob.extend(pred.tolist())
                
            test_acc=accuracy_score(test_true,test_pred)
            test_f1=f1_score(test_true,test_pred)
            test_true=np.eye(pred.shape[1])[test_true]
            test_auc=roc_auc_score(test_true,test_pred_prob)
            return test_acc,test_f1,test_auc


In [4]:
class ClientGroup(object):
    def __init__(self,dir,template,datainfo,device,localBatchsize):
        self.datainfo = pd.read_csv(datainfo)
        self.dir = dir
        self.template = template
        self.dev = device
        self.localBatchsize = localBatchsize
        self.clients_set = {}
        self.dataSetAllocation()

    def dataSetAllocation(self):
        RSN = range(self.template.shape[0])
        func_atlas = [self.template[i] for i in RSN]
        df = self.datainfo.groupby("SITE").nunique()["ID"]
        self.len = len(df)
        self.SITE = df.keys()
        sites = {}
        for k in df.keys():
            sites[k] = self.datainfo[(self.datainfo['SITE']==k)].index.tolist()
            site_info = self.datainfo.loc[sites[k]].reset_index(drop=True)
            train_info,test_info = train_test_split(site_info,train_size=0.7,random_state=2022)
            train_info = train_info.reset_index(drop = True)
            test_info = test_info.reset_index(drop = True)
            train_data = MyDataset(func_atlas,train_info,self.dir,removeAlone=True)
            test_data = MyDataset(func_atlas,test_info,self.dir,removeAlone=True)
            someone = client(train_data,test_data,self.localBatchsize,self.dev)
            self.clients_set["SITE_{}".format(k)] = someone



In [5]:
class server(object):
    def __init__(self,args_model,args_train,beta=0.9):
        self.args_train = args_train
        self.args_model = args_model
        self.template = np.load(args_train['template'])
        self.beta = beta
        if args_model['model'] in ['HGT','SEHGT']:
            self.modeltype = args_model['model']
        else:
            self.modeltype = 'SEHGT'
        self.global_parameters = {}
        self.init_data()
        self.init_weight()

    def init_data(self):
        datainfo = 'DataInfo.csv'
        feat_root='/RawFunc'
        self.clients = ClientGroup(feat_root,self.template,datainfo,device,self.args_train['localbatch'])
        self.num_in_comm = int(max(self.clients.len*self.args_train['cfraction'],0.3))
        self.num_sub = 0
        order = np.random.permutation(self.clients.SITE)
        clients_in_comm = ['SITE_{}'.format(i) for i in order[0:self.num_in_comm]]
        for client in clients_in_comm:
            self.num_sub += len(self.clients.clients_set[client].train_ds)
        self.sites = self.clients.SITE

    def init_model(self):
        self.RSN = range(self.template.shape[0])
        if self.modeltype is 'HGT':
            node_dict = {'area':0}
            edge_dict={'func'+str(i):i for i in range(len(self.RSN))}
            args = self.args_model
            from ModelsForGraph import HGT
            self.model=HGT(
                node_dict,
                edge_dict,
                n_inp=116,
                n_hid=args['n_hid'],
                n_out=2,
                n_layers=args['n_layers'],
                n_heads=args['n_heads']
            )
        elif self.modeltype is 'SEHGT':
            node_dict = {'area':0}
            edge_dict={'func'+str(i):i for i in range(len(self.RSN))}
            args = self.args_model
            from ModelsForGraph import HGT_v2
            self.model=HGT_v2(
                node_dict,
                edge_dict,
                n_inp=116,
                n_hid=args['n_hid'],
                n_out=2,
                n_layers=args['n_layers'],
                n_heads=args['n_heads']
            )
        else:
            pass

        self.model = self.model.to(device)

        for key,var in self.model.state_dict().items():
            self.global_parameters[key] = var.clone()
    
    def init_weight(self):
        self.setup_seed(1234)
        self.init_model()
        self.init_optimizer()
        self.init_lossfn()
    
    def setup_seed(self,seed=2022):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

    def init_optimizer(self):
        args = self.args_train
        self.optimizer = th.optim.Adam(self.model.parameters(),lr=args['lr'])
        self.lr_schedule=CosineLRScheduler(optimizer=self.optimizer,t_initial=50,lr_min=1e-5,warmup_t=5,cycle_limit=1)
    
    def init_lossfn(self):
        self.lossfn =  LabelSmoothingCrossEntropy(self.args_train['losspr'])
        self.lossfn = self.lossfn.to(device)

    def train(self):
        for i in range(self.args_train['num_comm']):
            print("communication round {}".format(i+1),file=doc)

            order = np.random.permutation(self.clients.SITE)
            clients_in_comm = ['SITE_{}'.format(i) for i in order[0:self.num_in_comm]]

            sum_parameters = None
            for client in clients_in_comm:
                
                
                local_parameters = self.clients.clients_set[client].localUpdate(self.args_train['localepoch'],
                                                                                self.model,self.lossfn,self.optimizer,self.lr_schedule,
                                                                                self.global_parameters)
                if sum_parameters is None:
                    sum_parameters = {}
                    for key,var in local_parameters.items():
                        sum_parameters[key] = var.clone()*(len(self.clients.clients_set[client].train_ds)/self.num_sub)
                else:
                    for key in sum_parameters:
                        sum_parameters[key] = sum_parameters[key] + local_parameters[key]*(len(self.clients.clients_set[client].train_ds)/self.num_sub)

            for key in self.global_parameters:
                self.global_parameters[key] = sum_parameters[key]


            clients = ['SITE_{}'.format(i) for i in self.clients.SITE]

            print("*******************\nfederated results:",file=doc)
            self.model.load_state_dict(self.global_parameters,strict=True)
            for c in clients:
                acc,f1,auc = self.clients.clients_set[c].local_validate(self.model)
                print('{}:  [Acc:{:.3f}  F1:{:.3f}  AUC:{:.3f}]'.format(c,acc,f1,auc),file=doc)




In [6]:
from datetime import datetime
args_train = {
    'lr':1e-4,
    'losspr':0.2,
    'epoch':100,
    'template':'BM20.npy',
    'cfraction':1,
    'num_comm':50,
    'localepoch':20,
    'localbatch':16
    }
args_model={
    'model':'SEHGT',
    'n_hid':384,
    'n_layers':8,
    'n_heads':12
    }
doc = open('fedavg.txt','a')
print('\n',datetime.now(),file=doc)
print('args_model:\n{}\n args_train:\n{}'.format(args_model,args_train),file=doc)
server = server(args_model,args_train)
server.train()
doc.close()

Epoch:  0  Train Loss: 0.7500  Train Acc:  0.4571  
Epoch:  1  Train Loss: 0.7501  Train Acc:  0.4571  
Epoch:  2  Train Loss: 0.7460  Train Acc:  0.4571  
Epoch:  3  Train Loss: 0.7318  Train Acc:  0.4571  
Epoch:  4  Train Loss: 0.7122  Train Acc:  0.4571  
Epoch:  5  Train Loss: 0.6888  Train Acc:  0.4857  
Epoch:  6  Train Loss: 0.6670  Train Acc:  0.5429  


KeyboardInterrupt: 