In [1]:
import os
import csv
import numpy as np
import scipy.io as sio
import pandas as pd

from sklearn.linear_model import RidgeClassifier
from sklearn.feature_selection import RFE
from nilearn import connectome

from scipy.spatial import distance


# Reading and computing the input data

# Selected pipeline
pipeline = 'cpac'

# Input data variables
root_folder = '../ABIDE/'
data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal')
phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv')


def fetch_filenames(subject_IDs, file_type):

    """
        subject_list : list of short subject IDs in string format
        file_type    : must be one of the available file types

    returns:

        filenames    : list of filetypes (same length as subject_list)
    """

    import glob

    # Specify file mappings for the possible file types
    filemapping = {'func_preproc': '_func_preproc.nii.gz',
                   'rois_ho': '_rois_ho.1D'}

    # The list to be filled
    filenames = []

    # Fill list with requested file paths
    for i in range(len(subject_IDs)):
        os.chdir(data_folder)  # os.path.join(data_folder, subject_IDs[i]))
        try:
            filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0])
        except IndexError:
            # Return N/A if subject ID is not found
            filenames.append('N/A')

    return filenames


# Get timeseries arrays for list of subjects
def get_timeseries(subject_list, atlas_name):
    """
        subject_list : list of short subject IDs in string format
        atlas_name   : the atlas based on which the timeseries are generated e.g. aal, cc200

    returns:
        time_series  : list of timeseries arrays, each of shape (timepoints x regions)
    """

    timeseries = []
    for i in range(len(subject_list)):
        subject_folder = os.path.join(data_folder, subject_list[i])
        ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')]
        fl = os.path.join(subject_folder, ro_file[0])
        print("Reading timeseries file %s" %fl)
        timeseries.append(np.loadtxt(fl, skiprows=0))

    return timeseries


# Compute connectivity matrices
def subject_connectivity(timeseries, subject, atlas_name, kind, save=True, save_path=data_folder):
    """
        timeseries   : timeseries table for subject (timepoints x regions)
        subject      : the subject ID
        atlas_name   : name of the parcellation atlas used
        kind         : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation
        save         : save the connectivity matrix to a file
        save_path    : specify path to save the matrix if different from subject folder

    returns:
        connectivity : connectivity matrix (regions x regions)
    """

    print("Estimating %s matrix for subject %s" % (kind, subject))

    if kind in ['tangent', 'partial correlation', 'correlation']:
        conn_measure = connectome.ConnectivityMeasure(kind=kind)
        connectivity = conn_measure.fit_transform([timeseries])[0]

    if save:
        subject_file = os.path.join(save_path, subject,
                                    subject + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat')
        sio.savemat(subject_file, {'connectivity': connectivity})

    return connectivity


# Get the list of subject IDs
def get_ids(num_subjects=None):
    """

    return:
        subject_IDs    : list of all subject IDs
    """

    subject_IDs = np.genfromtxt(os.path.join(data_folder, 'subject_IDs.txt'), dtype=str)

    if num_subjects is not None:
        subject_IDs = subject_IDs[:num_subjects]

    return subject_IDs


# Get phenotype values for a list of subjects
def get_subject_score(subject_list, score):
    scores_dict = {}

    with open(phenotype) as csv_file:
        reader = csv.DictReader(csv_file)
        for row in reader:
            if row['SUB_ID'] in subject_list:
                scores_dict[row['SUB_ID']] = row[score]

    return scores_dict


# Dimensionality reduction step for the feature vector using a ridge classifier
def feature_selection(matrix, labels, train_ind, fnum):
    """
        matrix       : feature matrix (num_subjects x num_features)
        labels       : ground truth labels (num_subjects x 1)
        train_ind    : indices of the training samples
        fnum         : size of the feature vector after feature selection

    return:
        x_data      : feature matrix of lower dimension (num_subjects x fnum)
    """

    estimator = RidgeClassifier()
    selector = RFE(estimator, n_features_to_select=fnum, step=100, verbose=1)

    featureX = matrix[train_ind, :]
    featureY = labels[train_ind]
    selector = selector.fit(featureX, featureY.ravel())
    x_data = selector.transform(matrix)

    print("Number of labeled samples %d" % len(train_ind))
    print("Number of features selected %d" % x_data.shape[1])

    return x_data


# Make sure each site is represented in the training set when selecting a subset of the training set
def site_percentage(train_ind, perc, subject_list):
    """
        train_ind    : indices of the training samples
        perc         : percentage of training set used
        subject_list : list of subject IDs

    return:
        labeled_indices      : indices of the subset of training samples
    """

    train_list = subject_list[train_ind]
    sites = get_subject_score(train_list, score='SITE_ID')
    unique = np.unique(list(sites.values())).tolist()
    site = np.array([unique.index(sites[train_list[x]]) for x in range(len(train_list))])

    labeled_indices = []

    for i in np.unique(site):
        id_in_site = np.argwhere(site == i).flatten()

        num_nodes = len(id_in_site)
        labeled_num = int(round(perc * num_nodes))
        labeled_indices.extend(train_ind[id_in_site[:labeled_num]])

    return labeled_indices


# Load precomputed fMRI connectivity networks
def get_networks(subject_list, kind, atlas_name="aal", variable='connectivity'):
    """
        subject_list : list of subject IDs
        kind         : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation
        atlas_name   : name of the parcellation atlas used
        variable     : variable name in the .mat file that has been used to save the precomputed networks


    return:
        matrix      : feature matrix of connectivity networks (num_subjects x network_size)
    """

    all_networks = []
    for subject1 in subject_list:
        fl = f'../Datasets/all_fc_matrix_rois_cc400_2_a/matrix_rois_cc400_{subject1}.mat'
        try:  
            matrix = sio.loadmat(fl)[variable]
            all_networks.append(matrix)
        except FileNotFoundError:
            fl = f'../Datasets/all_fc_matrix_rois_cc400_2_a/matrix_rois_cc400_{50002}.mat'
            matrix = sio.loadmat(fl)[variable]
            all_networks.append(matrix)
            
            
    # all_networks=np.array(all_networks)

    idx = np.triu_indices_from(all_networks[0], 1)
    norm_networks = [np.arctanh(mat) if not np.all(np.abs(mat) == 1) else mat for mat in all_networks]
    vec_networks = [mat[idx] for mat in norm_networks]
    matrix = np.vstack(vec_networks)

    return matrix


# Construct the adjacency matrix of the population from phenotypic scores
def create_affinity_graph_from_scores(scores, pd_dict):
    num_nodes = len(pd_dict[scores[0]]) 
    graph = np.zeros((num_nodes, num_nodes))

    for l in scores:
        label_dict = pd_dict[l]

        if l in ['AGE_AT_SCAN', 'FIQ']:
            for k in range(num_nodes):
                for j in range(k + 1, num_nodes):
                    try:
                        val = abs(float(label_dict[k]) - float(label_dict[j]))
                        if val < 2:
                            graph[k, j] += 1
                            graph[j, k] += 1
                    except ValueError:  # missing label
                        pass

        else:
            for k in range(num_nodes):
                for j in range(k + 1, num_nodes):
                    if label_dict[k] == label_dict[j]:
                        graph[k, j] += 1
                        graph[j, k] += 1

    return graph

def get_static_affinity_adj(features, pd_dict):
    pd_affinity = create_affinity_graph_from_scores(['SEX', 'SITE_ID'], pd_dict) 
    distv = distance.pdist(features, metric='correlation') 
    dist = distance.squareform(distv)  
    sigma = np.mean(dist)
    feature_sim = np.exp(- dist ** 2 / (2 * sigma ** 2))
    adj = pd_affinity * feature_sim  

    return adj


In [2]:
import numpy as np
import scipy.sparse as sp
import torch


from sklearn.model_selection import StratifiedKFold
from scipy.spatial import distance
from scipy.sparse.linalg.eigen import eigsh


def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot

def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)

def get_train_test_masks(labels, idx_train, idx_val, idx_test):
    train_mask = sample_mask(idx_train, labels.shape[0])
    val_mask = sample_mask(idx_val, labels.shape[0])
    test_mask = sample_mask(idx_test, labels.shape[0])

    y_train = np.zeros(labels.shape)
    y_val = np.zeros(labels.shape)
    y_test = np.zeros(labels.shape)
    y_train[train_mask, :] = labels[train_mask, :]
    y_val[val_mask, :] = labels[val_mask, :]
    y_test[test_mask, :] = labels[test_mask, :]

    return y_train, y_val, y_test, train_mask, val_mask, test_mask

def load_data(subject_IDs, params): 
    
    # labels
    num_classes = 2
    num_nodes = len(subject_IDs)
    
    # , y
    y_data = np.zeros([num_nodes, num_classes])
    y = np.zeros([num_nodes, 1])
    
    labels = get_subject_score(subject_IDs, score='DX_GROUP')
    features = get_networks(subject_IDs, kind=params['connectivity'], atlas_name=params['atlas'])
    
    for i in range(num_nodes):
        y_data[i, int(labels[subject_IDs[i]]) - 1] = 1 # (871,2)
        y[i] = int(labels[subject_IDs[i]]) # (871,)
        
    skf = StratifiedKFold(n_splits=10)
    cv_splits = list(skf.split(features, np.squeeze(y)))
    train = cv_splits[params['folds']][0]
    test = cv_splits[params['folds']][1]
    val = test
    
    print('Number of train sample:{}' .format(len(train)))
        
    y_train, y_val, y_test, train_mask, val_mask, test_mask = get_train_test_masks(y_data, train, val, test)
    
    y_data = torch.LongTensor(np.where(y_data)[1])
    y = torch.LongTensor(y)
    y_train = torch.LongTensor(y_train[1])
    y_val = torch.LongTensor(y_val[1])
    y_test = torch.LongTensor(y_test[1])
    
    train = torch.LongTensor(train)
    val = torch.LongTensor(val)
    test = torch.LongTensor(test)
    train_mask = torch.LongTensor(train_mask)
    val_mask = torch.LongTensor(val_mask)
    test_mask = torch.LongTensor(test_mask)
    
    # Eigenvector
    labeled_ind = site_percentage(train, params['num_training'], subject_IDs)
    x_data = feature_selection(features, y, labeled_ind, params['num_features'])
    features = preprocess_features(sp.coo_matrix(x_data).tolil())
    features = torch.FloatTensor(np.array(features.todense()))
    
    # Adjacency matrix
    graph = create_affinity_graph_from_scores(['SEX', 'SITE_ID'], subject_IDs)
    distv = distance.pdist(x_data, metric='correlation')
    dist = distance.squareform(distv)
    sigma = np.mean(dist)
    sparse_graph = np.exp(- dist ** 2 / (2 * sigma ** 2))
    final_graph = graph * sparse_graph

    return final_graph, features, y, y_data, y_train, y_val, y_test, train, val, test, train_mask, val_mask, test_mask


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def sparse_to_tuple(sparse_mx):
    """Convert sparse matrix to tuple representation."""
    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):
            mx = mx.tocoo()
        coords = np.vstack((mx.row, mx.col)).transpose()
        values = mx.data
        shape = mx.shape
        coords = torch.from_numpy(coords)
        values = torch.from_numpy(values)
        shape = torch.tensor(shape)
        return coords, values, shape

    if isinstance(sparse_mx, list):
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)

    return sparse_mx

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def preprocess_features(features):
    """Row-normalize feature matrix"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return features

def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()

def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model."""
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0]))
    return adj_normalized

def chebyshev_polynomials(adj, k):
    """Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices (tuple representation)."""
    print("Calculating Chebyshev polynomials up to order {}...".format(k))

    adj_normalized = normalize_adj(adj)
    laplacian = sp.eye(adj.shape[0]) - adj_normalized
    largest_eigval, _ = eigsh(laplacian, 1, which='LM')
    scaled_laplacian = (2. / largest_eigval[0]) * laplacian - sp.eye(adj.shape[0])

    t_k = list()
    t_k.append(sp.eye(adj.shape[0]))
    t_k.append(scaled_laplacian)

    def chebyshev_recurrence(t_k_minus_one, t_k_minus_two, scaled_lap):
        s_lap = sp.csr_matrix(scaled_lap, copy=True)
        return 2 * s_lap.dot(t_k_minus_one) - t_k_minus_two

    for i in range(2, k+1):
        t_k.append(chebyshev_recurrence(t_k[-1], t_k[-2], scaled_laplacian))

    return t_k



  from scipy.sparse.linalg.eigen import eigsh


In [3]:
import numpy as np
import torch
# from utils import preprocess_features
from sklearn.model_selection import StratifiedKFold


class dataloader():
    def __init__(self): 
        self.pd_dict = {}
        self.node_ftr_dim = 2000
        self.num_classes = 2 

    def load_data(self, params, connectivity='correlation', atlas='cc400'):
        ''' load multimodal data from ABIDE
        return: imaging features (raw), labels, non-image data
        '''
        subject_IDs = get_ids()
        labels = get_subject_score(subject_IDs, score='DX_GROUP')
        num_nodes = len(subject_IDs)

        sites = get_subject_score(subject_IDs, score='SITE_ID')
        unique = np.unique(list(sites.values())).tolist()
        ages = get_subject_score(subject_IDs, score='AGE_AT_SCAN')
        genders = get_subject_score(subject_IDs, score='SEX') 

        y_onehot = np.zeros([num_nodes, self.num_classes])
        y = np.zeros([num_nodes])
        site = np.zeros([num_nodes], dtype=int)
        age = np.zeros([num_nodes], dtype=np.float32)
        gender = np.zeros([num_nodes], dtype=int)
        for i in range(num_nodes):
            y_onehot[i, int(labels[subject_IDs[i]])-1] = 1
            y[i] = int(labels[subject_IDs[i]])
            site[i] = unique.index(sites[subject_IDs[i]])
            age[i] = float(ages[subject_IDs[i]])
            gender[i] = genders[subject_IDs[i]]
        
        self.y = y -1  

        self.raw_features = get_networks(subject_IDs, kind=connectivity, atlas_name=atlas)

        phonetic_data = np.zeros([num_nodes, 3], dtype=np.float32)
        phonetic_data[:,0] = site 
        phonetic_data[:,1] = gender 
        phonetic_data[:,2] = age 

        self.pd_dict['SITE_ID'] = np.copy(phonetic_data[:,0])
        self.pd_dict['SEX'] = np.copy(phonetic_data[:,1])
        self.pd_dict['AGE_AT_SCAN'] = np.copy(phonetic_data[:,2]) 
        
        return self.raw_features, self.y, phonetic_data

    def data_split(self, n_folds):
        # split data by k-fold CV
        skf = StratifiedKFold(n_splits=n_folds)
        cv_splits = list(skf.split(self.raw_features, self.y))
        return cv_splits 

    def get_node_features(self, train_ind):
        '''preprocess node features for wl-deepgcn
        '''
        node_ftr = feature_selection(self.raw_features, self.y, train_ind, self.node_ftr_dim)
        self.node_ftr = preprocess_features(node_ftr) 
        return self.node_ftr

    def get_WL_inputs(self, nonimg):
        '''get WL inputs for wl-deepgcn 
        '''
        # construct edge network inputs 
        n = self.node_ftr.shape[0] 
        num_edge = n*(1+n)//2 - n  # n*(n-1)//2,HO=6105
        pd_ftr_dim = nonimg.shape[1]
        edge_index = np.zeros([2, num_edge], dtype=np.int64) 
        edgenet_input = np.zeros([num_edge, 2*pd_ftr_dim], dtype=np.float32)  
        aff_score = np.zeros(num_edge, dtype=np.float32)
        # static affinity score used to pre-prune edges 
        aff_adj = get_static_affinity_adj(self.node_ftr, self.pd_dict)  
        flatten_ind = 0 
        for i in range(n):
            for j in range(i+1, n):
                edge_index[:,flatten_ind] = [i,j]
                edgenet_input[flatten_ind]  = np.concatenate((nonimg[i], nonimg[j]))
                aff_score[flatten_ind] = aff_adj[i][j]  
                flatten_ind +=1

        assert flatten_ind == num_edge, "Error in computing edge input"
        
        keep_ind = np.where(aff_score > 1.1)[0]  
        edge_index = edge_index[:, keep_ind]
        edgenet_input = edgenet_input[keep_ind]

        return edge_index, edgenet_input
    

In [4]:

import torch
from torch.nn import Linear as Lin, Sequential as Seq
import torch.nn.functional as F
from torch import nn

class WL(torch.nn.Module):
    def __init__(self, input_dim, dropout=0.3):
        super(WL, self).__init__()
        h1=256
        h2=128
        self.parser =nn.Sequential(
                nn.Linear(input_dim, h1, bias=True),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm1d(h1),
                nn.Dropout(dropout),
                nn.Linear(h1, h2, bias=True),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm1d(h2),
                nn.Dropout(dropout),
                nn.Linear(h2, h2, bias=True),
                )
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-8)
        self.input_dim = input_dim
        self.model_init()
        self.relu = nn.ReLU(inplace=True)
        self.elu = nn.ReLU()

    def forward(self, x):
        x1 = x[:,0:self.input_dim]
        x2 = x[:,self.input_dim:]
        h1 = self.parser(x1) 
        h2 = self.parser(x2) 
        p = (self.cos(h1,h2) + 1)*0.5
        return p

    def model_init(self):
        for m in self.modules():
            if isinstance(m, Lin):
                torch.nn.init.kaiming_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True



In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import Linear as Lin, Sequential as Seq
import torch_geometric as tg
# from wl import WL


class MLP(nn.Module):
    def __init__(self, input_dim, nhid):
        super(MLP,self).__init__()
        self.cls = nn.Sequential(
            torch.nn.Linear(input_dim,nhid))
        
    def forward(self, features):
        output = self.cls(features)
        return output
            
class GCN(nn.Module):
    def __init__(self, input_dim, nhid, num_classes, ngl, dropout, edge_dropout, edgenet_input_dim):
        super(GCN, self).__init__()
        K=3   
        hidden = [nhid for i in range(ngl)] 
        self.dropout = dropout
        self.edge_dropout = edge_dropout 
        bias = False 
        self.relu = torch.nn.ReLU(inplace=True) 
        self.ngl = ngl 
        self.gconv = nn.ModuleList()
        for i in range(ngl):
            in_channels = input_dim if i==0  else hidden[i-1]
            self.gconv.append(tg.nn.ChebConv(in_channels, hidden[i], K, normalization='sym', bias=bias)) 
          
        self.cls = nn.Sequential(
                torch.nn.Linear(16, 128),
                torch.nn.ReLU(inplace=True),
                nn.BatchNorm1d(128), 
                torch.nn.Linear(128, num_classes))

        self.edge_net = WL(input_dim=edgenet_input_dim//2, dropout=dropout)
        self.model_init()

    def model_init(self):
        for m in self.modules():
            if isinstance(m, Lin):
                torch.nn.init.kaiming_normal_(m.weight) # He init
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True

    def forward(self, features, edge_index, edgenet_input, enforce_edropout=False): 
        if self.edge_dropout>0:
            if enforce_edropout or self.training:
                one_mask = torch.ones([edgenet_input.shape[0],1])
                self.drop_mask = F.dropout(one_mask, self.edge_dropout, True)
                self.bool_mask = torch.squeeze(self.drop_mask.type(torch.bool))
                edge_index = edge_index[:, self.bool_mask] 
                edgenet_input = edgenet_input[self.bool_mask] # Weights
            
        edge_weight = torch.squeeze(self.edge_net(edgenet_input))
        

        # GCN residual connection
        # input layer
        features = F.dropout(features, self.dropout, self.training)
        x = self.relu(self.gconv[0](features, edge_index, edge_weight)) 
        x_temp = x
        
        # hidden layers
        for i in range(1, self.ngl - 1): # self.ngl→7
            x = F.dropout(x_temp, self.dropout, self.training)
            x = self.relu(self.gconv[i](x, edge_index, edge_weight)) 
            x_temp = x_temp + x # ([871,64])

        # output layer
        x = F.dropout(x_temp, self.dropout, self.training)
        x = self.relu(self.gconv[self.ngl - 1](x, edge_index, edge_weight))
        x_temp = x_temp + x

        output = x # Final output is not cumulative
        output = self.cls(output) 
        
        return output, edge_weight
    

    



In [6]:
import torch
import torchmetrics
from torchmetrics.classification import MulticlassSpecificity
import matplotlib.pyplot as plt
import numpy as np
import itertools

from sklearn.metrics import precision_recall_fscore_support

def torchmetrics_accuracy(preds, labels):
    acc = torchmetrics.functional.accuracy(preds, labels,task="multiclass", num_classes=2)
    return acc

def torchmetrics_spef(preds, labels):
    metric = MulticlassSpecificity(num_classes=2)
    spef = metric(preds, labels)
    return spef

def torchmetrics_auc(preds, labels):
    auc = torchmetrics.functional.auroc(preds, labels, task="multiclass", num_classes=2)
    return auc

def confusion_matrix(preds, labels):
    conf_matrix = torch.zeros(2, 2)
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[t, p] += 1 
    return conf_matrix
def plot_confusion_matrix(cm, normalize=False, title='Confusion matrix', cmap=plt.cm.Oranges):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Input
    - cm : computer the value of confusion matrix
    - normalize : True: %, False: 123
    """
    classes = ['0:ASD','1:TC']
    if normalize:
        cm = cm.numpy()
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    fmt = '.2f' if normalize else '.0f'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
def correct_num(preds, labels):
    """Accuracy, auc with masking.Acc of the masked samples"""
    correct_prediction = np.equal(np.argmax(preds, 1), labels).astype(np.float32)
    return np.sum(correct_prediction)

def prf(preds, labels, is_logit=True):
    ''' input: logits, labels  ''' 
    pred_lab= np.argmax(preds, 1)
    p,r,f,s  = precision_recall_fscore_support(labels, pred_lab, average='binary')
    return [p,r,f]







In [7]:
import argparse
parser = argparse.ArgumentParser()
args, unknown = parser.parse_known_args()

In [None]:
from __future__ import division
from __future__ import print_function

import os
import time
import argparse
import numpy as np
import io
import sys

import torch
import torch.optim as optim



# from dataloader import dataloader

from sklearn.model_selection import train_test_split
from sklearn.metrics import auc

if hasattr(sys.stdout, 'buffer'):
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')


class Args:
    def __init__(self):
        self.no_cuda = False
        self.seed = 46
        self.epochs = 200
        self.lr = 0.001
        self.weight_decay = 5e-5
        self.hidden = 16
        self.dropout = 0.2
        self.atlas = 'cc400'
        self.num_features = 2000
        self.folds = 10
        self.connectivity = 'correlation'
        self.max_degree = 3
        self.ngl = 4
        self.edropout = 0.3
        self.train = 1
        self.ckpt_path = '../folds/rois_cc400_pth_4_layer'
        self.early_stopping = True
        self.early_stopping_patience = 20

# Instantiate Args class
args = Args()

# Check if CUDA is available
args.cuda = not args.no_cuda and torch.cuda.is_available()

# Set random seeds
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Create params dictionary
params = vars(args)

# Print Hyperparameters
print('Hyperparameters:')
for key, value in params.items():
    print(key + ":", value)

corrects = np.zeros(args.folds, dtype=np.int32) 
accs = np.zeros(args.folds, dtype=np.float32) 
aucs = np.zeros(args.folds, dtype=np.float32)
prfs = np.zeros([args.folds,3], dtype=np.float32) # Save Precision, Recall, F1
test_num = np.zeros(args.folds, dtype=np.float32)


print('  Loading dataset ...')
dataloader = dataloader()
raw_features, y, nonimg = dataloader.load_data(params) 
cv_splits = dataloader.data_split(args.folds)
features=raw_features

t1 = time.time()
count=1;
for i in range(args.folds):
    
    
    
    t_start = time.time()
    train_ind, test_ind = cv_splits[i]

    train_ind, valid_ind = train_test_split(train_ind, test_size=0.1, random_state = 24)
    
    cv_splits[i] = (train_ind, valid_ind)
    cv_splits[i] = cv_splits[i] + (test_ind,)
    print('Size of the {}-fold Training, Validation, and Test Sets:{},{},{}' .format(i+1, len(cv_splits[i][0]), len(cv_splits[i][1]), len(cv_splits[i][2])))

    if args.train == 1:
        for j in range(args.folds):
            print(' Starting the {}-{} Fold:：'.format(i+1,j+1))
            node_ftr = dataloader.get_node_features(train_ind)
            edge_index, edgenet_input = dataloader.get_WL_inputs(nonimg)
            edgenet_input = (edgenet_input - edgenet_input.mean(axis=0)) / edgenet_input.std(axis=0)
            
            model = GCN(input_dim = args.num_features,
                        nhid = args.hidden, 
                        num_classes = 2, 
                        ngl = args.ngl, 
                        dropout = args.dropout, 
                        edge_dropout = args.edropout, 
                        edgenet_input_dim = 2*nonimg.shape[1])
            optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            
#             if args.cuda:
            model
            features = torch.tensor(node_ftr, dtype=torch.float32)
            edge_index = torch.tensor(edge_index, dtype=torch.long)
            edgenet_input = torch.tensor(edgenet_input, dtype=torch.float32)
            labels = torch.tensor(y, dtype=torch.long)
            fold_model_path = args.ckpt_path + "/fold{}.pth".format(i+1)
                
            acc = 0
            best_val_loss = float('inf') # early stoppping: Initialized to positive infinity
            current_patience = 0 # early stopping: Used to record the epochs of the current early stopping
            
            epoch_store = []
            acc_train_store =[]        
            pre_train_store =[]
            recall_train_store =[]
            F1_train_store =[]
            AUC_train_store =[]
            acc_val_store=[]
            pre_val_store=[]
            recall_val_store=[]
            F1_val_store=[]
            AUC_val_store=[]
            
            for epoch in range(args.epochs):
                # train
                model.train()
                with torch.set_grad_enabled(True):
                    optimizer.zero_grad()
                    output, edge_weights = model(features, edge_index, edgenet_input)
                    loss_train = torch.nn.CrossEntropyLoss()(output[train_ind], labels[train_ind])
                    loss_train.backward()
                    optimizer.step()
                acc_train = torchmetrics_accuracy(output[train_ind], labels[train_ind])
                auc_train = torchmetrics_auc(output[train_ind], labels[train_ind])
                logits_train = output[train_ind].detach().cpu().numpy()
                prf_train = prf(logits_train, y[train_ind])

                
                # valid
                model.eval()
                with torch.set_grad_enabled(False):
                    output, edge_weights = model(features, edge_index, edgenet_input)
                loss_val = torch.nn.CrossEntropyLoss()(output[valid_ind], labels[valid_ind])
                acc_val = torchmetrics_accuracy(output[valid_ind], labels[valid_ind])
                auc_val = torchmetrics_auc(output[valid_ind], labels[valid_ind])
                logits_val = output[valid_ind].detach().cpu().numpy()
                prf_val = prf(logits_val, y[valid_ind])

                
                print('Epoch:{:04d}'.format(epoch+1))
                print('acc_train:{:.4f}'.format(acc_train),
                      'pre_train:{:.4f}'.format(prf_train[0]),
                      'recall_train:{:.4f}'.format(prf_train[1]),
                      'F1_train:{:.4f}'.format(prf_train[2]),
                      'AUC_train:{:.4f}'.format(auc_train))
                print('acc_val:{:.4f}'.format(acc_val),
                      'pre_val:{:.4f}'.format(prf_val[0]),
                      'recall_val:{:.4f}'.format(prf_val[1]),
                      'F1_val:{:4f}'.format(prf_val[2]),
                      'AUC_val:{:.4f}'.format(auc_val))
                
                epoch_store.append(epoch+1)
                acc_train_store.append(acc_train)       
                pre_train_store.append(prf_train[0])
                recall_train_store.append(prf_train[1])
                F1_train_store.append(prf_train[2])
                AUC_train_store.append(auc_train)
                acc_val_store.append(acc_val)
                pre_val_store.append(prf_val[0])
                recall_val_store.append(prf_val[1])
                F1_val_store.append(prf_val[2])
                AUC_val_store.append(auc_val)
                
                # save pth
                if acc_val > acc and epoch > 50:
                    acc = acc_val
                    if args.ckpt_path != '':
                        if not os.path.exists(args.ckpt_path):
                            os.makedirs(args.ckpt_path)
                        torch.save(model.state_dict(), fold_model_path)
                
                # Early Stopping
                if epoch > 50 and args.early_stopping == True:
                    if loss_val < best_val_loss:
                        best_val_loss = loss_val
                        current_patience = 0
                    else:
                        current_patience += 1
                    if current_patience >= args.early_stopping_patience:
                        print('Early Stopping!!! epoch：{}'.format(epoch))
                        break
        print("===================================================================",i,"_",j)
        data  = { 
              "epoch" : epoch_store ,
              "acc_train" : acc_train_store ,        
              "pre_train" : pre_train_store ,
              "recall_train" : recall_train_store ,
              "F1_train" : F1_train_store ,
              "AUC_train" : AUC_train_store ,
              "acc_val" : acc_val_store,
               "pre_val" : pre_val_store ,
              "recall_val" : recall_val_store ,
              "F1_val" : F1_val_store ,
              "AUC_val" : AUC_val_store  
        }
        
        
        epoch_file_path =  f'../files/rois_cc400_4_layer/file_{i}_{j}_{count}.csv'
        data_file = pd.DataFrame(data);
        data_file.to_csv(epoch_file_path , index=False);
        count=count+1;
        # test
        print("Loading the Model for the {}-th Fold:... ...".format(i+1),
              "Size of samples in the test set:{}".format(len(test_ind)))
#         model.load_state_dict(torch.load(fold_model_path))
        model.eval()
        
        with torch.set_grad_enabled(False):
            output, edge_weights = model(features, edge_index, edgenet_input)
        acc_test = torchmetrics_accuracy(output[test_ind], labels[test_ind])
        auc_test = torchmetrics_auc(output[test_ind], labels[test_ind])
        logits_test = output[test_ind].detach().cpu().numpy()
        correct_test = correct_num(logits_test, y[test_ind])
        prf_test =  prf(logits_test, y[test_ind])
        
        t_end = time.time()
        t = t_end - t_start
        print('Fold {} Results:'.format(i+1),
              'test acc:{:.4f}'.format(acc_test),
              'test_pre:{:.4f}'.format(prf_test[0]),
              'test_recall:{:.4f}'.format(prf_test[1]),
              'test_F1:{:.4f}'.format(prf_test[2]),
              'test_AUC:{:.4f}'.format(auc_test),
              'time:{:.3f}s'.format(t))
        
        correct = correct_test
        aucs[i] = auc_test
        prfs[i] = prf_test
        corrects[i] = correct
        test_num[i] = len(test_ind)
    
    
    if args.train == 0:
        node_ftr = dataloader.get_node_features(train_ind)
        edge_index, edgenet_input = dataloader.get_WL_inputs(nonimg)
        edgenet_input = (edgenet_input - edgenet_input.mean(axis=0)) / edgenet_input.std(axis=0)
        
        model = GCN(input_dim = args.num_features,
                    nhid = args.hidden, 
                    num_classes = 2, 
                    ngl = args.ngl, 
                    dropout = args.dropout, 
                    edge_dropout = args.edropout, 
                    edgenet_input_dim = 2*nonimg.shape[1])
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        
#         if args.cuda
        model
        features = torch.tensor(node_ftr, dtype=torch.float)
        edge_index = torch.tensor(edge_index, dtype=torch.long)
        edgenet_input = torch.tensor(edgenet_input, dtype=torch.float32)
        labels = torch.tensor(y, dtype=torch.long)
        fold_model_path = args.ckpt_path + "/fold{}.pth".format(i+1)
        
        model.load_state_dict(torch.load(fold_model_path))
        model.eval()
        
        with torch.set_grad_enabled(False):
            output, edge_weights = model(features, edge_index, edgenet_input)
        acc_test = torchmetrics_accuracy(output[test_ind], labels[test_ind])
        auc_test = torchmetrics_auc(output[test_ind], labels[test_ind])
        logits_test = output[test_ind].detach().cpu().numpy()
        correct_test = correct_num(logits_test, y[test_ind])
        prf_test =  prf(logits_test, y[test_ind])
        
        t_end = time.time()
        t = t_end - t_start
        print('Fold {} Results:'.format(i+1),
              'test acc:{:.4f}'.format(acc_test),
              'test_pre:{:.4f}'.format(prf_test[0]),
              'test_recall:{:.4f}'.format(prf_test[1]),
              'test_F1:{:.4f}'.format(prf_test[2]),
              'test_AUC:{:.4f}'.format(auc_test),
              'time:{:.3f}s'.format(t))
        
        correct = correct_test
        aucs[i] = auc_test
        prfs[i] = prf_test
        corrects[i] = correct
        test_num[i] = len(test_ind)

t2 = time.time()

print('\r\n======Finish Results for Nested 10-fold cross-validation======')
Nested10kCV_acc = np.sum(corrects) / np.sum(test_num)
Nested10kCV_auc = np.mean(aucs)
Nested10kCV_precision, Nested10kCV_recall, Nested10kCV_F1 = np.mean(prfs, axis=0)
print('Test:',
      'acc:{}'.format(Nested10kCV_acc),
      'precision:{}'.format(Nested10kCV_precision),
      'recall:{}'.format(Nested10kCV_recall),
      'F1:{}'.format(Nested10kCV_F1),
      'AUC:{}'.format(Nested10kCV_auc))
print('Total duration:{}'.format(t2 - t1))



Hyperparameters:
no_cuda: False
seed: 46
epochs: 200
lr: 0.001
weight_decay: 5e-05
hidden: 16
dropout: 0.2
atlas: cc400
num_features: 2000
folds: 10
connectivity: correlation
max_degree: 3
ngl: 4
edropout: 0.3
train: 1
ckpt_path: ../folds/rois_cc400_pth_4_layer
early_stopping: True
early_stopping_patience: 20
cuda: False
  Loading dataset ...


  norm_networks = [np.arctanh(mat) if not np.all(np.abs(mat) == 1) else mat for mat in all_networks]


Size of the 1-fold Training, Validation, and Test Sets:900,100,112
 Starting the 1-1 Fold:：
Fitting estimator with 76636 features.
Fitting estimator with 76536 features.
Fitting estimator with 76436 features.
Fitting estimator with 76336 features.
Fitting estimator with 76236 features.
Fitting estimator with 76136 features.
Fitting estimator with 76036 features.
Fitting estimator with 75936 features.
Fitting estimator with 75836 features.
Fitting estimator with 75736 features.
Fitting estimator with 75636 features.
Fitting estimator with 75536 features.
Fitting estimator with 75436 features.
Fitting estimator with 75336 features.
Fitting estimator with 75236 features.
Fitting estimator with 75136 features.
Fitting estimator with 75036 features.
Fitting estimator with 74936 features.
Fitting estimator with 74836 features.
Fitting estimator with 74736 features.
Fitting estimator with 74636 features.
Fitting estimator with 74536 features.
Fitting estimator with 74436 features.
Fitting est

Fitting estimator with 55836 features.
Fitting estimator with 55736 features.
Fitting estimator with 55636 features.
Fitting estimator with 55536 features.
Fitting estimator with 55436 features.
Fitting estimator with 55336 features.
Fitting estimator with 55236 features.
Fitting estimator with 55136 features.
Fitting estimator with 55036 features.
Fitting estimator with 54936 features.
Fitting estimator with 54836 features.
Fitting estimator with 54736 features.
Fitting estimator with 54636 features.
Fitting estimator with 54536 features.
Fitting estimator with 54436 features.
Fitting estimator with 54336 features.
Fitting estimator with 54236 features.
Fitting estimator with 54136 features.
Fitting estimator with 54036 features.
Fitting estimator with 53936 features.
Fitting estimator with 53836 features.
Fitting estimator with 53736 features.
Fitting estimator with 53636 features.
Fitting estimator with 53536 features.
Fitting estimator with 53436 features.
Fitting estimator with 53

Fitting estimator with 34736 features.
Fitting estimator with 34636 features.
Fitting estimator with 34536 features.
Fitting estimator with 34436 features.
Fitting estimator with 34336 features.
Fitting estimator with 34236 features.
Fitting estimator with 34136 features.
Fitting estimator with 34036 features.
Fitting estimator with 33936 features.
Fitting estimator with 33836 features.
Fitting estimator with 33736 features.
Fitting estimator with 33636 features.
Fitting estimator with 33536 features.
Fitting estimator with 33436 features.
Fitting estimator with 33336 features.
Fitting estimator with 33236 features.
Fitting estimator with 33136 features.
Fitting estimator with 33036 features.
Fitting estimator with 32936 features.
Fitting estimator with 32836 features.
Fitting estimator with 32736 features.
Fitting estimator with 32636 features.
Fitting estimator with 32536 features.
Fitting estimator with 32436 features.
Fitting estimator with 32336 features.
Fitting estimator with 32

Fitting estimator with 13636 features.
Fitting estimator with 13536 features.
Fitting estimator with 13436 features.
Fitting estimator with 13336 features.
Fitting estimator with 13236 features.
Fitting estimator with 13136 features.
Fitting estimator with 13036 features.
Fitting estimator with 12936 features.
Fitting estimator with 12836 features.
Fitting estimator with 12736 features.
Fitting estimator with 12636 features.
Fitting estimator with 12536 features.
Fitting estimator with 12436 features.
Fitting estimator with 12336 features.
Fitting estimator with 12236 features.
Fitting estimator with 12136 features.
Fitting estimator with 12036 features.
Fitting estimator with 11936 features.
Fitting estimator with 11836 features.
Fitting estimator with 11736 features.
Fitting estimator with 11636 features.
Fitting estimator with 11536 features.
Fitting estimator with 11436 features.
Fitting estimator with 11336 features.
Fitting estimator with 11236 features.
Fitting estimator with 11

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0001
acc_train:0.4700 pre_train:0.4902 recall_train:0.6473 F1_train:0.5579 AUC_train:0.4541
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5972


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0002
acc_train:0.5356 pre_train:0.5439 recall_train:0.6258 F1_train:0.5820 AUC_train:0.5373
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5568


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0003
acc_train:0.5678 pre_train:0.5754 recall_train:0.6237 F1_train:0.5986 AUC_train:0.6190
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5433


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0004
acc_train:0.6089 pre_train:0.6336 recall_train:0.5763 F1_train:0.6036 AUC_train:0.6380
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5472
Epoch:0005
acc_train:0.6044 pre_train:0.6394 recall_train:0.5376 F1_train:0.5841 AUC_train:0.6309
acc_val:0.4800 pre_val:0.2500 recall_val:0.0200 F1_val:0.037037 AUC_val:0.5704
Epoch:0006
acc_train:0.6367 pre_train:0.6778 recall_train:0.5656 F1_train:0.6166 AUC_train:0.6651
acc_val:0.5000 pre_val:0.5000 recall_val:0.0600 F1_val:0.107143 AUC_val:0.5892
Epoch:0007
acc_train:0.5900 pre_train:0.6250 recall_train:0.5161 F1_train:0.5654 AUC_train:0.6158
acc_val:0.5000 pre_val:0.5000 recall_val:0.0600 F1_val:0.107143 AUC_val:0.6140
Epoch:0008
acc_train:0.6189 pre_train:0.6540 recall_train:0.5570 F1_train:0.6016 AUC_train:0.6576
acc_val:0.5700 pre_val:0.7692 recall_val:0.2000 F1_val:0.317460 AUC_val:0.6288
Epoch:0009
acc_train:0.6222 pre_train:0.6658 recall_train:0.5398 F1_train:0.5962 AUC_train:0.6612
acc_val:0.5900 pr

Epoch:0051
acc_train:0.9544 pre_train:0.9327 recall_train:0.9828 F1_train:0.9571 AUC_train:0.9919
acc_val:0.8000 pre_val:0.7885 recall_val:0.8200 F1_val:0.803922 AUC_val:0.8652
Epoch:0052
acc_train:0.9400 pre_train:0.9290 recall_train:0.9570 F1_train:0.9428 AUC_train:0.9841
acc_val:0.8500 pre_val:0.8723 recall_val:0.8200 F1_val:0.845361 AUC_val:0.8824
Epoch:0053
acc_train:0.9622 pre_train:0.9518 recall_train:0.9763 F1_train:0.9639 AUC_train:0.9942
acc_val:0.8300 pre_val:0.9024 recall_val:0.7400 F1_val:0.813187 AUC_val:0.8884
Epoch:0054
acc_train:0.9511 pre_train:0.9340 recall_train:0.9742 F1_train:0.9537 AUC_train:0.9884
acc_val:0.8400 pre_val:0.9048 recall_val:0.7600 F1_val:0.826087 AUC_val:0.8904
Epoch:0055
acc_train:0.9489 pre_train:0.9374 recall_train:0.9656 F1_train:0.9513 AUC_train:0.9900
acc_val:0.8600 pre_val:0.8750 recall_val:0.8400 F1_val:0.857143 AUC_val:0.8872
Epoch:0056
acc_train:0.9667 pre_train:0.9485 recall_train:0.9892 F1_train:0.9684 AUC_train:0.9910
acc_val:0.7900 pr

Epoch:0098
acc_train:0.9822 pre_train:0.9726 recall_train:0.9935 F1_train:0.9830 AUC_train:0.9972
acc_val:0.7400 pre_val:0.6875 recall_val:0.8800 F1_val:0.771930 AUC_val:0.8772
Epoch:0099
acc_train:0.9933 pre_train:0.9873 recall_train:1.0000 F1_train:0.9936 AUC_train:0.9997
acc_val:0.7500 pre_val:0.6984 recall_val:0.8800 F1_val:0.778761 AUC_val:0.8732
Early Stopping!!! epoch：98
 Starting the 1-2 Fold:：
Fitting estimator with 76636 features.
Fitting estimator with 76536 features.
Fitting estimator with 76436 features.
Fitting estimator with 76336 features.
Fitting estimator with 76236 features.
Fitting estimator with 76136 features.
Fitting estimator with 76036 features.
Fitting estimator with 75936 features.
Fitting estimator with 75836 features.
Fitting estimator with 75736 features.
Fitting estimator with 75636 features.
Fitting estimator with 75536 features.
Fitting estimator with 75436 features.
Fitting estimator with 75336 features.
Fitting estimator with 75236 features.
Fitting e

Fitting estimator with 56636 features.
Fitting estimator with 56536 features.
Fitting estimator with 56436 features.
Fitting estimator with 56336 features.
Fitting estimator with 56236 features.
Fitting estimator with 56136 features.
Fitting estimator with 56036 features.
Fitting estimator with 55936 features.
Fitting estimator with 55836 features.
Fitting estimator with 55736 features.
Fitting estimator with 55636 features.
Fitting estimator with 55536 features.
Fitting estimator with 55436 features.
Fitting estimator with 55336 features.
Fitting estimator with 55236 features.
Fitting estimator with 55136 features.
Fitting estimator with 55036 features.
Fitting estimator with 54936 features.
Fitting estimator with 54836 features.
Fitting estimator with 54736 features.
Fitting estimator with 54636 features.
Fitting estimator with 54536 features.
Fitting estimator with 54436 features.
Fitting estimator with 54336 features.
Fitting estimator with 54236 features.
Fitting estimator with 54

Fitting estimator with 35536 features.
Fitting estimator with 35436 features.
Fitting estimator with 35336 features.
Fitting estimator with 35236 features.
Fitting estimator with 35136 features.
Fitting estimator with 35036 features.
Fitting estimator with 34936 features.
Fitting estimator with 34836 features.
Fitting estimator with 34736 features.
Fitting estimator with 34636 features.
Fitting estimator with 34536 features.
Fitting estimator with 34436 features.
Fitting estimator with 34336 features.
Fitting estimator with 34236 features.
Fitting estimator with 34136 features.
Fitting estimator with 34036 features.
Fitting estimator with 33936 features.
Fitting estimator with 33836 features.
Fitting estimator with 33736 features.
Fitting estimator with 33636 features.
Fitting estimator with 33536 features.
Fitting estimator with 33436 features.
Fitting estimator with 33336 features.
Fitting estimator with 33236 features.
Fitting estimator with 33136 features.
Fitting estimator with 33

Fitting estimator with 14436 features.
Fitting estimator with 14336 features.
Fitting estimator with 14236 features.
Fitting estimator with 14136 features.
Fitting estimator with 14036 features.
Fitting estimator with 13936 features.
Fitting estimator with 13836 features.
Fitting estimator with 13736 features.
Fitting estimator with 13636 features.
Fitting estimator with 13536 features.
Fitting estimator with 13436 features.
Fitting estimator with 13336 features.
Fitting estimator with 13236 features.
Fitting estimator with 13136 features.
Fitting estimator with 13036 features.
Fitting estimator with 12936 features.
Fitting estimator with 12836 features.
Fitting estimator with 12736 features.
Fitting estimator with 12636 features.
Fitting estimator with 12536 features.
Fitting estimator with 12436 features.
Fitting estimator with 12336 features.
Fitting estimator with 12236 features.
Fitting estimator with 12136 features.
Fitting estimator with 12036 features.
Fitting estimator with 11

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0001
acc_train:0.4356 pre_train:0.4667 recall_train:0.6473 F1_train:0.5423 AUC_train:0.4271
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.3004


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0002
acc_train:0.4911 pre_train:0.5074 recall_train:0.5183 F1_train:0.5128 AUC_train:0.5075
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.3012


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0003
acc_train:0.5544 pre_train:0.5658 recall_train:0.5914 F1_train:0.5783 AUC_train:0.5822
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.3188
Epoch:0004
acc_train:0.5578 pre_train:0.5717 recall_train:0.5742 F1_train:0.5730 AUC_train:0.6083
acc_val:0.4100 pre_val:0.3333 recall_val:0.1800 F1_val:0.233766 AUC_val:0.3212
Epoch:0005
acc_train:0.5833 pre_train:0.5983 recall_train:0.5892 F1_train:0.5937 AUC_train:0.6282
acc_val:0.3500 pre_val:0.3913 recall_val:0.5400 F1_val:0.453782 AUC_val:0.3300
Epoch:0006
acc_train:0.5811 pre_train:0.5894 recall_train:0.6237 F1_train:0.6061 AUC_train:0.6209
acc_val:0.3800 pre_val:0.4211 recall_val:0.6400 F1_val:0.507937 AUC_val:0.3308
Epoch:0007
acc_train:0.5956 pre_train:0.6095 recall_train:0.6043 F1_train:0.6069 AUC_train:0.6324
acc_val:0.4000 pre_val:0.4375 recall_val:0.7000 F1_val:0.538462 AUC_val:0.3260
Epoch:0008
acc_train:0.6078 pre_train:0.6181 recall_train:0.6301 F1_train:0.6241 AUC_train:0.6237
acc_val:0.4400 pr

Epoch:0050
acc_train:0.6211 pre_train:0.6403 recall_train:0.6086 F1_train:0.6240 AUC_train:0.6786
acc_val:0.6700 pre_val:0.7297 recall_val:0.5400 F1_val:0.620690 AUC_val:0.6980
Epoch:0051
acc_train:0.6567 pre_train:0.6840 recall_train:0.6237 F1_train:0.6524 AUC_train:0.7080
acc_val:0.6800 pre_val:0.7500 recall_val:0.5400 F1_val:0.627907 AUC_val:0.7004
Epoch:0052
acc_train:0.6500 pre_train:0.6659 recall_train:0.6473 F1_train:0.6565 AUC_train:0.7101
acc_val:0.6800 pre_val:0.7500 recall_val:0.5400 F1_val:0.627907 AUC_val:0.7036
Epoch:0053
acc_train:0.6500 pre_train:0.6894 recall_train:0.5871 F1_train:0.6341 AUC_train:0.7046
acc_val:0.6800 pre_val:0.7500 recall_val:0.5400 F1_val:0.627907 AUC_val:0.7064
Epoch:0054
acc_train:0.6444 pre_train:0.6808 recall_train:0.5871 F1_train:0.6305 AUC_train:0.7150
acc_val:0.6900 pre_val:0.7714 recall_val:0.5400 F1_val:0.635294 AUC_val:0.7108
Epoch:0055
acc_train:0.6567 pre_train:0.6741 recall_train:0.6495 F1_train:0.6616 AUC_train:0.7158
acc_val:0.6800 pr

Epoch:0097
acc_train:0.9378 pre_train:0.9115 recall_train:0.9742 F1_train:0.9418 AUC_train:0.9832
acc_val:0.7900 pre_val:0.7843 recall_val:0.8000 F1_val:0.792079 AUC_val:0.8412
Epoch:0098
acc_train:0.9333 pre_train:0.9074 recall_train:0.9699 F1_train:0.9376 AUC_train:0.9755
acc_val:0.7500 pre_val:0.7778 recall_val:0.7000 F1_val:0.736842 AUC_val:0.8356
Epoch:0099
acc_train:0.9311 pre_train:0.9038 recall_train:0.9699 F1_train:0.9357 AUC_train:0.9822
acc_val:0.7400 pre_val:0.7857 recall_val:0.6600 F1_val:0.717391 AUC_val:0.8316
Epoch:0100
acc_train:0.9311 pre_train:0.9172 recall_train:0.9527 F1_train:0.9346 AUC_train:0.9834
acc_val:0.7400 pre_val:0.8000 recall_val:0.6400 F1_val:0.711111 AUC_val:0.8328
Epoch:0101
acc_train:0.9322 pre_train:0.9226 recall_train:0.9484 F1_train:0.9353 AUC_train:0.9798
acc_val:0.7500 pre_val:0.8205 recall_val:0.6400 F1_val:0.719101 AUC_val:0.8400
Epoch:0102
acc_train:0.9422 pre_train:0.9138 recall_train:0.9806 F1_train:0.9461 AUC_train:0.9899
acc_val:0.7500 pr

Fitting estimator with 72936 features.
Fitting estimator with 72836 features.
Fitting estimator with 72736 features.
Fitting estimator with 72636 features.
Fitting estimator with 72536 features.
Fitting estimator with 72436 features.
Fitting estimator with 72336 features.
Fitting estimator with 72236 features.
Fitting estimator with 72136 features.
Fitting estimator with 72036 features.
Fitting estimator with 71936 features.
Fitting estimator with 71836 features.
Fitting estimator with 71736 features.
Fitting estimator with 71636 features.
Fitting estimator with 71536 features.
Fitting estimator with 71436 features.
Fitting estimator with 71336 features.
Fitting estimator with 71236 features.
Fitting estimator with 71136 features.
Fitting estimator with 71036 features.
Fitting estimator with 70936 features.
Fitting estimator with 70836 features.
Fitting estimator with 70736 features.
Fitting estimator with 70636 features.
Fitting estimator with 70536 features.
Fitting estimator with 70

Fitting estimator with 51836 features.
Fitting estimator with 51736 features.
Fitting estimator with 51636 features.
Fitting estimator with 51536 features.
Fitting estimator with 51436 features.
Fitting estimator with 51336 features.
Fitting estimator with 51236 features.
Fitting estimator with 51136 features.
Fitting estimator with 51036 features.
Fitting estimator with 50936 features.
Fitting estimator with 50836 features.
Fitting estimator with 50736 features.
Fitting estimator with 50636 features.
Fitting estimator with 50536 features.
Fitting estimator with 50436 features.
Fitting estimator with 50336 features.
Fitting estimator with 50236 features.
Fitting estimator with 50136 features.
Fitting estimator with 50036 features.
Fitting estimator with 49936 features.
Fitting estimator with 49836 features.
Fitting estimator with 49736 features.
Fitting estimator with 49636 features.
Fitting estimator with 49536 features.
Fitting estimator with 49436 features.
Fitting estimator with 49

Fitting estimator with 30736 features.
Fitting estimator with 30636 features.
Fitting estimator with 30536 features.
Fitting estimator with 30436 features.
Fitting estimator with 30336 features.
Fitting estimator with 30236 features.
Fitting estimator with 30136 features.
Fitting estimator with 30036 features.
Fitting estimator with 29936 features.
Fitting estimator with 29836 features.
Fitting estimator with 29736 features.
Fitting estimator with 29636 features.
Fitting estimator with 29536 features.
Fitting estimator with 29436 features.
Fitting estimator with 29336 features.
Fitting estimator with 29236 features.
Fitting estimator with 29136 features.
Fitting estimator with 29036 features.
Fitting estimator with 28936 features.
Fitting estimator with 28836 features.
Fitting estimator with 28736 features.
Fitting estimator with 28636 features.
Fitting estimator with 28536 features.
Fitting estimator with 28436 features.
Fitting estimator with 28336 features.
Fitting estimator with 28

Fitting estimator with 9636 features.
Fitting estimator with 9536 features.
Fitting estimator with 9436 features.
Fitting estimator with 9336 features.
Fitting estimator with 9236 features.
Fitting estimator with 9136 features.
Fitting estimator with 9036 features.
Fitting estimator with 8936 features.
Fitting estimator with 8836 features.
Fitting estimator with 8736 features.
Fitting estimator with 8636 features.
Fitting estimator with 8536 features.
Fitting estimator with 8436 features.
Fitting estimator with 8336 features.
Fitting estimator with 8236 features.
Fitting estimator with 8136 features.
Fitting estimator with 8036 features.
Fitting estimator with 7936 features.
Fitting estimator with 7836 features.
Fitting estimator with 7736 features.
Fitting estimator with 7636 features.
Fitting estimator with 7536 features.
Fitting estimator with 7436 features.
Fitting estimator with 7336 features.
Fitting estimator with 7236 features.
Fitting estimator with 7136 features.
Fitting esti

Epoch:0031
acc_train:0.7022 pre_train:0.7184 recall_train:0.6968 F1_train:0.7074 AUC_train:0.7689
acc_val:0.5000 pre_val:0.5000 recall_val:1.0000 F1_val:0.666667 AUC_val:0.7716
Epoch:0032
acc_train:0.6778 pre_train:0.6797 recall_train:0.7118 F1_train:0.6954 AUC_train:0.7534
acc_val:0.5200 pre_val:0.5102 recall_val:1.0000 F1_val:0.675676 AUC_val:0.7720
Epoch:0033
acc_train:0.6744 pre_train:0.6629 recall_train:0.7527 F1_train:0.7049 AUC_train:0.7426
acc_val:0.5300 pre_val:0.5158 recall_val:0.9800 F1_val:0.675862 AUC_val:0.7736
Epoch:0034
acc_train:0.7122 pre_train:0.7068 recall_train:0.7570 F1_train:0.7310 AUC_train:0.7937
acc_val:0.5600 pre_val:0.5326 recall_val:0.9800 F1_val:0.690141 AUC_val:0.7768
Epoch:0035
acc_train:0.7389 pre_train:0.7426 recall_train:0.7570 F1_train:0.7497 AUC_train:0.8341
acc_val:0.6400 pre_val:0.5897 recall_val:0.9200 F1_val:0.718750 AUC_val:0.7812
Epoch:0036
acc_train:0.7256 pre_train:0.7271 recall_train:0.7505 F1_train:0.7386 AUC_train:0.8148
acc_val:0.6800 pr

Epoch:0078
acc_train:0.9622 pre_train:0.9757 recall_train:0.9505 F1_train:0.9630 AUC_train:0.9913
acc_val:0.7800 pre_val:0.8043 recall_val:0.7400 F1_val:0.770833 AUC_val:0.8732
Epoch:0079
acc_train:0.9600 pre_train:0.9799 recall_train:0.9419 F1_train:0.9605 AUC_train:0.9956
acc_val:0.7900 pre_val:0.8222 recall_val:0.7400 F1_val:0.778947 AUC_val:0.8732
Epoch:0080
acc_train:0.9722 pre_train:0.9825 recall_train:0.9634 F1_train:0.9729 AUC_train:0.9933
acc_val:0.7900 pre_val:0.8222 recall_val:0.7400 F1_val:0.778947 AUC_val:0.8760
Epoch:0081
acc_train:0.9656 pre_train:0.9865 recall_train:0.9462 F1_train:0.9660 AUC_train:0.9969
acc_val:0.8000 pre_val:0.8000 recall_val:0.8000 F1_val:0.800000 AUC_val:0.8744
Epoch:0082
acc_train:0.9633 pre_train:0.9758 recall_train:0.9527 F1_train:0.9641 AUC_train:0.9963
acc_val:0.7700 pre_val:0.7368 recall_val:0.8400 F1_val:0.785047 AUC_val:0.8644
Epoch:0083
acc_train:0.9789 pre_train:0.9934 recall_train:0.9656 F1_train:0.9793 AUC_train:0.9957
acc_val:0.7300 pr

Fitting estimator with 68836 features.
Fitting estimator with 68736 features.
Fitting estimator with 68636 features.
Fitting estimator with 68536 features.
Fitting estimator with 68436 features.
Fitting estimator with 68336 features.
Fitting estimator with 68236 features.
Fitting estimator with 68136 features.
Fitting estimator with 68036 features.
Fitting estimator with 67936 features.
Fitting estimator with 67836 features.
Fitting estimator with 67736 features.
Fitting estimator with 67636 features.
Fitting estimator with 67536 features.
Fitting estimator with 67436 features.
Fitting estimator with 67336 features.
Fitting estimator with 67236 features.
Fitting estimator with 67136 features.
Fitting estimator with 67036 features.
Fitting estimator with 66936 features.
Fitting estimator with 66836 features.
Fitting estimator with 66736 features.
Fitting estimator with 66636 features.
Fitting estimator with 66536 features.
Fitting estimator with 66436 features.
Fitting estimator with 66

Fitting estimator with 47736 features.
Fitting estimator with 47636 features.
Fitting estimator with 47536 features.
Fitting estimator with 47436 features.
Fitting estimator with 47336 features.
Fitting estimator with 47236 features.
Fitting estimator with 47136 features.
Fitting estimator with 47036 features.
Fitting estimator with 46936 features.
Fitting estimator with 46836 features.
Fitting estimator with 46736 features.
Fitting estimator with 46636 features.
Fitting estimator with 46536 features.
Fitting estimator with 46436 features.
Fitting estimator with 46336 features.
Fitting estimator with 46236 features.
Fitting estimator with 46136 features.
Fitting estimator with 46036 features.
Fitting estimator with 45936 features.
Fitting estimator with 45836 features.
Fitting estimator with 45736 features.
Fitting estimator with 45636 features.
Fitting estimator with 45536 features.
Fitting estimator with 45436 features.
Fitting estimator with 45336 features.
Fitting estimator with 45

Fitting estimator with 26636 features.
Fitting estimator with 26536 features.
Fitting estimator with 26436 features.
Fitting estimator with 26336 features.
Fitting estimator with 26236 features.
Fitting estimator with 26136 features.
Fitting estimator with 26036 features.
Fitting estimator with 25936 features.
Fitting estimator with 25836 features.
Fitting estimator with 25736 features.
Fitting estimator with 25636 features.
Fitting estimator with 25536 features.
Fitting estimator with 25436 features.
Fitting estimator with 25336 features.
Fitting estimator with 25236 features.
Fitting estimator with 25136 features.
Fitting estimator with 25036 features.
Fitting estimator with 24936 features.
Fitting estimator with 24836 features.
Fitting estimator with 24736 features.
Fitting estimator with 24636 features.
Fitting estimator with 24536 features.
Fitting estimator with 24436 features.
Fitting estimator with 24336 features.
Fitting estimator with 24236 features.
Fitting estimator with 24

Fitting estimator with 5436 features.
Fitting estimator with 5336 features.
Fitting estimator with 5236 features.
Fitting estimator with 5136 features.
Fitting estimator with 5036 features.
Fitting estimator with 4936 features.
Fitting estimator with 4836 features.
Fitting estimator with 4736 features.
Fitting estimator with 4636 features.
Fitting estimator with 4536 features.
Fitting estimator with 4436 features.
Fitting estimator with 4336 features.
Fitting estimator with 4236 features.
Fitting estimator with 4136 features.
Fitting estimator with 4036 features.
Fitting estimator with 3936 features.
Fitting estimator with 3836 features.
Fitting estimator with 3736 features.
Fitting estimator with 3636 features.
Fitting estimator with 3536 features.
Fitting estimator with 3436 features.
Fitting estimator with 3336 features.
Fitting estimator with 3236 features.
Fitting estimator with 3136 features.
Fitting estimator with 3036 features.
Fitting estimator with 2936 features.
Fitting esti

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0001
acc_train:0.4378 pre_train:0.4685 recall_train:0.6559 F1_train:0.5466 AUC_train:0.3913
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.3176


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0002
acc_train:0.4978 pre_train:0.5103 recall_train:0.6925 F1_train:0.5876 AUC_train:0.5002
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.4948


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0003
acc_train:0.4989 pre_train:0.5126 recall_train:0.6108 F1_train:0.5574 AUC_train:0.5206
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5452


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0004
acc_train:0.5767 pre_train:0.5875 recall_train:0.6065 F1_train:0.5968 AUC_train:0.5976
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5824


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0005
acc_train:0.5789 pre_train:0.5991 recall_train:0.5591 F1_train:0.5784 AUC_train:0.6212
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5924


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0006
acc_train:0.6244 pre_train:0.6591 recall_train:0.5656 F1_train:0.6088 AUC_train:0.6395
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5876


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0007
acc_train:0.6233 pre_train:0.6522 recall_train:0.5806 F1_train:0.6143 AUC_train:0.6439
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5900


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0008
acc_train:0.6167 pre_train:0.6442 recall_train:0.5763 F1_train:0.6084 AUC_train:0.6374
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5960


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0009
acc_train:0.6322 pre_train:0.6650 recall_train:0.5806 F1_train:0.6200 AUC_train:0.6690
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.6000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0010
acc_train:0.6233 pre_train:0.6552 recall_train:0.5720 F1_train:0.6108 AUC_train:0.6612
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.5996


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0011
acc_train:0.6311 pre_train:0.6802 recall_train:0.5398 F1_train:0.6019 AUC_train:0.6634
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.6104


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0012
acc_train:0.6067 pre_train:0.6364 recall_train:0.5570 F1_train:0.5940 AUC_train:0.6410
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.6144


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0013
acc_train:0.6344 pre_train:0.6762 recall_train:0.5613 F1_train:0.6134 AUC_train:0.6732
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.6236


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch:0014
acc_train:0.6100 pre_train:0.6432 recall_train:0.5505 F1_train:0.5933 AUC_train:0.6391
acc_val:0.5000 pre_val:0.0000 recall_val:0.0000 F1_val:0.000000 AUC_val:0.6272
Epoch:0015
acc_train:0.6133 pre_train:0.6481 recall_train:0.5505 F1_train:0.5953 AUC_train:0.6448
acc_val:0.5400 pre_val:1.0000 recall_val:0.0800 F1_val:0.148148 AUC_val:0.6472
Epoch:0016
acc_train:0.6200 pre_train:0.6461 recall_train:0.5849 F1_train:0.6140 AUC_train:0.6678
acc_val:0.5600 pre_val:0.7500 recall_val:0.1800 F1_val:0.290323 AUC_val:0.6608
Epoch:0017
acc_train:0.6256 pre_train:0.6488 recall_train:0.6000 F1_train:0.6235 AUC_train:0.6523
acc_val:0.6100 pre_val:0.8235 recall_val:0.2800 F1_val:0.417910 AUC_val:0.6660
Epoch:0018
acc_train:0.5878 pre_train:0.6193 recall_train:0.5247 F1_train:0.5681 AUC_train:0.6274
acc_val:0.6600 pre_val:0.8077 recall_val:0.4200 F1_val:0.552632 AUC_val:0.6724
Epoch:0019
acc_train:0.5878 pre_train:0.6152 recall_train:0.5398 F1_train:0.5750 AUC_train:0.6467
acc_val:0.6400 pr

In [None]:
# from __future__ import division
# from __future__ import print_function

# import os
# import time
# import argparse
# import numpy as np
# import io
# import sys

# import torch
# import torch.optim as optim

# from models import GCN

# from metrics import torchmetrics_accuracy, torchmetrics_auc, correct_num, prf

# # from dataloader import dataloader

# from sklearn.model_selection import train_test_split
# from sklearn.metrics import auc

# if hasattr(sys.stdout, 'buffer'):
#     sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')

# # Training settings
# # parser = argparse.ArgumentParser()
# # parser.add_argument('--no_cuda', action='store_true', default=False, help='Disables CUDA training.')
# # parser.add_argument('--seed', type=int, default=46, help='Random seed.')
# # parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
# # parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.')
# # parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight decay (L2 loss on parameters).')
# # parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
# # parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate (1 - keep probability).')
# # parser.add_argument('--atlas', default='cc200', help='atlas for network construction (node definition) default: ho, see preprocessed-connectomes-project.org/abide/Pipelines.html for more options ')
# # parser.add_argument('--num_features', default=2000, type=int, help='Number of features to keep for the feature selection step (default: 2000)')
# # parser.add_argument('--folds', default=10, type=int, help='For cross validation, specifies which fold will be used. All folds are used if set to 11 (default: 11)')
# # parser.add_argument('--connectivity', default='correlation', help='Type of connectivity used for network construction (default: correlation, options: correlation, partial correlation, tangent)')
# # parser.add_argument('--max_degree', type=int, default=3, help='Maximum Chebyshev polynomial degree.')
# # parser.add_argument('--ngl', default=8, type=int, help='number of gcn hidden layders (default: 8)')
# # parser.add_argument('--edropout', type=float, default=0.3, help='edge dropout rate')
# # parser.add_argument('--train', default=1, type=int, help='train(default: 1) or evaluate(0)')
# # parser.add_argument('--ckpt_path', type=str, default='./pth', help='checkpoint path to save trained models')
# # parser.add_argument('--early_stopping', action='store_true', default=True, help='early stopping switch')
# # parser.add_argument('--early_stopping_patience', type=int, default=20, help='early stoppng epochs')

# # args = parser.parse_args()
# # args.cuda = not args.no_cuda and torch.cuda.is_available()

# # np.random.seed(args.seed)
# # torch.manual_seed(args.seed)
# # if args.cuda:
# #     torch.cuda.manual_seed(args.seed)
    
# # params = dict()
# # params['no_cuda'] = args.no_cuda
# # params['seed'] = args.seed
# # params['epochs'] = args.epochs
# # params['lr'] = args.lr
# # params['weight_decay'] = args.weight_decay
# # params['hidden'] = args.hidden
# # params['dropout'] = args.dropout
# # params['atlas'] = args.atlas
# # params['num_features'] = args.num_features
# # params['folds'] = args.folds
# # params['connectivity'] = args.connectivity
# # params['max_degree'] = args.max_degree
# # params['ngl'] = args.ngl
# # params['edropout'] = args.edropout
# # params['train'] = args.train
# # params['ckpt_path'] = args.ckpt_path
# # params['early_stopping'] = args.early_stopping
# # params['early_stopping_patience'] = args.early_stopping_patience

# class Args:
#     def __init__(self):
#         self.no_cuda = False
#         self.seed = 46
#         self.epochs = 200
#         self.lr = 0.001
#         self.weight_decay = 5e-5
#         self.hidden = 16
#         self.dropout = 0.2
#         self.atlas = 'cc400'
#         self.num_features = 2000
#         self.folds = 10
#         self.connectivity = 'correlation'
#         self.max_degree = 3
#         self.ngl = 8
#         self.edropout = 0.3
#         self.train = 0
#         self.ckpt_path = './pth'
#         self.early_stopping = True
#         self.early_stopping_patience = 20

# # Instantiate Args class
# args = Args()

# # Check if CUDA is available
# args.cuda = not args.no_cuda and torch.cuda.is_available()

# # Set random seeds
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
# if args.cuda:
#     torch.cuda.manual_seed(args.seed)

# # Create params dictionary
# params = vars(args)

# # Print Hyperparameters
# print('Hyperparameters:')
# for key, value in params.items():
#     print(key + ":", value)

# corrects = np.zeros(args.folds, dtype=np.int32) 
# accs = np.zeros(args.folds, dtype=np.float32) 
# aucs = np.zeros(args.folds, dtype=np.float32)
# prfs = np.zeros([args.folds,3], dtype=np.float32) # Save Precision, Recall, F1
# test_num = np.zeros(args.folds, dtype=np.float32)

# print('  Loading dataset ...')
# dataloader = dataloader()
# raw_features, y, nonimg = dataloader.load_data(params) 
# cv_splits = dataloader.data_split(args.folds)
# features=raw_features

# t1 = time.time()

# for i in range(args.folds):
#     t_start = time.time()
#     train_ind, test_ind = cv_splits[i]

#     train_ind, valid_ind = train_test_split(train_ind, test_size=0.1, random_state = 24)
    
#     cv_splits[i] = (train_ind, valid_ind)
#     cv_splits[i] = cv_splits[i] + (test_ind,)
#     print('Size of the {}-fold Training, Validation, and Test Sets:{},{},{}' .format(i+1, len(cv_splits[i][0]), len(cv_splits[i][1]), len(cv_splits[i][2])))

#     if args.train == 1:
#         for j in range(args.folds):
#             print(' Starting the {}-{} Fold:：'.format(i+1,j+1))
#             node_ftr = dataloader.get_node_features(train_ind)
#             edge_index, edgenet_input = dataloader.get_WL_inputs(nonimg)
#             edgenet_input = (edgenet_input - edgenet_input.mean(axis=0)) / edgenet_input.std(axis=0)
            
#             model = GCN(input_dim = args.num_features,
#                         nhid = args.hidden, 
#                         num_classes = 2, 
#                         ngl = args.ngl, 
#                         dropout = args.dropout, 
#                         edge_dropout = args.edropout, 
#                         edgenet_input_dim = 2*nonimg.shape[1])
#             optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            
# #             if args.cuda:
#             model
#             features = torch.tensor(node_ftr, dtype=torch.float32)
#             edge_index = torch.tensor(edge_index, dtype=torch.long)
#             edgenet_input = torch.tensor(edgenet_input, dtype=torch.float32)
#             labels = torch.tensor(y, dtype=torch.long)
#             fold_model_path = args.ckpt_path + "/fold{}.pth".format(i+1)
                
#             acc = 0
#             best_val_loss = float('inf') # early stoppping: Initialized to positive infinity
#             current_patience = 0 # early stopping: Used to record the epochs of the current early stopping

#             for epoch in range(args.epochs):
#                 # train
#                 model.train()
#                 with torch.set_grad_enabled(True):
#                     optimizer.zero_grad()
#                     output, edge_weights = model(features, edge_index, edgenet_input)
#                     loss_train = torch.nn.CrossEntropyLoss()(output[train_ind], labels[train_ind])
#                     loss_train.backward()
#                     optimizer.step()
#                 acc_train = torchmetrics_accuracy(output[train_ind], labels[train_ind])
#                 auc_train = torchmetrics_auc(output[train_ind], labels[train_ind])
#                 logits_train = output[train_ind].detach().cpu().numpy()
#                 prf_train = prf(logits_train, y[train_ind])

                
#                 # valid
#                 model.eval()
#                 with torch.set_grad_enabled(False):
#                     output, edge_weights = model(features, edge_index, edgenet_input)
#                 loss_val = torch.nn.CrossEntropyLoss()(output[valid_ind], labels[valid_ind])
#                 acc_val = torchmetrics_accuracy(output[valid_ind], labels[valid_ind])
#                 auc_val = torchmetrics_auc(output[valid_ind], labels[valid_ind])
#                 logits_val = output[valid_ind].detach().cpu().numpy()
#                 prf_val = prf(logits_val, y[valid_ind])

                
#                 print('Epoch:{:04d}'.format(epoch+1))
#                 print('acc_train:{:.4f}'.format(acc_train),
#                       'pre_train:{:.4f}'.format(prf_train[0]),
#                       'recall_train:{:.4f}'.format(prf_train[1]),
#                       'F1_train:{:.4f}'.format(prf_train[2]),
#                       'AUC_train:{:.4f}'.format(auc_train))
#                 print('acc_val:{:.4f}'.format(acc_val),
#                       'pre_val:{:.4f}'.format(prf_val[0]),
#                       'recall_val:{:.4f}'.format(prf_val[1]),
#                       'F1_val:{:4f}'.format(prf_val[2]),
#                       'AUC_val:{:.4f}'.format(auc_val))
                
#                 # save pth
#                 if acc_val > acc and epoch > 50:
#                     acc = acc_val
#                     if args.ckpt_path != '':
#                         if not os.path.exists(args.ckpt_path):
#                             os.makedirs(args.ckpt_path)
#                         torch.save(model.state_dict(), fold_model_path)
                
#                 # Early Stopping
#                 if epoch > 50 and args.early_stopping == True:
#                     if loss_val < best_val_loss:
#                         best_val_loss = loss_val
#                         current_patience = 0
#                     else:
#                         current_patience += 1
#                     if current_patience >= args.early_stopping_patience:
#                         print('Early Stopping!!! epoch：{}'.format(epoch))
#                         break
                        
#         # test
#         print("Loading the Model for the {}-th Fold:... ...".format(i+1),
#               "Size of samples in the test set:{}".format(len(test_ind)))
#         model.load_state_dict(torch.load(fold_model_path))
#         model.eval()
        
#         with torch.set_grad_enabled(False):
#             output, edge_weights = model(features, edge_index, edgenet_input)
#         acc_test = torchmetrics_accuracy(output[test_ind], labels[test_ind])
#         auc_test = torchmetrics_auc(output[test_ind], labels[test_ind])
#         logits_test = output[test_ind].detach().cpu().numpy()
#         correct_test = correct_num(logits_test, y[test_ind])
#         prf_test =  prf(logits_test, y[test_ind])
        
#         t_end = time.time()
#         t = t_end - t_start
#         print('Fold {} Results:'.format(i+1),
#               'test acc:{:.4f}'.format(acc_test),
#               'test_pre:{:.4f}'.format(prf_test[0]),
#               'test_recall:{:.4f}'.format(prf_test[1]),
#               'test_F1:{:.4f}'.format(prf_test[2]),
#               'test_AUC:{:.4f}'.format(auc_test),
#               'time:{:.3f}s'.format(t))
        
#         correct = correct_test
#         aucs[i] = auc_test
#         prfs[i] = prf_test
#         corrects[i] = correct
#         test_num[i] = len(test_ind)
    
    
#     if args.train == 0:
#         node_ftr = dataloader.get_node_features(train_ind)
#         edge_index, edgenet_input = dataloader.get_WL_inputs(nonimg)
#         edgenet_input = (edgenet_input - edgenet_input.mean(axis=0)) / edgenet_input.std(axis=0)
        
#         model = GCN(input_dim = args.num_features,
#                     nhid = args.hidden, 
#                     num_classes = 2, 
#                     ngl = args.ngl, 
#                     dropout = args.dropout, 
#                     edge_dropout = args.edropout, 
#                     edgenet_input_dim = 2*nonimg.shape[1])
#         optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        
# #         if args.cuda
#         model
#         features = torch.tensor(node_ftr, dtype=torch.float)
#         edge_index = torch.tensor(edge_index, dtype=torch.long)
#         edgenet_input = torch.tensor(edgenet_input, dtype=torch.float32)
#         labels = torch.tensor(y, dtype=torch.long)
#         fold_model_path = args.ckpt_path + "/fold{}.pth".format(i+1)
        
#         model.load_state_dict(torch.load(fold_model_path))
#         model.eval()
        
#         with torch.set_grad_enabled(False):
#             output, edge_weights = model(features, edge_index, edgenet_input)
#         acc_test = torchmetrics_accuracy(output[test_ind], labels[test_ind])
#         auc_test = torchmetrics_auc(output[test_ind], labels[test_ind])
#         logits_test = output[test_ind].detach().cpu().numpy()
#         correct_test = correct_num(logits_test, y[test_ind])
#         prf_test =  prf(logits_test, y[test_ind])
        
#         t_end = time.time()
#         t = t_end - t_start
#         print('Fold {} Results:'.format(i+1),
#               'test acc:{:.4f}'.format(acc_test),
#               'test_pre:{:.4f}'.format(prf_test[0]),
#               'test_recall:{:.4f}'.format(prf_test[1]),
#               'test_F1:{:.4f}'.format(prf_test[2]),
#               'test_AUC:{:.4f}'.format(auc_test),
#               'time:{:.3f}s'.format(t))
        
#         correct = correct_test
#         aucs[i] = auc_test
#         prfs[i] = prf_test
#         corrects[i] = correct
#         test_num[i] = len(test_ind)

# t2 = time.time()

# print('\r\n======Finish Results for Nested 10-fold cross-validation======')
# Nested10kCV_acc = np.sum(corrects) / np.sum(test_num)
# Nested10kCV_auc = np.mean(aucs)
# Nested10kCV_precision, Nested10kCV_recall, Nested10kCV_F1 = np.mean(prfs, axis=0)
# print('Test:',
#       'acc:{}'.format(Nested10kCV_acc),
#       'precision:{}'.format(Nested10kCV_precision),
#       'recall:{}'.format(Nested10kCV_recall),
#       'F1:{}'.format(Nested10kCV_F1),
#       'AUC:{}'.format(Nested10kCV_auc))
# print('Total duration:{}'.format(t2 - t1))

In [None]:
# from __future__ import division
# from __future__ import print_function

# import os
# import time
# import argparse
# import numpy as np
# import io
# import sys

# import torch
# import torch.optim as optim

# from models import GCN

# from metrics import torchmetrics_accuracy, torchmetrics_auc, correct_num, prf

# # from dataloader import dataloader

# from sklearn.model_selection import train_test_split
# from sklearn.metrics import auc

# if hasattr(sys.stdout, 'buffer'):
#     sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')

# # Training settings
# # parser = argparse.ArgumentParser()
# # parser.add_argument('--no_cuda', action='store_true', default=False, help='Disables CUDA training.')
# # parser.add_argument('--seed', type=int, default=46, help='Random seed.')
# # parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
# # parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.')
# # parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight decay (L2 loss on parameters).')
# # parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
# # parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate (1 - keep probability).')
# # parser.add_argument('--atlas', default='cc200', help='atlas for network construction (node definition) default: ho, see preprocessed-connectomes-project.org/abide/Pipelines.html for more options ')
# # parser.add_argument('--num_features', default=2000, type=int, help='Number of features to keep for the feature selection step (default: 2000)')
# # parser.add_argument('--folds', default=10, type=int, help='For cross validation, specifies which fold will be used. All folds are used if set to 11 (default: 11)')
# # parser.add_argument('--connectivity', default='correlation', help='Type of connectivity used for network construction (default: correlation, options: correlation, partial correlation, tangent)')
# # parser.add_argument('--max_degree', type=int, default=3, help='Maximum Chebyshev polynomial degree.')
# # parser.add_argument('--ngl', default=8, type=int, help='number of gcn hidden layders (default: 8)')
# # parser.add_argument('--edropout', type=float, default=0.3, help='edge dropout rate')
# # parser.add_argument('--train', default=1, type=int, help='train(default: 1) or evaluate(0)')
# # parser.add_argument('--ckpt_path', type=str, default='./pth', help='checkpoint path to save trained models')
# # parser.add_argument('--early_stopping', action='store_true', default=True, help='early stopping switch')
# # parser.add_argument('--early_stopping_patience', type=int, default=20, help='early stoppng epochs')

# # args = parser.parse_args()
# # args.cuda = not args.no_cuda and torch.cuda.is_available()

# # np.random.seed(args.seed)
# # torch.manual_seed(args.seed)
# # if args.cuda:
# #     torch.cuda.manual_seed(args.seed)
    
# # params = dict()
# # params['no_cuda'] = args.no_cuda
# # params['seed'] = args.seed
# # params['epochs'] = args.epochs
# # params['lr'] = args.lr
# # params['weight_decay'] = args.weight_decay
# # params['hidden'] = args.hidden
# # params['dropout'] = args.dropout
# # params['atlas'] = args.atlas
# # params['num_features'] = args.num_features
# # params['folds'] = args.folds
# # params['connectivity'] = args.connectivity
# # params['max_degree'] = args.max_degree
# # params['ngl'] = args.ngl
# # params['edropout'] = args.edropout
# # params['train'] = args.train
# # params['ckpt_path'] = args.ckpt_path
# # params['early_stopping'] = args.early_stopping
# # params['early_stopping_patience'] = args.early_stopping_patience

# class Args:
#     def __init__(self):
#         self.no_cuda = False
#         self.seed = 46
#         self.epochs = 200
#         self.lr = 0.001
#         self.weight_decay = 5e-5
#         self.hidden = 16
#         self.dropout = 0.2
#         self.atlas = 'cc400'
#         self.num_features = 2000
#         self.folds = 10
#         self.connectivity = 'correlation'
#         self.max_degree = 3
#         self.ngl = 8
#         self.edropout = 0.3
#         self.train = 0
#         self.ckpt_path ='./pth'
#         self.early_stopping = True
#         self.early_stopping_patience = 20

# # Instantiate Args class
# args = Args()

# # Check if CUDA is available
# args.cuda = not args.no_cuda and torch.cuda.is_available()

# # Set random seeds
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
# if args.cuda:
#     torch.cuda.manual_seed(args.seed)

# # Create params dictionary
# params = vars(args)

# # Print Hyperparameters
# print('Hyperparameters:')
# for key, value in params.items():
#     print(key + ":", value)

# corrects = np.zeros(args.folds, dtype=np.int32) 
# accs = np.zeros(args.folds, dtype=np.float32) 
# aucs = np.zeros(args.folds, dtype=np.float32)
# prfs = np.zeros([args.folds,3], dtype=np.float32) # Save Precision, Recall, F1
# test_num = np.zeros(args.folds, dtype=np.float32)

# print('  Loading dataset ...')
# dataloader = dataloader()
# raw_features, y, nonimg = dataloader.load_data(params) 
# cv_splits = dataloader.data_split(args.folds)
# features=raw_features

# t1 = time.time()

# for i in range(args.folds):
#     t_start = time.time()
#     train_ind, test_ind = cv_splits[i]

#     train_ind, valid_ind = train_test_split(train_ind, test_size=0.1, random_state = 24)
    
#     cv_splits[i] = (train_ind, valid_ind)
#     cv_splits[i] = cv_splits[i] + (test_ind,)
#     print('Size of the {}-fold Training, Validation, and Test Sets:{},{},{}' .format(i+1, len(cv_splits[i][0]), len(cv_splits[i][1]), len(cv_splits[i][2])))

#     if args.train == 1:
#         for j in range(args.folds):
#             print(' Starting the {}-{} Fold:：'.format(i+1,j+1))
#             node_ftr = dataloader.get_node_features(train_ind)
#             edge_index, edgenet_input = dataloader.get_WL_inputs(nonimg)
#             edgenet_input = (edgenet_input - edgenet_input.mean(axis=0)) / edgenet_input.std(axis=0)
            
#             model = GCN(input_dim = args.num_features,
#                         nhid = args.hidden, 
#                         num_classes = 2, 
#                         ngl = args.ngl, 
#                         dropout = args.dropout, 
#                         edge_dropout = args.edropout, 
#                         edgenet_input_dim = 2*nonimg.shape[1])
#             optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            
# #             if args.cuda:
#             model
#             features = torch.tensor(node_ftr, dtype=torch.float32)
#             edge_index = torch.tensor(edge_index, dtype=torch.long)
#             edgenet_input = torch.tensor(edgenet_input, dtype=torch.float32)
#             labels = torch.tensor(y, dtype=torch.long)
#             fold_model_path = args.ckpt_path + "/fold{}.pth".format(i+1)
                
#             acc = 0
#             best_val_loss = float('inf') # early stoppping: Initialized to positive infinity
#             current_patience = 0 # early stopping: Used to record the epochs of the current early stopping

#             for epoch in range(args.epochs):
#                 # train
#                 model.train()
#                 with torch.set_grad_enabled(True):
#                     optimizer.zero_grad()
#                     output, edge_weights = model(features, edge_index, edgenet_input)
#                     loss_train = torch.nn.CrossEntropyLoss()(output[train_ind], labels[train_ind])
#                     loss_train.backward()
#                     optimizer.step()
#                 acc_train = torchmetrics_accuracy(output[train_ind], labels[train_ind])
#                 auc_train = torchmetrics_auc(output[train_ind], labels[train_ind])
#                 logits_train = output[train_ind].detach().cpu().numpy()
#                 prf_train = prf(logits_train, y[train_ind])

                
#                 # valid
#                 model.eval()
#                 with torch.set_grad_enabled(False):
#                     output, edge_weights = model(features, edge_index, edgenet_input)
#                 loss_val = torch.nn.CrossEntropyLoss()(output[valid_ind], labels[valid_ind])
#                 acc_val = torchmetrics_accuracy(output[valid_ind], labels[valid_ind])
#                 auc_val = torchmetrics_auc(output[valid_ind], labels[valid_ind])
#                 logits_val = output[valid_ind].detach().cpu().numpy()
#                 prf_val = prf(logits_val, y[valid_ind])

                
#                 print('Epoch:{:04d}'.format(epoch+1))
#                 print('acc_train:{:.4f}'.format(acc_train),
#                       'pre_train:{:.4f}'.format(prf_train[0]),
#                       'recall_train:{:.4f}'.format(prf_train[1]),
#                       'F1_train:{:.4f}'.format(prf_train[2]),
#                       'AUC_train:{:.4f}'.format(auc_train))
#                 print('acc_val:{:.4f}'.format(acc_val),
#                       'pre_val:{:.4f}'.format(prf_val[0]),
#                       'recall_val:{:.4f}'.format(prf_val[1]),
#                       'F1_val:{:4f}'.format(prf_val[2]),
#                       'AUC_val:{:.4f}'.format(auc_val))
                
#                 # save pth
#                 if acc_val > acc and epoch > 50:
#                     acc = acc_val
#                     if args.ckpt_path != '':
#                         if not os.path.exists(args.ckpt_path):
#                             os.makedirs(args.ckpt_path)
#                         torch.save(model.state_dict(), fold_model_path)
                
#                 # Early Stopping
#                 if epoch > 50 and args.early_stopping == True:
#                     if loss_val < best_val_loss:
#                         best_val_loss = loss_val
#                         current_patience = 0
#                     else:
#                         current_patience += 1
#                     if current_patience >= args.early_stopping_patience:
#                         print('Early Stopping!!! epoch：{}'.format(epoch))
#                         break
                        
#         # test
#         print("Loading the Model for the {}-th Fold:... ...".format(i+1),
#               "Size of samples in the test set:{}".format(len(test_ind)))
#         model.load_state_dict(torch.load(fold_model_path))
#         model.eval()
        
#         with torch.set_grad_enabled(False):
#             output, edge_weights = model(features, edge_index, edgenet_input)
#         acc_test = torchmetrics_accuracy(output[test_ind], labels[test_ind])
#         auc_test = torchmetrics_auc(output[test_ind], labels[test_ind])
#         logits_test = output[test_ind].detach().cpu().numpy()
#         correct_test = correct_num(logits_test, y[test_ind])
#         prf_test =  prf(logits_test, y[test_ind])
        
#         t_end = time.time()
#         t = t_end - t_start
#         print('Fold {} Results:'.format(i+1),
#               'test acc:{:.4f}'.format(acc_test),
#               'test_pre:{:.4f}'.format(prf_test[0]),
#               'test_recall:{:.4f}'.format(prf_test[1]),
#               'test_F1:{:.4f}'.format(prf_test[2]),
#               'test_AUC:{:.4f}'.format(auc_test),
#               'time:{:.3f}s'.format(t))
        
#         correct = correct_test
#         aucs[i] = auc_test
#         prfs[i] = prf_test
#         corrects[i] = correct
#         test_num[i] = len(test_ind)
    
    
#     if args.train == 0:
#         node_ftr = dataloader.get_node_features(train_ind)
#         edge_index, edgenet_input = dataloader.get_WL_inputs(nonimg)
#         edgenet_input = (edgenet_input - edgenet_input.mean(axis=0)) / edgenet_input.std(axis=0)
        
#         model = GCN(input_dim = args.num_features,
#                     nhid = args.hidden, 
#                     num_classes = 2, 
#                     ngl = args.ngl, 
#                     dropout = args.dropout, 
#                     edge_dropout = args.edropout, 
#                     edgenet_input_dim = 2*nonimg.shape[1])
#         optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        
# #         if args.cuda
#         model
#         features = torch.tensor(node_ftr, dtype=torch.float)
#         edge_index = torch.tensor(edge_index, dtype=torch.long)
#         edgenet_input = torch.tensor(edgenet_input, dtype=torch.float32)
#         labels = torch.tensor(y, dtype=torch.long)
#         fold_model_path = args.ckpt_path + "/fold{}.pth".format(i+1)
        
#         model.load_state_dict(torch.load(fold_model_path))
#         model.eval()
        
#         with torch.set_grad_enabled(False):
#             output, edge_weights = model(features, edge_index, edgenet_input)
#         acc_test = torchmetrics_accuracy(output[test_ind], labels[test_ind])
#         auc_test = torchmetrics_auc(output[test_ind], labels[test_ind])
#         logits_test = output[test_ind].detach().cpu().numpy()
#         correct_test = correct_num(logits_test, y[test_ind])
#         prf_test =  prf(logits_test, y[test_ind])
        
#         t_end = time.time()
#         t = t_end - t_start
#         print('Fold {} Results:'.format(i+1),
#               'test acc:{:.4f}'.format(acc_test),
#               'test_pre:{:.4f}'.format(prf_test[0]),
#               'test_recall:{:.4f}'.format(prf_test[1]),
#               'test_F1:{:.4f}'.format(prf_test[2]),
#               'test_AUC:{:.4f}'.format(auc_test),
#               'time:{:.3f}s'.format(t))
        
#         correct = correct_test
#         aucs[i] = auc_test
#         prfs[i] = prf_test
#         corrects[i] = correct
#         test_num[i] = len(test_ind)

# t2 = time.time()

# print('\r\n======Finish Results for Nested 10-fold cross-validation======')
# Nested10kCV_acc = np.sum(corrects) / np.sum(test_num)
# Nested10kCV_auc = np.mean(aucs)
# Nested10kCV_precision, Nested10kCV_recall, Nested10kCV_F1 = np.mean(prfs, axis=0)
# print('Test:',
#       'acc:{}'.format(Nested10kCV_acc),
#       'precision:{}'.format(Nested10kCV_precision),
#       'recall:{}'.format(Nested10kCV_recall),
#       'F1:{}'.format(Nested10kCV_F1),
#       'AUC:{}'.format(Nested10kCV_auc))
# print('Total duration:{}'.format(t2 - t1))

