In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import copy
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.nn.parameter import Parameter
from os.path import join as pjoin
import pandas as pd

plt.rcParams['figure.figsize'] = (15.0, 9.0)
plt.rcParams['figure.dpi'] = 100

In [3]:
# Hyper-parameters
dataset = 'PROTEINS2'
dataset = './graph_nn-master/graph_nn-master/data/%s'%dataset
n_folds = 1  
training_size_p = 0.8
balance = False
batch_size = 64
epochs = 30
lr = 0.001
wdecay = 1e-4
model_name = 'GAT'
#GAT
n_hidden = 8
dropout_gat = 0.6
alpha_leakyReLU = 0.2
n_att = 8
gat_out_dim = 8
#information pooling
pooling_method = 'max,sum,mean' #max,sum,mean
#fc
n_hidden_fc = '32,16' # 'None' or '32,16'
if n_hidden_fc != 'None':
    n_hidden_fc = list(map(int,n_hidden_fc.strip().split(',')))
fc_activation = eval('nn.'+'ReLU'+'(inplace=True)')#'ELU' #'ReLU' #'Identity'    
dropout_fc = 0.3
fc_bias = True

#device
device = 'cpu'  # 'cuda', 'cpu'
seed = 'Random'
threads = 0 #線程數目
log_interval = 1

#output folder
output_folder = dataset


In [4]:
# Data loader and reader
class GraphData(torch.utils.data.Dataset):
    def __init__(self,datareader,fold_id,split):
        self.fold_id = fold_id #預設0，只執行一次
        self.split = split #"train" or "test"
        self.rnd_state = datareader.rnd_state
        self.set_fold(datareader.data, fold_id) #利用方法，建立屬性。set_fold()在下面
        
        
    def set_fold(self, data, fold_id):
        self.total = len(data['labels']) #graph數目
        self.N_nodes_max = data['N_nodes_max'] #最多node的graph之node數目
        self.n_classes = data['n_classes'] #graph分類的種類數目
        self.features_dim = data['features_dim'] #node的feature數目
        self.idx = data['splits'][fold_id][self.split]#train或test的index
        self.labels = copy.deepcopy([data['labels'][i] for i in self.idx])#特定index(train or test)下的graph labels
        self.adj_list = copy.deepcopy([data['adj_list'][i] for i in self.idx])#特定index(train or test)下的A矩陣
        self.features_onehot = copy.deepcopy([data['features_onehot'][i] for i in self.idx])#特定index(train or test)下的node feature
        print('%s: %d/%d' % (self.split.upper(), len(self.labels), len(data['labels'])))
        self.indices = np.arange(len(self.idx))  # sample indices for this epoch(for這次epoch，index從新編碼)
        self.label_to_target = data['label_to_target']
        self.node_idx_to_id = data['node_idx_to_id']
        self.targets = data['targets']
        
    def pad(self, mtx, desired_dim1, desired_dim2=None, value=0):
        sz = mtx.shape
        assert len(sz) == 2, ('only 2d arrays are supported', sz)
        if desired_dim2 is not None:
            mtx = np.pad(mtx, ((0, desired_dim1 - sz[0]), (0, desired_dim2 - sz[1])), 'constant', constant_values=value)
        else:
            mtx = np.pad(mtx, ((0, desired_dim1 - sz[0]), (0, 0)), 'constant', constant_values=value)
        return mtx
    
    def nested_list_to_torch(self, data):
        if isinstance(data, dict):
            keys = list(data.keys())           
        for i in range(len(data)):
            if isinstance(data, dict):
                i = keys[i]
            if isinstance(data[i], np.ndarray):
                data[i] = torch.from_numpy(data[i]).float()
            elif isinstance(data[i], list):
                data[i] = list_to_torch(data[i])
        return data
        
    def __len__(self): #__len__:未來可以len(類別)，呼叫下面code
        return len(self.labels)

    def __getitem__(self, index):#__getitem__:未來這個類別可以使用[]索引，來完成下面code
        index = self.indices[index]
        N_nodes_max = self.N_nodes_max
        N_nodes = self.adj_list[index].shape[0]
        graph_support = np.zeros(self.N_nodes_max)
        graph_support[:N_nodes] = 1
        #1.把features捕到620,預設補0
        #2.把adj補到620*620,預設補0
        #3.graph_support: mask
        #4.每個圖的真正nodes數
        return self.nested_list_to_torch([self.pad(self.features_onehot[index].copy(), self.N_nodes_max),  # node_features
                                          self.pad(self.adj_list[index], self.N_nodes_max, self.N_nodes_max),  # adjacency matrix
                                          graph_support,  # mask with values of 0 for dummy (zero padded) nodes, otherwise 1 
                                          N_nodes,
                                          int(self.labels[index]),
                                          self.idx[index]])  # convert to torch

class DataReader():
    def __init__(self,
                 data_dir, 
                 rnd_state=None, 
                 training_size_p=None,
                 folds=None,
                 balance=None):
        self.data_dir = data_dir
        self.rnd_state = np.random.RandomState() if rnd_state == 'Random' else np.random.RandomState(int(rnd_state))
        
        files = os.listdir(self.data_dir)
        
        #data starage!
        data = {}
        #1. nodes:為dict，{node_id:graph_id}
        #2. graphs:為dict,{graph_id:np.array([node_id 1,node_id 2,...])}
        nodes, graphs = self.read_graph_nodes_relations(list(filter(lambda f: f.find('graph_indicator') >= 0, files))[0])
        #3. data['node_id_to_idx']
        node_id_to_idx, node_idx_to_id= self.read_node_ID(list(filter(lambda f: f.find('node_features') >= 0, files))[0])
        data['node_id_to_idx'] = node_id_to_idx
        data['node_idx_to_id'] = node_idx_to_id
        #4. data['features_onehot']
        data['features_onehot'] = self.read_node_features(list(filter(lambda f: f.find('node_features') >= 0, files))[0], nodes, graphs)  
        #data['adj_list']
        data['adj_list'] = self.read_graph_adj(list(filter(lambda f: f.find('_A') >= 0, files))[0], nodes, graphs,node_id_to_idx) 
        #data['labels'] 0開始
        target_to_label = {}
        label_to_target = {}
        targets = np.array(self.parse_txt_file(list(filter(lambda f: f.find('graph_labels') >= 0, files))[0], 
                                      line_parse_fn=lambda s: s.strip()))
        data['targets'] = targets
        target_category = sorted(list(set(targets)))
        for l, t in enumerate(target_category): 
            target_to_label[t] = l
            label_to_target[l] = t
        data['labels'] = np.array([target_to_label[t] for t in targets])
        data['target_to_label'] = target_to_label
        data['label_to_target'] = label_to_target
        n_edges, degrees = [], []
        for sample_id, adj in enumerate(data['adj_list']):
            N = len(adj)  # number of nodes
            n = np.sum(adj)  # total sum of edges
            n_edges.append(int(n/2))  # undirected edges, so need to divide by 2
            degrees.extend(list(np.sum(adj, 1)))
        features_dim = len(data['features_onehot'][0][0])
        shapes = [len(adj) for adj in data['adj_list']]
        N_nodes_max = np.max(shapes)
        classes = target_category
        n_classes = len(target_category)

        print('N nodes avg/std/min/max: \t%.2f/%.2f/%d/%d' % (np.mean(shapes), np.std(shapes), np.min(shapes), np.max(shapes)))
        print('N edges avg/std/min/max: \t%.2f/%.2f/%d/%d' % (np.mean(n_edges), np.std(n_edges), np.min(n_edges), np.max(n_edges)))
        print('Node degree avg/std/min/max: \t%.2f/%.2f/%d/%d' % (np.mean(degrees), np.std(degrees), np.min(degrees), np.max(degrees)))
        print('Node features dim: \t\t%d' % features_dim)
        print('N classes: \t\t\t%d' % n_classes)
        print('Classes: \t\t\t%s' %(', '.join(classes)))
        for lbl in classes:
            print('Class %s: \t\t\t%s samples' % (lbl, np.sum(targets == lbl)))
        #判斷每個資料中，graph數量是否相等
        N_graphs = len(data['labels']) 
        assert N_graphs == len(data['adj_list']) == len(data['features_onehot']), 'invalid data'

        train_ids, test_ids = self.split_ids(data['labels'], rnd_state=self.rnd_state, training_size_p=training_size_p,
                                             folds=n_folds, balance=balance)
        splits = [] #塞入dict('train':[index...],'test':[index...])
        for fold in range(folds):
            splits.append({'train': train_ids[fold],
                           'test': test_ids[fold]})
        
        data['splits'] = splits #folds份的train和test之index
        data['N_nodes_max'] = N_nodes_max
        data['features_dim'] = features_dim
        data['n_classes'] = n_classes #graph label種類數目
        
        self.data = data # data為一個dict()

    def split_ids(self, labels_all, rnd_state=None,folds=1, training_size_p=None, balance=False):
        if folds == 1:
            if balance == True:
                classes = list(set(labels_all))
                classes_dict = dict()
                for i in classes:
                    classes_dict[i] = []
                for idx,l in enumerate(labels_all):
                    classes_dict[l].append(idx)
                min_classes_n = len(labels_all)
                for i in classes:
                    if len(classes_dict[i]) < min_classes_n:
                        min_classes_n = len(classes_dict[i])
                training_size_per_class = int(np.round(min_classes_n*training_size_p))
                ids_all = np.arange(len(labels_all))
                ids = ids_all[rnd_state.permutation(len(ids_all))]
                train_ids = []
                for i in classes:
                    class_ls = np.array(classes_dict[i])
                    sampling = class_ls[rnd_state.permutation(len(class_ls))][0:training_size_per_class]

                    train_ids.extend(sampling)
                test_ids = [np.array([e for e in ids if e not in train_ids])]    
                train_ids = [np.array(train_ids)]
            else:
                ids_all = np.arange(len(labels_all))
                n = len(ids_all) #n:graph的數目
                ids = ids_all[rnd_state.permutation(n)]
                testing_size = int(np.round(n*(1-training_size_p)))
                test_ids = ids[0:testing_size] # 包著np.array()
                train_ids = [np.array([e for e in ids if e not in test_ids])] # 包著np.array()
                test_ids = [test_ids]
        elif folds > 1:
            ids_all = np.arange(len(labels_all))
            n = len(ids_all)
            ids = ids_all[rnd_state.permutation(n)]
            stride = int(np.ceil(n / float(folds)))
            test_ids = [ids[i: i + stride] for i in range(0, n, stride)]
            assert np.all(np.unique(np.concatenate(test_ids)) == sorted(ids_all)), 'some graphs are missing in the test sets'
            assert len(test_ids) == folds, 'invalid test sets'
            train_ids = []
            for fold in range(folds):
                train_ids.append(np.array([e for e in ids if e not in test_ids[fold]]))
                assert len(train_ids[fold]) + len(test_ids[fold]) == len(np.unique(list(train_ids[fold]) + list(test_ids[fold]))) == n, 'invalid splits'

        return train_ids, test_ids

    def parse_txt_file(self, fpath, line_parse_fn=None):
        #pjoin=os.path.join:路徑拼接
        #os.path.join([PATH_1], [PATH_2], [PATH_3], ...)-->return:[PATH_1]/[PATH_2]/[PATH_3]
        with open(pjoin(self.data_dir, fpath), 'r') as f:
            lines = f.readlines()
        #if line_parse_fn is not None else s:代表如果有處理字串函數就執行，否則就保留原本的樣子
        data = [line_parse_fn(s) if line_parse_fn is not None else s for s in lines]
        return data
    
    def read_graph_adj(self, fpath, nodes, graphs, node_id_to_idx):
        def fn_read_graph_adj(s):
            if ',' in s:
                return s.strip().split(',')
            else:
                return s.strip().split()
        edges = self.parse_txt_file(fpath, line_parse_fn=fn_read_graph_adj)
        adj_dict = {}
        for edge in edges:
            node1 = node_id_to_idx[edge[0].strip()]
            node2 = node_id_to_idx[edge[1].strip()]
            graph_id = nodes[node1]
            assert graph_id == nodes[node2], ('invalid data', graph_id, nodes[node2])
            
            if graph_id not in adj_dict:
                n = len(graphs[graph_id])
                adj_dict[graph_id] = np.zeros((n, n))
            ind1 = np.where(graphs[graph_id] == node1)[0]
            ind2 = np.where(graphs[graph_id] == node2)[0]
            assert len(ind1) == len(ind2) == 1, (ind1, ind2)
            adj_dict[graph_id][ind1, ind2] = 1
            adj_dict[graph_id][ind2, ind1] = 1
        adj_list = [adj_dict[graph_id] for graph_id in sorted(list(graphs.keys()))]        
        return adj_list
        
    #graph_indicator
    def read_graph_nodes_relations(self, fpath):
        #node從0開始
        #graph沒限定，但要是整數
        graph_ids = self.parse_txt_file(fpath, line_parse_fn=lambda s: int(s.rstrip()))
        nodes, graphs = {}, {}
        for node_id, graph_id in enumerate(graph_ids):
            if graph_id not in graphs:
                graphs[graph_id] = []
            graphs[graph_id].append(node_id)
            nodes[node_id] = graph_id
        graph_ids = np.unique(list(graphs.keys()))
        for graph_id in graphs:
            graphs[graph_id] = np.array(graphs[graph_id])
        return nodes, graphs

    def read_node_features(self, fpath, nodes, graphs):
        def fn_read_node_features(s):
            if ',' in s:
                return list(map(float,(s.strip().split(',')[1:])))
            else:
                return list(map(float,(s.strip().split()[1:])))
        node_features_all = self.parse_txt_file(fpath, line_parse_fn=fn_read_node_features)
        node_features = {}
        #node_features:資料格式和graphs相似
        for node_id, x in enumerate(node_features_all):
            graph_id = nodes[node_id]
            if graph_id not in node_features:
                node_features[graph_id] = [ None ] * len(graphs[graph_id])
            ind = np.where(graphs[graph_id] == node_id)[0]
            #assert 判斷式, 如果有誤回傳的內容
            assert len(ind) == 1, ind
            assert node_features[graph_id][ind[0]] is None, node_features[graph_id][ind[0]]
            node_features[graph_id][ind[0]] = x
        node_features_lst = [np.array(node_features[graph_id]) for graph_id in sorted(list(graphs.keys()))]
        return node_features_lst
    
    def read_node_ID(self, fpath):
        def fn_read_node_ID(s):
            if ',' in s:
                return s.strip().split(',')[0]
            else:
                return s.strip().split()[0]
        node_ID_all = self.parse_txt_file(fpath, line_parse_fn=fn_read_node_ID)
        assert len(node_ID_all) == len(set(node_ID_all))
        
        node_id_to_idx = {}#str:int
        node_idx_to_id = {}
        for node_idx, node_id in enumerate(node_ID_all):
            node_id_to_idx[node_id] = node_idx
            node_idx_to_id[node_idx] = node_id
        return node_id_to_idx, node_idx_to_id

In [7]:
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.W = nn.Linear(in_features=in_features, out_features=out_features,bias=False)
        self.a = nn.Linear(in_features=2*out_features, out_features=1,bias=False)

        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.concatenate = torch.cat
        self.elu = nn.ELU(inplace=True)
    def forward(self, data):
        inp, adj = data[:2]
        h = self.W(inp)
        n_batch = h.size()[0]
        N = h.size()[1] #node size
        a_inp = self.concatenate([h.repeat(1,1,N).view(n_batch,N*N,-1),h.repeat(1,N,1)],dim=1).view(n_batch,N,-1,2*self.out_features)
        e = self.leakyrelu(self.a(a_inp)).squeeze(3) #shape=[batch size, node size, node size]
        
        zero_vec = -9e15*torch.ones_like(e)
        I = torch.eye(N).unsqueeze(0).to(device)
        adj_I = adj+I
        attention = torch.where(adj_I>0, e, zero_vec) #沒鄰接:負很大，有鄰接:e
        attention = F.softmax(attention, dim=2) #負很大->softmax->0
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)
        if self.concat:
            return self.elu(h_prime)
        else:
            return h_prime       
    def extra_repr(self):
        lines = []
        lines.append('(hidden_features): Linear(in_features=%s, out_features=%s, bias=False)'%(self.in_features,self.out_features))
        lines.append('(attetion): Attetion(')
        lines.append('  (concat_ij): Concat(in_features=%s, out_features=%s*2)'%(self.out_features,self.out_features))
        lines.append('  (a): Linear(in_features=%s, out_features=1, bias=False)'%(self.out_features*2))
        lines.append('  (leakyrelu): LeakyReLU(negative_slope=%s)'%(self.alpha))
        lines.append('  (concat_edges): Concat(in_features=1, out_features=1-hop edges)')
        lines.append('  (softmax): Softmax(in_features=1-hop edges, out_features=1-hop edges)')
        lines.append('  (dropout_attetion): Dropout(p=%s)'%self.dropout)
        lines.append(')')
        lines.append('(weighted_hidden_features): Matmul(attention, hidden_features)')
        if self.concat:
            lines.append('(elu): ELU(alpha=1.0)')        
        lines = '\n'.join(lines)
        return lines
        

    
class GAT(nn.Module):
    def __init__(self, 
                 nfeat, 
                 nhid, 
                 nclass, 
                 dropout, 
                 alpha, 
                 nheads,
                 gat_out_dim,
                 pooling_method,
                 n_hidden_fc, 
                 dropout_fc,
                 fc_bias,
                 fc_activation):
        super(GAT, self).__init__()
        self.dropout = dropout
        self.nhid = nhid
        self.nheads = nheads
        self.nclass = nclass
        self.pooling_method = pooling_method
        self.n_hidden_fc = n_hidden_fc
        self.dropout_fc = dropout_fc
        self.fc_bias = fc_bias
        self.fc_activation = fc_activation
        
        #GAT
        self.attentions = [GraphAttentionLayer(in_features=nfeat,
                                               out_features=nhid, 
                                               dropout=dropout, 
                                               alpha=alpha, 
                                               concat=True) for _ in range(nheads)]     
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)
        self.out_att = GraphAttentionLayer(nhid * nheads, gat_out_dim, dropout=dropout, alpha=alpha, concat=False)
        
        #fc
        fc = []
        pooling_noumber = len(pooling_method.split(','))      
        if n_hidden_fc != 'None':
            for layer, f in enumerate(n_hidden_fc):
                if dropout_fc > 0:
                    fc.append(nn.Dropout(p=dropout_fc))
                else:
                    fc.append(nn.Identity())
                if layer == 0:
                    fc.append(nn.Linear(gat_out_dim*pooling_noumber, n_hidden_fc[layer], bias=fc_bias)) 
                    fc.append(fc_activation)
                else:
                    fc.append(nn.Linear(n_hidden_fc[layer-1], n_hidden_fc[layer], bias=fc_bias))   
                    fc.append(fc_activation)
            n_last = n_hidden_fc[-1]
        else:
            n_last = gat_out_dim*pooling_noumber
            
        #last layer
        if dropout_fc > 0:
            fc.append(nn.Dropout(p=dropout_fc))
        else:
            fc.append(nn.Identity())            
        fc.append(nn.Linear(n_last, nclass, bias=fc_bias))
        self.fc = nn.Sequential(*fc) 
                
        
    def forward(self, data):
        mask = data[2].clone()
        N_nodes = torch.sum(mask, dim=1).reshape(len(torch.sum(mask, dim=1)),1)        
        x, adj = data[:2]
                      
        #GAT
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att((x, adj)) for att in self.attentions], dim=2)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att((x, adj)))
        
        #pooling
        pooling_ls = []
        if 'max' in pooling_method:
            max_pooling = torch.max(x, 1)[0]
            pooling_ls.append(max_pooling)
        if 'sum' in pooling_method:
            sum_pooling = torch.sum(x, 1)
            pooling_ls.append(sum_pooling)
        if 'mean' in pooling_method:
            mean_pooling = torch.sum(x, 1)/N_nodes
            pooling_ls.append(mean_pooling)
        x = torch.cat(pooling_ls,1)  
        
        #fc
        x = self.fc(x) 
        x = F.log_softmax(x, dim=1)
        
        return x    
    
    def extra_repr(self):
        lines = []
        lines.append('GAT(')
        lines.append('  (dropout): Dropout(p=%s)'%self.dropout)
        lines.append('  ->')
        for i, att in enumerate(self.attentions):
            lines.append('  (attention_%d): GraphAttentionLayer('%(i))
            for j in att.extra_repr().split('\n'):
                lines.append('    '+j)
            lines.append('  )')
        lines.append('  ->')             
        lines.append('  (concat_attention_0-%s): Concat(in_features=%s, out_features=%s*%slayers)'%(self.nheads-1,
                                                                                    self.nhid, self.nhid, self.nheads))
        lines.append('  ->')               
        lines.append('  (dropout): Dropout(p=%s)'%self.dropout)
        lines.append('  ->')  
        lines.append('  (attention_out): GraphAttentionLayer(')
        for j in self.out_att.extra_repr().split('\n'):
            lines.append('    '+j)
        lines.append('  )')
        lines.append('  ->')
        lines.append('  (elu): ELU(alpha=1.0)')  
        lines.append(')')
        lines.append('->')
        lines.append('Pooling(')
        lines.append('  (concat): Concat(%s)'%self.pooling_method)
        lines.append(')')
        lines.append('->')
        lines.append('FullyConnected(')
        for i in model.fc:
            i = str(i)
            if i[0]=='D':
                i = '  (dropout): ' + i
            elif i[0]=='L':
                i = '  (fc): ' + i
            elif i[0]=='R':
                i = '  (activation): ' + i     
            lines.append(i)    
        lines.append('  (softmax): Softmax(in_features=%s, out_features=%s)'%(self.nclass,self.nclass))
        lines.append(')')    
        lines='\n'.join(lines)
        
        return lines

In [None]:
print('Loading data')

datareader = DataReader(data_dir=dataset, 
                        rnd_state=seed,training_size_p=training_size_p,folds=n_folds,balance=balance)

train_acc_folds = []
test_acc_folds = []
for fold_id in range(n_folds):
    print('\nFOLD', fold_id+1)
    loaders = []
    for split in ['train', 'test']:
        #製作"train"或"test" graph data
        gdata = GraphData(fold_id=fold_id, datareader=datareader, split=split)
        loader = torch.utils.data.DataLoader(gdata, 
                                             batch_size=batch_size,
                                             shuffle=split.find('train') >= 0,
                                             num_workers=threads)
        loaders.append(loader)
        if split == 'train':
            training_size = len(gdata.idx)
    if model_name == 'GAT':
        model = GAT(nfeat=loaders[0].dataset.features_dim,
                    nhid=n_hidden,
                    nclass=loaders[0].dataset.n_classes,
                    dropout=dropout_gat,
                    alpha=alpha_leakyReLU,
                    nheads=n_att,
                    gat_out_dim=gat_out_dim,
                    pooling_method=pooling_method,
                    n_hidden_fc=n_hidden_fc, 
                    dropout_fc=dropout_fc,
                    fc_bias=fc_bias,
                    fc_activation=fc_activation).to(device)    

    print('\nInitialize model')
    print(model.extra_repr())
    c = 0
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        c += p.numel()
    print('N trainable parameters:', c)

    optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=lr,
                weight_decay=wdecay,
                betas=(0.5, 0.999))
    scheduler = lr_scheduler.MultiStepLR(optimizer, [20, 30], gamma=0.1)


    def train(train_loader):
        scheduler.step()#每個batch就會改變學習率
        model.train()
        start = time.time()
        train_loss, correct, n_samples = 0, 0, 0
        train_loss_batch_ls = []
        train_acc_batch_ls = []
        for batch_idx, data in enumerate(train_loader):
            for i in range(len(data)):
                data[i] = data[i].to(device)
            optimizer.zero_grad()
            output = model(data)     
            loss = loss_fn(output, data[4])
            loss.backward()
            optimizer.step()
            time_iter = time.time() - start
            train_loss += loss.item() * len(output)
            n_samples += len(output)
            pred = output.detach().cpu().max(1, keepdim=True)[1]
            correct += pred.eq(data[4].detach().cpu().view_as(pred)).sum().item()
            acc = 100. * correct / n_samples
            train_loss_batch_ls.append(train_loss/n_samples)
            train_acc_batch_ls.append(acc/100)
            if batch_idx % log_interval == 0 or batch_idx == len(train_loader) - 1:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}(avg: {:.4f})\tAcc: {:.2f}%({}/{}) \tsec/iter: {:.4f}'.format(
                    epoch, n_samples, len(train_loader.dataset),
                    100. * (batch_idx + 1) / len(train_loader), loss.item(), train_loss / n_samples, 
                    acc, correct, n_samples, time_iter / (batch_idx + 1) ))    
        return train_loss_batch_ls, train_acc_batch_ls
    def test(test_loader):
        model.eval()
        start = time.time()
        test_loss, correct, n_samples = 0, 0, 0
        for batch_idx, data in enumerate(test_loader):
            for i in range(len(data)):
                data[i] = data[i].to(device)
            output = model(data)
            loss = loss_fn(output, data[4], reduction='sum')
            test_loss += loss.item()
            n_samples += len(output)
            pred = output.detach().cpu().max(1, keepdim=True)[1]
            correct += pred.eq(data[4].detach().cpu().view_as(pred)).sum().item()
        time_iter = time.time() - start
        test_loss /= n_samples
        acc = 100. * correct / n_samples
        print('Test set (epoch {}): Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(epoch+1, 
                                                                                              test_loss, 
                                                                                              correct, 
                                                                                              n_samples, acc))
        return test_loss,acc/100

    def predict(loader_full):
        idx_ls = []
        pred_ls = [] 
        label_ls = []
        length_ls = []
        output_ls = []
        print('[Trained Model]')
        for i in [0,1]:
            model.eval()     
            pred_tmp = []
            label_tmp = []
            for batch_idx, data in enumerate(loader_full[i]):
                for j in range(len(data)):
                    data[j] = data[j].to(device)
                output = model(data)
                idx_ls.extend(data[5].tolist())
                pred = output.detach().cpu().max(1, keepdim=True)[1]
                pred_ls.extend(pred.reshape(pred.shape[0]).tolist())
                label_ls.extend(data[4].tolist())
                output_ls.extend(output)
                pred_tmp.extend(pred.reshape(pred.shape[0]).tolist())
                label_tmp.extend(data[4].tolist())
            total = len(pred_tmp)
            c = sum(np.array(pred_tmp)==np.array(label_tmp)) 
            if i==0:
                print('Training Set: Accuracy=%.2f%%(%s/%s)'%(c*100/total,c,total))
                train_acc_folds.append(c*100/total)
            elif i==1:
                print('Testing Set: Accuracy=%.2f%%(%s/%s)'%(c*100/total,c,total))  
                test_acc_folds.append(c*100/total)
            length_ls.append(total)
        return idx_ls, pred_ls, label_ls, length_ls, output_ls

    train_loss_ls = []
    train_acc_ls = []
    test_loss_ls = []
    test_acc_ls = []
    loss_fn = F.nll_loss
    for epoch in range(epochs):
        train_loss, train_acc = train(loaders[0])
        test_loss, test_acc = test(loaders[1])
        train_loss_ls.extend(train_loss)
        train_acc_ls.extend(train_acc)
        test_loss_ls.append(test_loss)
        test_acc_ls.append(test_acc)   

    idx_ls, pred_ls, label_ls, length_ls, output_ls = predict(loaders)

    #plot
    length_train = range(len(train_loss_ls))
    length_test = range(int(np.ceil(training_size/batch_size))-1,len(train_loss_ls),int(np.ceil(training_size/batch_size)))
    plt.plot(length_train,train_acc_ls,label='training accuracy')
    plt.plot(length_test,test_acc_ls,label='validation accuracy')
    x_ticks = [0]+list(length_test)
    plt.xticks(x_ticks,list(range(0,epochs+1)))
    plt.xlabel('epoch',fontsize=18)
    plt.ylabel('accuracy',fontsize=18)
    plt.legend(fontsize=12)
    plt.grid(linestyle='--')
    plt.savefig(output_folder+'/acc_fold%s.png'%(fold_id+1))
    plt.clf()

    plt.plot(length_train,train_loss_ls,label='training loss')
    plt.plot(length_test,test_loss_ls,label='validation loss')
    plt.xticks(x_ticks,list(range(0,epochs+1)))
    plt.xlabel('epoch',fontsize=18)
    plt.ylabel('loss',fontsize=18)
    plt.legend(fontsize=12)
    plt.grid(linestyle='--')
    plt.savefig(output_folder+'/loss_fold%s.png'%(fold_id+1))
    plt.clf()    

Loading data
N nodes avg/std/min/max: 	39.06/45.76/4/620
N edges avg/std/min/max: 	72.82/84.60/5/1049
Node degree avg/std/min/max: 	3.73/1.15/0/25
Node features dim: 		3
N classes: 			2
Classes: 			1, 2
Class 1: 			663 samples
Class 2: 			450 samples

FOLD 1
TRAIN: 890/1113
TEST: 223/1113

Initialize model
GAT(
  (dropout): Dropout(p=0.6)
  ->
  (attention_0): GraphAttentionLayer(
    (hidden_features): Linear(in_features=3, out_features=8, bias=False)
    (attetion): Attetion(
      (concat_ij): Concat(in_features=8, out_features=8*2)
      (a): Linear(in_features=16, out_features=1, bias=False)
      (leakyrelu): LeakyReLU(negative_slope=0.2)
      (concat_edges): Concat(in_features=1, out_features=1-hop edges)
      (softmax): Softmax(in_features=1-hop edges, out_features=1-hop edges)
      (dropout_attetion): Dropout(p=0.6)
    )
    (weighted_hidden_features): Matmul(attention, hidden_features)
    (elu): ELU(alpha=1.0)
  )
  (attention_1): GraphAttentionLayer(
    (hidden_feature



