In [None]:
import networkx as nx
import numpy as np
import os
import scipy.io as scio
import torch
import torch.nn as nn

import dgl
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader
import dgl.function as fn
from dgl.nn import GATConv
import random
from timm.scheduler.cosine_lr import CosineLRScheduler

In [None]:
class MyDataset(DGLDataset):
        def __init__(self,data,thed):
            self.data=data
            self.thed=0.003
            self.label={
            'T-T':2,
            'P-T':1,
            'P-P':0
            }

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            path_1,label=self.data[idx]
            image=scio.loadmat(path_1)['fc']
            image=np.abs(image)
            adj=(image>=self.thed)
            image=torch.from_numpy(image).float()
            graph=nx.Graph(adj) # turn into graph
            for node in range(graph.number_of_nodes()):
                graph.nodes[node]['feat']=image[node]
            graph=dgl.from_networkx(graph,node_attrs=['feat'])
            label=torch.tensor(self.label[label])
            return graph,label

In [None]:
class GAT(nn.Module):
    def __init__(self,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 heads,
                 activation,
                 feat_drop,
                 attn_drop,
                 negative_slope,
                 residual):
        super(GAT, self).__init__()
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        # input projection (no residual)
        self.gat_layers.append(GATConv(
            in_dim, num_hidden, heads[0],
            feat_drop, attn_drop, negative_slope, False, self.activation))
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
            self.gat_layers.append(GATConv(
                num_hidden * heads[l-1], num_hidden, heads[l],
                feat_drop, attn_drop, negative_slope, residual, self.activation))
        # output projection
        self.gat_layers.append(nn.Linear(num_hidden*heads[-1],num_classes))

    def forward(self, g):
        h = g.ndata['feat']
        for l in range(self.num_layers):
            h = self.gat_layers[l](g, h).flatten(1)
        # output projection
        g.ndata['h']=h
        hg=dgl.mean_nodes(g,'h')
        logits = self.gat_layers[-1](hg)
        return logits

In [None]:
with open('fc_list.txt','r') as f:
    data=f.readlines()
data=[item.replace('\n','').split('\t') for item in data]
random.seed(2022)
random.shuffle(data)

In [None]:
result_train_loss,result_loss,result_acc,result_recall=[],[],[],[]
from sklearn.metrics import confusion_matrix,accuracy_score,recall_score,roc_auc_score,f1_score
result_index_matrix=torch.zeros(5,3,3)
k_fold=5
for kk in range(k_fold):
    k_fold_len = int(len(data)//k_fold)
    test_data = data[k_fold_len*kk:k_fold_len*(kk+1)]
    train_data = data[:k_fold_len*kk] + data[k_fold_len*(kk+1):]
    train_dataset=MyDataset(train_data,16)
    test_dataset=MyDataset(test_data,16)
    train_dataloader=GraphDataLoader(train_dataset, batch_size =16, shuffle=True)
    test_dataloader=GraphDataLoader(test_dataset, batch_size =1, shuffle=False)

    model=GAT(2,64,384,3,[8,8],torch.nn.ReLU(),0.2,0.2,False,False)
    optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)
    lr_schedule=CosineLRScheduler(optimizer=optimizer,t_initial=10,lr_min=1e-5,warmup_t=5)
    loss_fn= torch.nn.CrossEntropyLoss()
    epochs=60
    device=torch.device('cuda:1')
    loss_fn=loss_fn.to(device)
    model=model.to(device)
    index_matrix=torch.zeros(3,3)

    fold_loss,fold_acc,fold_recall,fold_train_loss=1.,0.,0.,1.
    fold_f1,fold_auc=0.,0.
    for epoch in range(epochs):
        pred_list=[]
        true_list=[]
        pred_prob_list=[]
        model.train()
        train_loss,test_acc,test_acc_2,test_loss=.0,.0,.0,0.
        for image,label in train_dataloader:
            image=image.to(device)
            label=label.to(device)
            image.requires_grad = True
            pred=model(image)
            loss=loss_fn(pred,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()
        lr_schedule.step(epoch)
        with torch.no_grad():
            model.eval()
            for image,label in test_dataloader:
                image=image.to(device)
                label=label.to(device)
                true_list.extend(label.tolist())
                image.requires_grad = True
                pred=model(image)
                pred_list.extend(pred.argmax(dim=1).tolist())
                pred_prob_list.extend(pred.tolist())
                loss=loss_fn(pred,label)
                acc = (pred.argmax(dim=1) == label).float().mean()
        test_acc=accuracy_score(true_list,pred_list)
        test_recall=recall_score(true_list,pred_list,average='micro')
        test_f1=f1_score(true_list,pred_list,average='micro')
        c_matrix=confusion_matrix(true_list,pred_list)
        for idm in range(3):
            index_matrix[0][idm]=c_matrix[idm][idm]/np.sum(c_matrix[:,idm])
            index_matrix[1][idm]=c_matrix[idm][idm]/np.sum(c_matrix[idm,:])
            index_matrix[2][idm]=(np.sum(c_matrix)-np.sum(c_matrix[:,idm])-np.sum(c_matrix[idm,:])+c_matrix[idm][idm])/(np.sum(c_matrix)-np.sum(c_matrix[idm,:]))
        true_list=np.eye(pred.shape[1])[true_list]
        test_auc=roc_auc_score(true_list,pred_prob_list)
        # print('Epoch: {:2d}  Train Loss: {:.4f}  Test Loss: {:.4f}  Test Acc: {:.4f}'.format(epoch,train_loss/len(train_dataloader),test_loss/len(test_dataloader),test_acc/len(test_dataloader)))
        if test_acc >= fold_acc and (test_f1+test_auc) > (fold_f1+fold_auc):
            fold_acc=test_acc
            result_index_matrix[kk]=index_matrix
            fold_recall=test_recall
    print('Fold{:.0f}:test acc:{:.4f}   test recall:{:.4f}]'.format(kk+1,fold_acc,fold_recall))
    result_train_loss.append(fold_train_loss)
    result_loss.append(fold_loss)
    result_acc.append(fold_acc)
    result_recall.append(fold_recall)
    print(result_index_matrix[kk])
# print('Result: [train loss:{:.4f}  test loss:{:.4f}  test acc:{:.4f}]'.format(np.mean(result_train_loss),np.mean(result_loss),np.mean(result_acc)))
print('Result: [mean:{:.4f}  std:{:.4f}] [mean:{:.4f}  std:{:.4f}]]'.format(np.mean(result_recall),np.std(result_recall),np.mean(result_acc),np.std(result_acc)))
