In [1]:
import numpy as np
import torch as th
import dgl
import os
import pandas as pd
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.loss import LabelSmoothingCrossEntropy
import random
import torch
from sklearn.metrics import roc_auc_score,f1_score,accuracy_score
from sklearn.model_selection import KFold
import datetime
device = th.device('cuda:0')

In [None]:
oridir = os.path.dirname(__file__)
datainfo_dir = os.path.join(oridir,'DataInfo.csv')
funcdata_dir = os.path.join(oridir,'data/RawFunc')

In [2]:
class MyDataset(DGLDataset):
    """
    """
    def __init__(self,func_atlas,datainfo,feat_root,removeAlone=True):
        super(MyDataset,self).__init__(
            name='ABIDE I'
        )
        self.func_atlas = func_atlas
        self.datainfo = datainfo
        self.feat_root = feat_root
        self.getGraphStructual()
        self.removeAlone=removeAlone  
    def process(self):
        pass
    def save(self):
        pass
    def load(self):
        pass
        
    def __getitem__(self,idx):
        try:
            label = self.datainfo.loc[idx,'LABEL']
        except:
            print(idx)
            label = 0 
        label = th.tensor(label).long()
        subid = self.datainfo.loc[idx,'SubID']
        nfeat = os.path.join(self.feat_root,str(subid)+'.npy')
        nfeat = th.from_numpy(np.load(nfeat)).float()
        graph = dgl.heterograph(self.graphS)
        graph.ndata['feat'] = nfeat
        return graph,label
    
    def __len__(self):
        return len(self.datainfo)
    
    def getGraphStructual(self):
        graphdata = {}
        graphnode = []
        for index,func in enumerate(self.func_atlas):
            same_net_i,same_net_j=[],[]
            for i in range(func.shape[0]):
                for j in range(func.shape[0]):
                    if (func[i] == True) and (func[j] == True):
                        same_net_i.append(i)
                        same_net_j.append(j)
                        if i not in graphnode:
                            graphnode.append(i)
            same_net_i = th.tensor(same_net_i,dtype=th.int32)
            same_net_j = th.tensor(same_net_j,dtype=th.int32)
            graphdata[('area','func'+str(index),'area')] = (same_net_i,same_net_j)
        self.graphS  = graphdata
        self.graphN = {'area':graphnode}

In [3]:

class Runer(object):
    def __init__(self,args_model,args_train=None):
        self.args_train = args_train
        self.args_model = args_model
        if args_model['model_name'] in ['HGT','HAN','SEHGT']:
            self.modeltype= args_model['model_name']
        else:
            self.modeltype = 'HGT'
        self.template = np.load(args_train['template'])
        self.dataload()

    def init_weight(self):
        self.setup_seed(2022)
        self.init_model()
        self.init_optimizer()
        self.init_lossfn()

    def dataload(self):
        self.RSN=range(self.template.shape[0])
        self.func_atlas = [self.template[i] for i in self.RSN]
        self.datainfo = pd.read_csv(datainfodir)
        self.KFold = KFold(n_splits=5,random_state=2022,shuffle=True).split(self.datainfo)


    def init_data(self,train_split,test_split):
        train_info = self.datainfo.loc[train_split]
        test_info = self.datainfo.loc[test_split]
        train_info = train_info.reset_index(drop=True)
        test_info = test_info.reset_index(drop=True)
        feat_root = funcdata_dir
        self.traindata = MyDataset(self.func_atlas,train_info,feat_root,removeAlone=True)
        self.testdata = MyDataset(self.func_atlas,test_info,feat_root,removeAlone=True)
        self.train_loader = GraphDataLoader(self.traindata,batch_size = 16)
        self.test_loader = GraphDataLoader(self.testdata,batch_size = 16)
        print('Train:{}  Test:{}  Nodes:{}'.format(len(self.traindata),len(self.testdata),len(self.traindata.graphN['area'])))

    def setup_seed(self,seed=42):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        
    def init_model(self):
        if self.modeltype =='SEHGT':
            node_dict = {'area':0}
            edge_dict={'func'+str(i):i for i in range(len(self.RSN))}
            args = self.args_model
            from ModelsForGraph import SEHGT
            self.model=SEHGT(
                node_dict,
                edge_dict,
                n_inp=116,
                n_hid=args['n_hid'],
                n_out=2,
                n_layers=args['n_layers'],
                n_heads=args['n_heads']
            )
        elif self.modeltype == 'HGT':
            node_dict = {'area':0}
            edge_dict={'func'+str(i):i for i in range(len(self.RSN))}
            args = self.args_model
            from ModelsForGraph import HGT_v1
            self.model=HGT_v1(
                node_dict,
                edge_dict,
                n_inp=116,
                n_hid=args['n_hid'],
                n_out=2,
                n_layers=args['n_layers'],
                n_heads=args['n_heads']
            )
        elif self.modeltype == 'HAN':
            from ModelsForGraph import HAN
            self.model = HAN(
                meta_paths=self.args_model['metapath'],
                in_size=116,
                hidden_size=self.args_model['n_hid'],
                out_size=2,
                num_heads=[self.args_model['num_heads']],
                dropout=self.args_model['dropout']
            )
        self.model = self.model.to(device)

    def init_optimizer(self):
        args = self.args_train
        self.optimizer = th.optim.Adam(self.model.parameters(),lr=args['lr'])
        self.lr_schedule=CosineLRScheduler(optimizer=self.optimizer,t_initial=50,lr_min=1e-5,warmup_t=5,cycle_limit=1)
    
    def init_lossfn(self):
        self.lossfn =  LabelSmoothingCrossEntropy(self.args_train['losspr'])
        self.lossfn = self.lossfn.to(device)

    def validate(self):
        with th.no_grad():
            self.model.eval()
            test_pred,test_true,test_pred_prob=[],[],[]
            for image,label in self.test_loader:
                image=image.to(device)
                label=label.to(device)
                pred=self.model(image)
                test_pred.extend(pred.argmax(dim=1).tolist())
                test_true.extend(label.tolist())
                test_pred_prob.extend(pred.tolist())
            test_acc=accuracy_score(test_true,test_pred)
            test_f1=f1_score(test_true,test_pred)
            test_true=np.eye(pred.shape[1])[test_true]
            test_auc=roc_auc_score(test_true,test_pred_prob)
            return test_acc,test_f1,test_auc
    
    def train(self):
        result_acc,result_f1,result_auc=[],[],[]
        for train_split,test_split in self.KFold:
            self.init_data(train_split,test_split)
            self.init_weight()
            optimizer = self.optimizer
            max_acc,max_f1,max_auc=.0,.0,.0
            for epoch in range(self.args_train['epoch']):
                self.model.train()
                train_loss,train_acc=.0,.0
                for image,label in self.train_loader:
                    image=image.to(device)
                    label=label.to(device)
                    pred=self.model(image)
                    loss=self.lossfn(pred,label)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    train_loss+=loss.item()
                    acc = (pred.argmax(dim=1) == label).float().sum()
                    train_acc +=acc.item()
                self.lr_schedule.step(epoch)
                acc,f1,auc=self.validate()
                if (epoch+1)%10==0:
                    print('Epoch: {:2d}  Train Loss: {:.4f}  Train Acc:  {:.4F}  Test [Acc:{:.4f} F1:{:.4f} AUC:{:.4f}]'.format(
                        epoch+1,train_loss/len(self.train_loader),train_acc/len(self.traindata),acc,f1,auc))
                    max_acc=max(acc,max_acc)
                    max_f1=max(f1,max_f1)
                    max_auc=max(auc,max_auc)
            result_acc.append(max_acc)
            result_f1.append(max_f1)
            result_auc.append(max_auc)
            if acc == max_acc or f1 == max_f1:
                key=datetime.datetime.strftime(datetime.datetime.now(),'%Y-%m-%d-%H-%M-%S')
                path = './checkpoint/'+key+'.pt'
                torch.save(self.model.state_dict(),path)
        
        result_acc=np.array(result_acc)
        result_f1=np.array(result_f1)
        result_auc=np.array(result_auc)
        result='[Acc:{:.3f}+{:.3f}  F1:{:.3f}+{:.3f}  AUC:{:.3f}+{:.3f}]'.format(
            np.mean(result_acc),np.std(result_acc),np.mean(result_f1),np.std(result_f1),np.mean(result_auc),np.std(result_auc)
        )
        print('Final Result',result)
        return result,result_acc,result_f1,result_auc

In [4]:
import json
import datetime
args_train = {
    'lr':1e-4,
    'losspr':0.2,
    'epoch':100,
    'template':'BM20.npy'
    }
args_model={
    'model_name':'SEHGT',
    'n_hid':384,
    'n_layers':8,
    'n_heads':12
    }


# metapath=[]
# for i in range(10):
#     metapath.append(['func'+str(i)])
#     if i == 9:
#         break
#     for j in range(i+1,10):
#         metapath.append(['func'+str(i),'func'+str(j)])
# args_HAN={
#     'model_name':'HAN',
#     'metapath':metapath,
#     'n_hid':256,
#     'num_heads':8,
#     'dropout':0.2
# }
#args_model = args_HAN


for model in ['SEHGT']:
    for rsn in ['BM20.npy']:
        args_train['template']=rsn
        args_model['model_name']=model
        main = Runer(args_model=args_model,args_train=args_train)
        modelresult,acc,f1,auc=main.train()
        result={
            'args_train':args_train,
            'args_model':args_model,
            'Result':modelresult}
        if os.path.exists('result.json'):
            with open('result.json','r',encoding='utf8')as fp:
                json_data = json.load(fp)
        else:
            json_data={}
        key=datetime.datetime.strftime(datetime.datetime.now(),'%Y-%m-%d %H:%M:%S')
        json_data[key]=result
        with open('result.json','w') as f:
            f.write(json.dumps(json_data,ensure_ascii=False,indent=4,separators=(',',':')))

Train:696  Test:175  Nodes:116
Epoch: 10  Train Loss: 0.6915  Train Acc:  0.4899  Test [Acc:0.6000 F1:0.3137 AUC:0.6144]
Epoch: 20  Train Loss: 0.6470  Train Acc:  0.6624  Test [Acc:0.5829 F1:0.4672 AUC:0.6135]
Epoch: 30  Train Loss: 0.5292  Train Acc:  0.8204  Test [Acc:0.5886 F1:0.5325 AUC:0.6482]
Epoch: 40  Train Loss: 0.4431  Train Acc:  0.9210  Test [Acc:0.6571 F1:0.5775 AUC:0.6592]
Epoch: 50  Train Loss: 0.4070  Train Acc:  0.9497  Test [Acc:0.6514 F1:0.5734 AUC:0.6583]
Epoch: 60  Train Loss: 0.3845  Train Acc:  0.9698  Test [Acc:0.6514 F1:0.5734 AUC:0.6610]
Epoch: 70  Train Loss: 0.3667  Train Acc:  0.9856  Test [Acc:0.6457 F1:0.5694 AUC:0.6593]
Epoch: 80  Train Loss: 0.3511  Train Acc:  0.9928  Test [Acc:0.6457 F1:0.5634 AUC:0.6600]
Epoch: 90  Train Loss: 0.3404  Train Acc:  0.9943  Test [Acc:0.6400 F1:0.5532 AUC:0.6468]
Epoch: 100  Train Loss: 0.3347  Train Acc:  0.9957  Test [Acc:0.6457 F1:0.5571 AUC:0.6441]
Train:697  Test:174  Nodes:116


KeyboardInterrupt: 

In [None]:
acc

array([0.62857143, 0.70689655, 0.67241379, 0.67816092, 0.67816092])

In [None]:
model.gcs[-1].selayer.register_forward_hook(hook)