In [None]:
import pickle
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statistics as stat
import time
%matplotlib inline

from sklearn.metrics import accuracy_score

from nilearn.connectome import ConnectivityMeasure

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import Data, DataLoader
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool

from sklearn import svm

import pickle5 as pickle
import os

from typing import Optional, Tuple
from torch_geometric.typing import Adj, OptTensor, PairTensor

# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes

In [None]:
def load_obj(path):
    with open(path + '.pkl', 'rb') as f:
        return pickle.load(f)

In [None]:
AD_dict = load_obj('AAL_data/timeseries/AD')
CN_dict = load_obj('AAL_data/timeseries/CN')

AD_train = load_obj('AAL_data/samples/AD_train')
AD_val = load_obj('AAL_data/samples/AD_val')
AD_test = load_obj('AAL_data/samples/AD_test')

CN_train = load_obj('AAL_data/samples/CN_train')
CN_val = load_obj('AAL_data/samples/CN_val')
CN_test = load_obj('AAL_data/samples/CN_test')

In [None]:
def get_correlation_matrix(timeseries,msr):
    correlation_measure = ConnectivityMeasure(kind=msr)
    correlation_matrix = correlation_measure.fit_transform([timeseries])[0]
    return correlation_matrix


def create_graph(timeseries, y, measure = 'correlation'):
    correlation_matrix = get_correlation_matrix(timeseries, measure)
    adj_mat = abs(correlation_matrix)

    G = nx.from_numpy_matrix(np.array(adj_mat), create_using=nx.DiGraph)
    data=torch_geometric.utils.from_networkx(G)
    data['x'] = torch.tensor(correlation_matrix, dtype=torch.float)
    data['y'] = torch.tensor([y])

    if torch.cuda.is_available():
        device = torch.device('cuda')
        data = data.to(device)
        return data
  
    return data

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GraphConv

num_classes = 2

class GNN(torch.nn.Module):
    def __init__(self, dim1, num_hidden_channels, hidden_channels_dims,initial_thr): ## 'initial_thr' new parameter
        super(GNN, self).__init__()
        torch.manual_seed(0)
        self.num_hidden_channels = num_hidden_channels
        self.hidden_channels_dims = hidden_channels_dims
        self.conv1 = GCNConv(dim1, hidden_channels_dims[0])
        self.conv2 = GCNConv(hidden_channels_dims[0], hidden_channels_dims[1])
        self.conv3 = GCNConv(hidden_channels_dims[1], hidden_channels_dims[2])
        self.lin1 = Linear(hidden_channels_dims[-1], 2)
        
        self.threshold = torch.nn.Parameter(torch.tensor(initial_thr, requires_grad=True))


    def forward(self, x1, edge_index1, edge_weight1, batch1):
        
        edge_weight1 = edge_weight1*torch.sigmoid(50*(edge_weight1 - self.threshold))
#         edge_weight1 = F.relu(edge_weight1 - self.threshold)
    
        x1= self.conv1(x1, edge_index1, edge_weight1)
        x1 = x1.relu()
#         x1 = x1.tanh()
        x1 = F.dropout(x1, p=0.5, training=self.training)
        
        x1 = self.conv2(x1, edge_index1, edge_weight1)
        x1 = x1.relu()
#         x1 = x1.tanh()
        x1 = F.dropout(x1, p=0.5, training=self.training)
        
        x1 = self.conv3(x1, edge_index1, edge_weight1)
        x1 = x1.relu()
#         x1 = x1.tanh()
        x1 = F.dropout(x1, p=0.5, training=self.training)

        x1 = global_mean_pool(x1, batch1)
        x1 = self.lin1(x1)
        x1 = torch.softmax(x1,dim=1)

        return x1


In [None]:
class EarlyStopping:
    """Early stops the training if validation acc doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation acc improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation acc improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float("inf")
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        
    def __call__(self, val_loss, model):

        score = val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score > self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
def train(model):
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.weight, data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(test_loader):
    model.eval()

    correct = 0
    for data in test_loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x, data.edge_index, data.weight, data.batch)  
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  
    return correct / len(test_loader.dataset)

def find_loss(loader):
    model.eval()
    func = torch.nn.CrossEntropyLoss(reduction='sum')
    loss = 0
    for i in loader:
        out = model(i.x, i.edge_index, i.weight, i.batch)
        loss += func(out, i.y).item()
    return loss/len(loader.dataset)

In [None]:
model = GNN(dim1 = 116, num_hidden_channels = 2, hidden_channels_dims = [32, 24,16], initial_thr = 0.0)
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))
(model.threshold.item() + 5)/2

In [None]:
AD_graph={}
CN_graph={}
for sub_id in (AD_dict):
    AD_graph[sub_id] = create_graph(timeseries=AD_dict[sub_id],y=0)
for sub_id in (CN_dict):
    CN_graph[sub_id] = create_graph(timeseries=CN_dict[sub_id],y=1)

In [None]:
num_hidden_channels = 3
hidden_channels_dims = [32,24,16]
activation = 'Relu'
Patience = 20
Dropout = 0.5 # Dropout removed
Learning_rate = 0.001

for init_th in [0.1,0.3,0.5,0.7,0.9]:
    
    start = time.time()
    
    
    final_train_accs = []
    final_test_accs = []
    final_val_accs = []

    final_ths = []
    
    seeds=100
    for i in range(1,seeds+1):
        # print(i)
        random.seed(i)
        np.random.seed(i)

        train_data = [AD_graph[sub_id] for sub_id in AD_train[i-1]] + [CN_graph[sub_id] for sub_id in CN_train[i-1]]
        train_label = [0]*len(AD_train[i-1]) + [1]*len(CN_train[i-1])
        train_data,train_label = shuffle(train_data, train_label, random_state=i)
        
        val_data = [AD_graph[sub_id] for sub_id in AD_val[i-1]] + [CN_graph[sub_id] for sub_id in CN_val[i-1]]
        val_label = [0]*len(AD_val[i-1]) + [1]*len(CN_val[i-1])
        val_data, val_label = shuffle(val_data, val_label, random_state=i)

        test_data = [AD_graph[sub_id] for sub_id in AD_test[i-1]] + [CN_graph[sub_id] for sub_id in CN_test[i-1]]
        test_label = [0]*len(AD_test[i-1]) + [1]*len(CN_test[i-1])
        test_data,test_label = shuffle(test_data, test_label, random_state=i)

        train_losses = []
        val_losses = []
        test_losses = []

        train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_data, batch_size=32, shuffle=True)

        model = GNN(dim1=116, num_hidden_channels=num_hidden_channels, hidden_channels_dims=hidden_channels_dims, initial_thr = init_th)
        if torch.cuda.is_available():
            device = torch.device('cuda')
            model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = torch.nn.CrossEntropyLoss()

        epochs=200
        
        os.makedirs(f"train_thr_VAL_LOSS_SIGMOID/{activation}/hidden:{hidden_channels_dims}/checkpoints_th:{init_th}", exist_ok=True)
        checkpt_path = f"train_thr_VAL_LOSS_SIGMOID/{activation}/hidden:{hidden_channels_dims}/checkpoints_th:{init_th}/checkpoint_seed_{i}.pt"
        early_stopping = EarlyStopping(patience=Patience, verbose=True, path=checkpt_path)
        

        
        for epoch in range(1,epochs+1):
            train(model)
            train_loss = find_loss(train_loader)
            val_loss = find_loss(val_loader)
            train_losses.append(train_loss)
            val_losses.append(val_loss)
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print(f"Early stopping at epoch:{epoch}")
#                 final_val_accs.append(early_stopping.val_acc_max)
                print(f'GCN graph classification Epoch: {epoch:03d}, Train Acc: {test(train_loader):.4f}, Val Acc: {test(val_loader):.4f}')
                break
            if epoch==epochs:
                print(f'GCN graph classification Epoch: {epoch:03d}, Train Acc: {test(train_loader):.4f}, Val Acc: {test(val_loader):.4f}')
#                 final_val_accs.append(early_stopping.val_acc_max)
        plt.plot(train_losses,label='train')
        plt.plot(val_losses,label='val')
        plt.legend()
        plt.show()
        
        
        
#         print('seed no.:',i)
#         print("GCN Final accuracy average = ", sum(final_val_accs)/len(final_val_accs))
#         print('\n\n')
        
        model = GNN(dim1=116, num_hidden_channels=num_hidden_channels, hidden_channels_dims=hidden_channels_dims, initial_thr = init_th)
        if torch.cuda.is_available():
            device = torch.device('cuda')
            model.to(device)
        checkpt_path = f"train_thr_VAL_LOSS_SIGMOID/{activation}/hidden:{hidden_channels_dims}/checkpoints_th:{init_th}"
        model.load_state_dict(torch.load(checkpt_path+f"/checkpoint_seed_{i}.pt"))
        final_train_accs.append(test(train_loader))
        final_val_accs.append(test(val_loader))
        final_test_accs.append(test(test_loader))
        final_ths.append(model.threshold.item())

    with open(f"train_thr_VAL_LOSS_SIGMOID/{activation}/hidden:{hidden_channels_dims}/checkpoints_th:{init_th}/report", "w") as file:
        file.write(f"num_hidden_channels={num_hidden_channels}\n"
                   f"hidden_channels_dims={hidden_channels_dims}\n"
                   f"Activation={activation}\n"
                   f"Patience={Patience}\n"
                   f"Dropout={Dropout}\n"
                   f"Learning_rate={Learning_rate}\n"
                   f"seeds={seeds}\n"
                   f"epochs={epochs}\n"
                   f"Initial_threshold={init_th}%\n"
                   f"train_avg_accuracy={stat.mean(final_train_accs)*100:0.2f}%\n"
                   f"train_accuracy_stdev={stat.stdev(final_train_accs)*100:0.2f}%\n"
                   f"val_avg_accuracy={stat.mean(final_val_accs)*100:0.2f}%\n"
                   f"val_accuracy_stdev={stat.stdev(final_val_accs)*100:0.2f}%\n"
                   f"test_avg_accuracy={stat.mean(final_test_accs)*100:0.2f}%\n"
                   f"test_accuracy_stdev={stat.stdev(final_test_accs)*100:0.2f}%\n"
                   f"threshold_avg={stat.mean(final_ths)}\n"
                   f"threshold_stdev={stat.stdev(final_ths)}\n"
                   f"threshold_min={min(final_ths)}\n"
                   f"threshold_max={max(final_ths)}\n")
    end = time.time()
    with open("time_report","a") as f:
        f.write(f"time_{init_th} = {end-start:0.2f}\n")

In [None]:
print("jhj")