In [None]:
from community import community_louvain

import pickle
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statistics as stat
import time
%matplotlib inline

from sklearn.metrics import accuracy_score, f1_score

from nilearn.connectome import ConnectivityMeasure

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import Data, DataLoader
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool

from sklearn import svm

import pickle5 as pickle
import os

from typing import Optional, Tuple
from torch_geometric.typing import Adj, OptTensor, PairTensor

# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes

In [None]:
def load_obj(path):
    with open(path + '.pkl', 'rb') as f:
        return pickle.load(f)

In [None]:
AD_dict = load_obj('AAL_data/timeseries/AD')
CN_dict = load_obj('AAL_data/timeseries/CN')

AD_train = load_obj('AAL_data/AD_train_full')
AD_val = load_obj('AAL_data/AD_val_full')
AD_test = load_obj('AAL_data/AD_test_full')

CN_train = load_obj('AAL_data/CN_train_full')
CN_val = load_obj('AAL_data/CN_val_full')
CN_test = load_obj('AAL_data/CN_test_full')

In [None]:
def get_avg_corr_mat(train_series):    
    all_correlation_matrices = []
    for timeseries in train_series:
        all_correlation_matrices.append(get_correlation_matrix(timeseries,'correlation'))

    avg_correlation_matrices = np.mean(np.array(all_correlation_matrices),axis=0)
    return avg_correlation_matrices

def partitions(avg_correlation_matrices,seed=None):
    correlation_matrix = avg_correlation_matrices
    G = nx.from_numpy_matrix(np.array(np.abs(correlation_matrix)),create_using=nx.Graph)
    partition = community_louvain.best_partition(G,random_state=seed)
    return partition

In [None]:
def get_correlation_matrix(timeseries,msr):
    correlation_measure = ConnectivityMeasure(kind=msr)
    correlation_matrix = correlation_measure.fit_transform([timeseries])[0]
    return correlation_matrix

def get_upper_triangular_matrix(matrix):
    upp_mat = []
    for i in range(len(matrix)):
        for j in range(i+1,len(matrix)):
            upp_mat.append(matrix[i][j])
    return upp_mat

    
# def get_adj_mat(correlation_matrix, threshold_value, weighted = True):
#     adj_mat = []
#     for i in correlation_matrix:
#         row = []
#         for j in i:
#             if abs(j)>threshold_value:
#                 if not weighted:
#                     row.append(1)
#                 else:
#                     row.append(abs(j))
#             else:
#                 row.append(0)
#         adj_mat.append(row)
#     return adj_mat


def get_adj_mat(correlation_matrix, th_inter, th_intra, partition):
    adj_mat = []
    k=0
    for i in range(len(correlation_matrix)):
        row = []
                   
        for j in range(len(correlation_matrix[0])):
                       
            if partition[i]==partition[j]:
                if abs(correlation_matrix[i][j])>th_inter:
                    row.append(abs(correlation_matrix[i][j]))
                else:
                    row.append(0)
            
            else:
                if abs(correlation_matrix[i][j])>th_intra:
                    row.append(abs(correlation_matrix[i][j]))
                else:
                    row.append(0)
                
    
        adj_mat.append(row)
    return adj_mat


def connect_isolated_nodes(adj_mat, correlation_matrix):
    correlation_matrix = list(np.array(correlation_matrix) - np.array(np.eye(len(correlation_matrix))))
    correlation_matrix = [list(a) for a in correlation_matrix]
                              
    for row_num in range(len(adj_mat)):
        if sum(adj_mat[row_num]) == 0:
            index_max_element_corr_row = correlation_matrix[row_num].index(max(correlation_matrix[row_num]))
            adj_mat[row_num][index_max_element_corr_row] = 1
    return adj_mat
    
# def get_threshold_value(ad_timeseires, cn_timeseries, measure, threshold_percent):
#     ad_corr_mats = [get_correlation_matrix(ts, measure) for ts in ad_timeseires]
#     cn_corr_mats = [get_correlation_matrix(ts, measure) for ts in cn_timeseries]

#     ad_upper = [get_upper_triangular_matrix(matrix) for matrix in ad_corr_mats]
#     cn_upper = [get_upper_triangular_matrix(matrix) for matrix in cn_corr_mats]

#     all_correlation_values = ad_upper + cn_upper
#     all_correlation_values = np.array(all_correlation_values).flatten()

#     all_correlation_values = np.array([abs(i) for i in all_correlation_values])
#     all_correlation_values = np.sort(all_correlation_values)[::-1]

#     th_val_index = (len(all_correlation_values)*threshold_percent)//100
#     return all_correlation_values[int(th_val_index)]


# def get_threshold_value(ad_timeseires, cn_timeseries, measure, threshold_percent):
#     ad_corr_mats = [get_correlation_matrix(ts, measure) for ts in ad_timeseires]
#     cn_corr_mats = [get_correlation_matrix(ts, measure) for ts in cn_timeseries]

#     ad_upper = [get_upper_triangular_matrix(matrix) for matrix in ad_corr_mats]
#     cn_upper = [get_upper_triangular_matrix(matrix) for matrix in cn_corr_mats]

#     all_correlation_values = ad_upper + cn_upper
#     all_correlation_values = np.array(all_correlation_values).flatten()
    
#     all_correlation_values_pos=[]
#     all_correlation_values_neg=[]
#     for i in all_correlation_values:
#         if i==1:
#             continue
#         elif i>0:
#             all_correlation_values_pos.append(i)
#         else:  
#             all_correlation_values_neg.append(abs(i))

#     all_correlation_values_pos = np.array(all_correlation_values_pos)
#     all_correlation_values_pos = np.sort(all_correlation_values_pos)[::-1]
    
#     all_correlation_values_neg = np.array(all_correlation_values_neg)
#     all_correlation_values_neg = np.sort(all_correlation_values_neg)[::-1]

#     th_val_index = (len(all_correlation_values)*threshold_percent)//100
    
#     return all_correlation_values_pos[int(th_val_index)], all_correlation_values_neg[int(th_val_index)]

def create_graph(timeseries, th_inter, th_intra, y, partition, measure='correlation'):
    correlation_matrix = get_correlation_matrix(timeseries, measure)
    adj_mat = get_adj_mat(correlation_matrix, th_inter, th_intra, partition)

    G = nx.from_numpy_matrix(np.array(adj_mat), create_using=nx.DiGraph)
    data=torch_geometric.utils.from_networkx(G)
    data['x'] = torch.tensor(correlation_matrix, dtype=torch.float)
    data['y'] = torch.tensor([y])

    if torch.cuda.is_available():
        device = torch.device('cuda:1')
        data = data.to(device)
        return data
  
    return data

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GraphConv

num_classes = 2

class GNN(torch.nn.Module):
    def __init__(self, dim1, num_hidden_channels, hidden_channels_dims):
        super(GNN, self).__init__()
        torch.manual_seed(0)
        self.num_hidden_channels = num_hidden_channels
        self.hidden_channels_dims = hidden_channels_dims
        self.conv1 = GCNConv(dim1, hidden_channels_dims[0])
        self.conv2 = GCNConv(hidden_channels_dims[0], hidden_channels_dims[1])
        self.conv3 = GCNConv(hidden_channels_dims[1], hidden_channels_dims[2])
        self.lin1 = Linear(hidden_channels_dims[-1], 2)

    def forward(self, x1, edge_index1, edge_weight1, batch1):

        x1= self.conv1(x1, edge_index1, edge_weight1)
        x1 = x1.relu()
#         x1 = x1.tanh()
        x1 = F.dropout(x1, p=0.5, training=self.training)
        
        x1 = self.conv2(x1, edge_index1, edge_weight1)
        x1 = x1.relu()
#         x1 = x1.tanh()
        x1 = F.dropout(x1, p=0.5, training=self.training)
        
        x1 = self.conv3(x1, edge_index1, edge_weight1)
        x1 = x1.relu()
#         x1 = x1.tanh()
        x1 = F.dropout(x1, p=0.5, training=self.training)

        x1 = global_mean_pool(x1, batch1)
        x1 = self.lin1(x1)
        x1 = torch.softmax(x1,dim=1)
        
        return x1

In [None]:
class EarlyStopping:
    """Early stops the training if validation acc doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation acc improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation acc improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_acc_max = 0
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        
    def __call__(self, val_acc, model):

        score = val_acc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        '''Saves model when validation acc increase.'''
        if self.verbose:
            self.trace_func(f'Validation acc increased ({self.val_acc_max:.6f} --> {val_acc:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_acc_max = val_acc

In [None]:
def train(model):
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.weight, data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(test_loader):
    model.eval()

    correct = 0
    for data in test_loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x, data.edge_index, data.weight, data.batch)  
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  
    return correct / len(test_loader.dataset)

def f1(loader):
    model.eval()

    y_true = []
    y_pred = []

    for data in loader:
        out = model(data.x, data.edge_index, data.weight, data.batch)
        pred = out.argmax(dim=1)
        y_pred += pred.cpu().detach().tolist()
        y_true += data.y.cpu().detach().tolist()

    return f1_score(y_true,y_pred)

In [None]:
def avg_acc(model,checkpt_path,loader,seeds):
    if torch.cuda.is_available():
        device = torch.device('cuda:1')
        model.to(device)
    accs = []
    for i in range(1,seeds+1):
        checkpt_path = checkpt_path
        model.load_state_dict(torch.load(checkpt_path+f"/checkpoint_seed_{i}.pt"))
        x = test(loader)
        accs.append(x)
    return accs

In [None]:
model = GNN(dim1 = 116, num_hidden_channels = 3, hidden_channels_dims = [32, 16,8])
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

In [None]:
num_hidden_channels = 3
hidden_channels_dims = [32,24,16]
activation = 'ReLU'
Patience = 200
Dropout = 0.5 # Dropout
Learning_rate = 0.001


avg_correlation_matrix = get_avg_corr_mat(list(AD_dict.values()) + list(CN_dict.values()))
partition = partitions(avg_correlation_matrix)

for th_inter in [0.1,0.3,0.5,0.7,0.9]:
    for th_intra in [0.1,0.3,0.5,0.7,0.9]:
        start = time.time()

        AD_graph={}
        CN_graph={}
        for sub_id in (AD_dict):
            AD_graph[sub_id] = create_graph(timeseries=AD_dict[sub_id], th_inter=th_inter, th_intra=th_intra, y=0, partition=partition)
        for sub_id in (CN_dict):
            CN_graph[sub_id] = create_graph(timeseries=CN_dict[sub_id], th_inter=th_inter, th_intra=th_intra, y=1, partition=partition)


        final_train_accs = []
        final_test_accs = []
        final_val_accs = []
        all_train_accs = []
        all_val_accs = []
        all_test_accs = []
        final_train_f1s = []
        final_val_f1s = []
        final_test_f1s = []
    

        seeds=100
        for i in range(1,seeds+1):
            # print(i)
            random.seed(i)
            np.random.seed(i)
            
            
            AD_train_data = [AD_graph[sub_id] for sub_id in AD_train[i-1]]
            CN_train_data = [CN_graph[sub_id] for sub_id in CN_train[i-1]]
            residual = len(CN_train_data) - 2*len(AD_train_data)
            extra = random.sample(AD_train_data, residual)
            AD_train_data = AD_train_data*2 + extra

            train_data = AD_train_data + CN_train_data
            train_label = [0]*len(AD_train_data) + [1]*len(CN_train_data)
            train_data,train_label = shuffle(train_data, train_label, random_state=i)

            val_data = [AD_graph[sub_id] for sub_id in AD_val[i-1]] + [CN_graph[sub_id] for sub_id in CN_val[i-1]]
            val_label = [0]*len(AD_val[i-1]) + [1]*len(CN_val[i-1])
            val_data, val_label = shuffle(val_data, val_label, random_state=i)

            test_data = [AD_graph[sub_id] for sub_id in AD_test[i-1]] + [CN_graph[sub_id] for sub_id in CN_test[i-1]]
            test_label = [0]*len(AD_test[i-1]) + [1]*len(CN_test[i-1])
            test_data,test_label = shuffle(test_data, test_label, random_state=i)

            train_losses = []
            val_losses = []
            test_losses = []
            
            train_f1s = []
            val_f1s = []
            test_f1s = []
            
            train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
            val_loader = DataLoader(val_data, batch_size=32, shuffle=True)
            test_loader = DataLoader(test_data, batch_size=32, shuffle=True)

            model = GNN(dim1=116, num_hidden_channels=num_hidden_channels, hidden_channels_dims=hidden_channels_dims)
            if torch.cuda.is_available():
                device = torch.device('cuda:1')
                model.to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            criterion = torch.nn.CrossEntropyLoss()
            epochs=200

            os.makedirs(f"manual_community_f1_OVERSAMPLE/{activation}/hidden:{hidden_channels_dims}/th_inter:{th_inter},th_intra:{th_intra}", exist_ok=True)
            checkpt_path = f"manual_community_f1_OVERSAMPLE/{activation}/hidden:{hidden_channels_dims}/th_inter:{th_inter},th_intra:{th_intra}/checkpoint_seed_{i}.pt"
            early_stopping = EarlyStopping(patience=Patience, verbose=True, path=checkpt_path)

            for epoch in range(1,epochs+1):
                train(model)
                train_f1 = f1(train_loader)
                val_f1 = f1(val_loader)
                train_f1s.append(train_f1)
                val_f1s.append(val_f1)
                early_stopping(val_f1, model)
                if early_stopping.early_stop:
                    print(f"Early stopping at epoch:{epoch}")
    #                 final_val_accs.append(early_stopping.val_acc_max)
                    print(f'GCN graph classification Epoch: {epoch:03d}, Train f1: {train_f1:.4f}, Val f1: {early_stopping.val_acc_max:.4f}')
                    break
                if epoch==epochs:
                    print(f'GCN graph classification Epoch: {epoch:03d}, Train f1: {train_f1:.4f}, Val f1: {early_stopping.val_acc_max:.4f}')
    #                 final_val_accs.append(early_stopping.val_acc_max)

            plt.plot(train_f1s,label='train')
            plt.plot(val_f1s,label='val')
            plt.legend()
            plt.show()



    #         all_train_accs.append(train_accs)
    #         all_val_accs.append(val_accs)


#             print('seed no.:',i)
#             print("GCN Final accuracy average = ", sum(final_val_accs)/len(final_val_accs))
#             print('\n\n')
            
            model = GNN(dim1=116, num_hidden_channels=num_hidden_channels, hidden_channels_dims=hidden_channels_dims)
            if torch.cuda.is_available():
                device = torch.device('cuda:1')
                model.to(device)
            checkpt_path = f"manual_community_f1_OVERSAMPLE/{activation}/hidden:{hidden_channels_dims}/th_inter:{th_inter},th_intra:{th_intra}"
            model.load_state_dict(torch.load(checkpt_path+f"/checkpoint_seed_{i}.pt"))
            
            final_train_accs.append(test(train_loader))
            final_test_accs.append(test(test_loader))
            final_val_accs.append(test(val_loader))

            final_train_f1s.append(f1(train_loader))
            final_test_f1s.append(f1(test_loader))
            final_val_f1s.append(f1(val_loader))


        with open(f"manual_community_f1_OVERSAMPLE/{activation}/hidden:{hidden_channels_dims}/th_inter:{th_inter},th_intra:{th_intra}/report", "w") as file:
            file.write(f"num_hidden_channels={num_hidden_channels}\n"
                       f"hidden_channels_dims={hidden_channels_dims}\n"
                       f"Activation={activation}\n"
                       f"Patience={Patience}\n"
                       f"Dropout={Dropout}\n"
                       f"Learning_rate={Learning_rate}\n"
                       f"seeds={seeds}\n"
                       f"epochs={epochs}\n"
                       f"Threshold inter={th_inter}\n"
                       f"threshold intra={th_intra}\n"
                       f"train_avg_accuracy={stat.mean(final_train_accs)*100:0.2f}%\n"
                       f"train_accuracy_stdev={stat.stdev(final_train_accs)*100:0.2f}%\n"
                       f"train_avg_f1={stat.mean(final_train_f1s):0.2f}\n"
                       f"train_f1_stdev={stat.stdev(final_train_f1s):0.2f}\n"
                       f"val_avg_accuracy={stat.mean(final_val_accs)*100:0.2f}%\n"
                       f"val_accuracy_stdev={stat.stdev(final_val_accs)*100:0.2f}%\n"
                       f"val_avg_f1={stat.mean(final_val_f1s):0.2f}\n"
                       f"val_f1_stdev={stat.stdev(final_val_f1s):0.2f}\n"
                       f"test_avg_accuracy={stat.mean(final_test_accs)*100:0.2f}%\n"
                       f"test_accuracy_stdev={stat.stdev(final_test_accs)*100:0.2f}%\n"
                       f"test_avg_f1={stat.mean(final_test_f1s):0.2f}\n"
                       f"test_f1_stdev={stat.stdev(final_test_f1s):0.2f}\n")
        end = time.time()
        with open("time_report","a") as f:
            f.write(f"time_th_inter:{th_inter},th_intra:{th_intra} = {end-start:0.2f}\n")