In [None]:
import pandas as pd
import os
import nibabel as nib
import pickle
import numpy as np
from nilearn.datasets import fetch_atlas_schaefer_2018
from nilearn.image import load_img
from scipy.stats import zscore
import torch
from torch_geometric.data import Data,InMemoryDataset
from random import randrange
import math
import zipfile
from joblib import Parallel, delayed
from tqdm import tqdm
import itertools
import torch
import NeuroGraph
from NeuroGraph import preprocess
import boto3
from pathos.multiprocessing import ProcessingPool as Pool
#from connectivity_matrices import KendallConnectivityMeasure
from nilearn.connectome import ConnectivityMeasure


def worker_function(args):
    # Unpack the arguments that were prepared for each task
    iid, BUCKET_NAME, volume,lag = args
    
    # Directly call the static processing method with the unpacked arguments
    return Brain_Connectome_Task_Download.get_data_obj_task(iid, BUCKET_NAME, volume,lag)



class Brain_Connectome_Task_Download(InMemoryDataset):
    
    def __init__(self, root, dataset_name,n_rois, threshold,path_to_data,n_jobs,s3,transform=None, pre_transform=None, pre_filter=None):
        self.root, self.dataset_name,self.n_rois,self.threshold,self.target_path,self.n_jobs,self.s3 = root, dataset_name,n_rois,threshold,path_to_data,n_jobs,s3
        super().__init__(root, transform, pre_transform, pre_filter)
        
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return [self.dataset_name+'.pt']

    @staticmethod
    def get_data_obj_task(iid,BUCKET_NAME,volume,lag):
        try:
            emotion_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_EMOTION_LR/tfMRI_EMOTION_LR.nii.gz"
            reg_emo_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_EMOTION_LR/Movement_Regressors.txt'

            gambling_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_GAMBLING_LR/tfMRI_GAMBLING_LR.nii.gz"
            reg_gamb_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_GAMBLING_LR/Movement_Regressors.txt'

            language_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_LANGUAGE_LR/tfMRI_LANGUAGE_LR.nii.gz"
            reg_lang_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_LANGUAGE_LR/Movement_Regressors.txt'

            motor_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_MOTOR_LR/tfMRI_MOTOR_LR.nii.gz"
            reg_motor_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_MOTOR_LR/Movement_Regressors.txt'

            relational_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_RELATIONAL_LR/tfMRI_RELATIONAL_LR.nii.gz"
            reg_rel_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_RELATIONAL_LR/Movement_Regressors.txt'

            social_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_SOCIAL_LR/tfMRI_SOCIAL_LR.nii.gz"
            reg_soc_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_SOCIAL_LR/Movement_Regressors.txt'

            wm_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_WM_LR/tfMRI_WM_LR.nii.gz"
            reg_wm_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_WM_LR/Movement_Regressors.txt'
            all_paths = [emotion_path,gambling_path,language_path,motor_path,relational_path,social_path,wm_path]
            reg_paths = [reg_emo_path,reg_gamb_path,reg_lang_path,reg_motor_path,reg_rel_path,reg_soc_path,reg_wm_path]
            data_list = []
            target_path = "data/raw/HCPState"
            
            for y, path in enumerate(all_paths):
                try:
                    ts_path = os.path.join(target_path+"/time_series_100", iid+"_"+os.path.basename(path).split(".")[0]+"_time_series.npy")
                    zd_Ytm = np.load(ts_path)
                    
                    #positive_threshold_percentage = 30
                    #positive_threshold_value = np.percentile(zd_Ytm, 100 - positive_threshold_percentage)
                    positive_threshold_value = 1
                    zd_Ytm[zd_Ytm < positive_threshold_value] = 0
                    #zd_Ytm[zd_Ytm >= positive_threshold_value] = 1
                    
                    conn = ConnectivityMeasure(kind='correlation')
                    #fc = conn.fit_transform([Ytm])[0]
                    zd_fc = conn.fit_transform([zd_Ytm])[0]
                    #fc *= np.tri(*fc.shape)
                    #np.fill_diagonal(fc, 0)

                    # zscored upper triangle
                    #zd_fc *= 1 - np.tri(*zd_fc.shape, k=-1)
                    np.fill_diagonal(zd_fc, 0)
                    corr = torch.tensor(zd_fc).to(torch.float)
                    A = Brain_Connectome_Task_Download.construct_Adj_postive_perc(corr,graph_threshold=5)
                    edge_index = A.nonzero().t().to(torch.long)
                
                    data = Data(x = corr, edge_index=edge_index, y = y)
                    data_list.append(data)

                except:
                    print("file skipped!") 
        except:
            return None   
        return data_list
    
    @staticmethod
    def extract_from_3d_no(volume, fmri):
        ''' 
        Extract time-series data from a 3d atlas with non-overlapping ROIs.
        
        Inputs:
            path_to_atlas = '/path/to/atlas.nii.gz'
            path_to_fMRI = '/path/to/fmri.nii.gz'
            
        Output:
            returns extracted time series # volumes x # ROIs
        '''
        subcor_ts = []
        for i in np.unique(volume):
            if i != 0: 
                bool_roi = np.zeros(volume.shape, dtype=int)
                bool_roi[volume == i] = 1
                bool_roi = bool_roi.astype(bool)
                roi_ts_mean = []
                for t in range(fmri.shape[-1]):
                    roi_ts_mean.append(np.mean(fmri[:, :, :, t][bool_roi]))
                subcor_ts.append(np.array(roi_ts_mean))
        Y = np.array(subcor_ts).T
        return Y

    @staticmethod
    def construct_Adj_postive_perc(corr,graph_threshold):
        corr_matrix_copy = corr.detach().clone()
        threshold = np.percentile(corr_matrix_copy[corr_matrix_copy > 0], 100 - graph_threshold)
        corr_matrix_copy[corr_matrix_copy < threshold] = 0
        corr_matrix_copy[corr_matrix_copy >= threshold] = 1
        return corr_matrix_copy
    
    def process(self):
        dataset = []
        BUCKET_NAME = 'hcp-openaccess'
        with open(os.path.join("data/","ids.pkl"),'rb') as f:
            ids = pickle.load(f)
        #ids = ids[:2]

        
        roi = fetch_atlas_schaefer_2018(n_rois=self.n_rois,yeo_networks=17, resolution_mm=2)
        atlas = load_img(roi['maps'])
        volume = atlas.get_fdata()
        #data_list = Parallel(n_jobs=self.n_jobs)(delayed(self.get_data_obj_task)(iid,BUCKET_NAME,volume) for iid in tqdm(ids))
        #print("length of data list:", len(data_list))
        tasks = [(iid, BUCKET_NAME, volume) for iid in ids]
        with Pool(self.n_jobs) as pool:
            data_list = pool.map(worker_function, tasks)       
        dataset = list(itertools.chain(*data_list))
        
        if self.pre_filter is not None:
            dataset = [data for data in dataset if self.pre_filter(data)]

        if self.pre_transform is not None:
            dataset = [self.pre_transform(data) for data in dataset]

        data, slices = self.collate(dataset)
        print("saving path:",self.processed_paths[0])
        torch.save((data, slices), self.processed_paths[0])  

 
ACCESS_KEY = ''  # your connectomeDB credentials
SECRET_KEY = ''
s3 = boto3.client('s3', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=SECRET_KEY)


In [3]:

class Brain_Connectome_Task_Download_lag(InMemoryDataset):
    
    def __init__(self, root, dataset_name,n_rois, threshold,path_to_data,n_jobs,s3, lag, transform=None, pre_transform=None, pre_filter=None):
        self.root, self.dataset_name,self.n_rois,self.threshold,self.target_path,self.n_jobs,self.s3,self.lag = root, dataset_name,n_rois,threshold,path_to_data,n_jobs,s3,lag
        super().__init__(root, transform, pre_transform, pre_filter)
        
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return [self.dataset_name+'.pt']

    @staticmethod
    def get_data_obj_task(iid,BUCKET_NAME,volume,lag):
        try:
            target_path = "data/raw/HCPState"
            emotion_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_EMOTION_LR/tfMRI_EMOTION_LR.nii.gz"
            reg_emo_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_EMOTION_LR/Movement_Regressors.txt'

            gambling_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_GAMBLING_LR/tfMRI_GAMBLING_LR.nii.gz"
            reg_gamb_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_GAMBLING_LR/Movement_Regressors.txt'

            language_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_LANGUAGE_LR/tfMRI_LANGUAGE_LR.nii.gz"
            reg_lang_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_LANGUAGE_LR/Movement_Regressors.txt'

            motor_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_MOTOR_LR/tfMRI_MOTOR_LR.nii.gz"
            reg_motor_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_MOTOR_LR/Movement_Regressors.txt'

            relational_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_RELATIONAL_LR/tfMRI_RELATIONAL_LR.nii.gz"
            reg_rel_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_RELATIONAL_LR/Movement_Regressors.txt'

            social_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_SOCIAL_LR/tfMRI_SOCIAL_LR.nii.gz"
            reg_soc_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_SOCIAL_LR/Movement_Regressors.txt'

            wm_path = "HCP_1200/"+iid+"/MNINonLinear/Results/tfMRI_WM_LR/tfMRI_WM_LR.nii.gz"
            reg_wm_path = "HCP_1200/" + iid + '/MNINonLinear/Results/tfMRI_WM_LR/Movement_Regressors.txt'
            all_paths = [emotion_path,gambling_path,language_path,motor_path,relational_path,social_path,wm_path]
            reg_paths = [reg_emo_path,reg_gamb_path,reg_lang_path,reg_motor_path,reg_rel_path,reg_soc_path,reg_wm_path]
            data_list = []
            
            for y, path in enumerate(all_paths):
                try:
                    
                    #save the zscored data
                    ts_path = os.path.join(target_path+"/time_series_100", iid+"_"+os.path.basename(path).split(".")[0]+"_time_series.npy")
                    zd_Ytm = np.load(ts_path)
                    
                    positive_threshold_percentage = 30
                    positive_threshold_value = np.percentile(zd_Ytm, 100 - positive_threshold_percentage)
                    #positive_threshold_value = 1
                    zd_Ytm[zd_Ytm < positive_threshold_value] = 0
                    
                    expanded_corr_matrix = Brain_Connectome_Task_Download_lag.construct_expanded_lagged_corr(zd_Ytm, lag)
                    
                    lag_corr = expanded_corr_matrix[0:400, 400:800]
                    #lag_corr = expanded_corr_matrix[0:100, 100:200]
                    np.fill_diagonal(lag_corr, 0)
                    
                    lag_corr_reverse = expanded_corr_matrix[400:800, 0:400]
                    #lag_corr_reverse = expanded_corr_matrix[100:200, 0:100]
                    np.fill_diagonal(lag_corr_reverse, 0)
                    
                    conn = ConnectivityMeasure(kind='correlation')
                    zd_fc = conn.fit_transform([zd_Ytm])[0]
                    np.fill_diagonal(zd_fc, 0)
                    corr_original = torch.tensor(zd_fc).to(torch.float)
                    A = Brain_Connectome_Task_Download_lag.construct_Adj_postive_perc(corr_original, graph_threshold=5)
                    edge_index = A.nonzero().t().to(torch.long)
                    concat_matrix = np.concatenate((zd_fc, lag_corr,lag_corr_reverse), axis=1)
                    corr = torch.tensor(concat_matrix).to(torch.float)
                    data = Data(x = corr, edge_index=edge_index, y = y)
                    data_list.append(data)

                except:
                    print("file skipped!") 
        except:
            return None   
        return data_list
    
    def extract_from_3d_no(self,volume, fmri):
        ''' 
        Extract time-series data from a 3d atlas with non-overlapping ROIs.
        
        Inputs:
            path_to_atlas = '/path/to/atlas.nii.gz'
            path_to_fMRI = '/path/to/fmri.nii.gz'
            
        Output:
            returns extracted time series # volumes x # ROIs
        '''
        subcor_ts = []
        for i in np.unique(volume):
            if i != 0: 
                bool_roi = np.zeros(volume.shape, dtype=int)
                bool_roi[volume == i] = 1
                bool_roi = bool_roi.astype(bool)
                roi_ts_mean = []
                for t in range(fmri.shape[-1]):
                    roi_ts_mean.append(np.mean(fmri[:, :, :, t][bool_roi]))
                subcor_ts.append(np.array(roi_ts_mean))
        Y = np.array(subcor_ts).T
        return Y

    @staticmethod
    def construct_Adj_postive_perc(corr,graph_threshold):
        corr_matrix_copy = corr.detach().clone()
        threshold = np.percentile(corr_matrix_copy[corr_matrix_copy > 0], 100 - graph_threshold)
        corr_matrix_copy[corr_matrix_copy < threshold] = 0
        corr_matrix_copy[corr_matrix_copy >= threshold] = 1
        return corr_matrix_copy
    
    @staticmethod
    def expand_time_series(time_series, lag):
        #time_series shape = (1200, 400) i.e. (timepoints, roi)
        expanded_ts = []
        num_time_points, num_rois = time_series.shape
        #ts_length = num_time_points - lag
        truncated_time_series = time_series[:-lag]
        lagged_time_series = time_series[lag:]
        expanded_ts.append(truncated_time_series)
        print("lagged_time_series", lagged_time_series.shape)
        print("truncated_time_series", truncated_time_series.shape)
        expanded_ts.append(lagged_time_series)
        return np.concatenate(expanded_ts, axis=1)

    @staticmethod
    def construct_expanded_correlation_matrix(expanded_ts):
        conn = ConnectivityMeasure(kind='correlation')
        corr_matrix = conn.fit_transform([expanded_ts])[0]
        np.fill_diagonal(corr_matrix, 0)
        return corr_matrix

    @staticmethod
    def construct_expanded_lagged_corr(time_series, lag):
        #for i in range(num_lag):
        expanded_ts = Brain_Connectome_Task_Download_lag.expand_time_series(time_series, lag)
        expanded_corr_matrix = Brain_Connectome_Task_Download_lag.construct_expanded_correlation_matrix(expanded_ts)
        print("expanded_corr_matrix", expanded_corr_matrix.shape)

        return expanded_corr_matrix
    
    
    def process(self):
        dataset = []
        BUCKET_NAME = 'hcp-openaccess'
        with open(os.path.join("data/","ids.pkl"),'rb') as f:
            ids = pickle.load(f)
        #ids = ids[:2]
        
        roi = fetch_atlas_schaefer_2018(n_rois=self.n_rois,yeo_networks=17, resolution_mm=2)
        atlas = load_img(roi['maps'])
        volume = atlas.get_fdata()
        lag = self.lag
        
        tasks = [(iid, BUCKET_NAME, volume,lag) for iid in ids]
        with Pool(self.n_jobs) as pool:
            data_list = pool.map(worker_function, tasks)

        #dataset = [x for x in data_list if x is not None]
        #data_list = Parallel(n_jobs=self.n_jobs)(delayed(self.get_data_obj_task)(iid,BUCKET_NAME,volume,lag) for iid in tqdm(ids))
        print("length of data list:", len(data_list))       
        dataset = list(itertools.chain(*data_list))
        
        if self.pre_filter is not None:
            dataset = [data for data in dataset if self.pre_filter(data)]

        if self.pre_transform is not None:
            dataset = [self.pre_transform(data) for data in dataset]

        data, slices = self.collate(dataset)
        print("saving path:",self.processed_paths[0])
        torch.save((data, slices), self.processed_paths[0])  


In [3]:
n_rois = 1000
root = "data/state_"+str(n_rois)+"/state_"+str(n_rois)+"_pearson/"
name = "HCPState"
threshold = 20
path_to_data = "data/raw/HCPState"
n_jobs = 20 
state_dataset_pearson = Brain_Connectome_Task_Download(root, name,n_rois, threshold,path_to_data,n_jobs,s3)

In [None]:
root = "data/state_"+str(n_rois)+"/state_"+str(n_rois)+"_spearman/"
state_dataset_spearman = Brain_Connectome_Task_Download(root, name,n_rois, threshold,path_to_data,n_jobs,s3)

In [None]:
root = "data/state_"+str(n_rois)+"/state_"+str(n_rois)+"_kendall/"
state_dataset_kendall = Brain_Connectome_Task_Download(root, name,n_rois, threshold,path_to_data,n_jobs,s3)

In [4]:
import torch
from torch_geometric.data import Data
from multiprocessing import Pool

# Function to create edge union
def create_edge_union(p_data, s_data, k_data):
    # Get all unique edges from P, S, and K
    all_edges = set(map(tuple, p_data.edge_index.t().tolist())) | \
                set(map(tuple, s_data.edge_index.t().tolist())) | \
                set(map(tuple, k_data.edge_index.t().tolist()))
    return torch.tensor(list(all_edges), dtype=torch.long).t()

# Function to process a single data set with edge union
def process_single_data(p_data, s_data, k_data, l_data):
    #lag_x = l_data.x[:, 400:1200]
    #lag_x = l_data.x[:, 100:300]
    #node_features = torch.cat([p_data.x, s_data.x, k_data.x, lag_x], dim=1)
    node_features = p_data.x
    
    # Create edge union
    edge_index_union = create_edge_union(p_data, s_data, k_data)
    
    # Initialize edge features (num_edges * 3)
    num_edges = edge_index_union.size(1)
    edge_features = torch.zeros((num_edges, 3))  # Initialize with zeros
    
    # Create edge sets for P, S, and K
    edge_set_p = set(map(tuple, p_data.edge_index.t().tolist()))
    edge_set_s = set(map(tuple, s_data.edge_index.t().tolist()))
    edge_set_k = set(map(tuple, k_data.edge_index.t().tolist()))
    
    # Fill in presence for edges in P, S, and K
    for i, edge in enumerate(edge_index_union.t()):
        edge_tuple = tuple(edge.tolist())
        if edge_tuple in edge_set_p:
            edge_features[i, 0] = 1  # First column for Pearson
        if edge_tuple in edge_set_s:
            edge_features[i, 1] = 1  # Second column for S
        if edge_tuple in edge_set_k:
            edge_features[i, 2] = 1  # Third column for K

    # Filter out edges not in Pearson (if needed)
    pearson_mask = edge_features[:, 0] == 1  # Filter by Pearson edges
    filtered_edge_index = edge_index_union[:, pearson_mask]
    filtered_edge_features = edge_features[pearson_mask]

    # Create new Data object with filtered edges
    new_data = Data(x=node_features, edge_index=filtered_edge_index, edge_attr=filtered_edge_features, y=p_data.y)
    return new_data

# Parallel processing for combined dataset
def create_combined_dataset_parallel(P, S, K, L, num_workers=20):
    with Pool(num_workers) as pool:
        results = pool.starmap(process_single_data, zip(P, S, K, L))
    return results

class PSKLL_PSK_Dataset(InMemoryDataset):

    def __init__(self, root, dataset_name, dataset, transform=None, pre_transform=None, pre_filter=None):
        self.root, self.dataset_name, self.dataset = root, dataset_name, dataset
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        #return [self.dataset_name + '_PSKLL_PSK.pt']
        return [self.dataset_name + '_PSK.pt']

    def process(self):
        gender_dataset = []
        for d in self.dataset:
            data = Data(x=d.x, edge_index=d.edge_index, edge_attr=d.edge_attr, y=d.y)
            gender_dataset.append(data)

        if self.pre_filter is not None:
            gender_dataset = [data for data in gender_dataset if self.pre_filter(data)]

        if self.pre_transform is not None:
            gender_dataset = [self.pre_transform(data) for data in gender_dataset]

        data, slices = self.collate(gender_dataset)
        print("saving path:", self.processed_paths[0])
        torch.save((data, slices), self.processed_paths[0])



In [None]:
lag = 1
#root = "data/state_"+str(n_rois)+"/state_"+str(n_rois)+"_thre30_"+str(lag)+"lag/"
#state_dataset_lag = Brain_Connectome_Task_Download_lag(root,name,n_rois, threshold,path_to_data,n_jobs,s3,lag)
state_dataset_lag = None
new_dataset = create_combined_dataset_parallel(state_dataset_pearson, state_dataset_spearman, state_dataset_kendall, state_dataset_lag)

print(new_dataset[0])

In [None]:
root = 'data/state_'+str(n_rois)+'/state_'+str(n_rois)+'_PSK/'
dataset = PSKLL_PSK_Dataset(root, name, new_dataset)
print(dataset[0])

Processing...


saving path: data/state_400/state_400_PSK/processed/HCPState_PSK.pt


Done!


Data(x=[400, 400], edge_index=[2, 7368], edge_attr=[7368, 3], y=[1])


In [None]:
#root = 'data/state_'+str(roi)+'/state_'+str(roi)+'_thre30_'+str(lag)+'lag_PSKLL_PSK/'
#dataset = PSKLL_PSK_Dataset(root, name, None)

In [None]:
import os
import pickle
import numpy as np
import torch
from torch_geometric.data import Data,InMemoryDataset
import logging

class PSKLL_PSK_Dataset(InMemoryDataset):

    def __init__(self, root,dataset_name, dataset,transform=None, pre_transform=None, pre_filter=None):
        self.root, self.dataset_name, self.dataset = root, dataset_name,dataset
        super().__init__(root, transform, pre_transform, pre_filter)

        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return [self.dataset_name+'_PSK.pt']

    def process(self):
        gender_dataset = []
        for d in self.dataset:
            data = Data(x= d.x, edge_index=d.edge_index, edge_attr=d.edge_attr, y=d.y)
            gender_dataset.append(data)
        if self.pre_filter is not None:
            gender_dataset = [data for data in gender_dataset if self.pre_filter(data)]

        if self.pre_transform is not None:
            gender_dataset = [self.pre_transform(data) for data in gender_dataset]

        data, slices = self.collate(gender_dataset)
        print("saving path:",self.processed_paths[0])
        torch.save((data, slices), self.processed_paths[0])
        

from torch.nn import Linear
from torch import nn
from torch_geometric.nn import global_max_pool
from torch_geometric.nn import aggr
import torch.nn.functional as F
from torch_geometric.nn import APPNP, MLP, GCNConv, GINConv, SAGEConv, GraphConv, TransformerConv, ChebConv, GATConv, SGConv, GeneralConv ,RGCNConv
from torch.nn import Conv1d, MaxPool1d, ModuleList
import math
import random
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader
import time
from torch.optim import Adam
from itertools import product
softmax = torch.nn.LogSoftmax(dim=1)

class Args:
    dataset = 'HCPState'
    runs = 3
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed = 123
    model_list = ["GCNConv"]
    hidden = 32
    hidden_mlp = 64
    num_layers = 3
    epochs = 100
    echo_epoch = 50
    batch_size = 16
    early_stopping = 50
    lr = 5e-4
    weight_decay = 0.0005
    dropout = 0.5
    edge_feature_dim = 3
args = Args()


path = "base_params_test_state/"
res_path = "results/"
root = "data/"
if not os.path.isdir(path):
    os.mkdir(path)
if not os.path.isdir(res_path):
    os.mkdir(res_path)
def logger(info):
    f = open(os.path.join(res_path, 'results_new.csv'), 'a')
    print(info, file=f)

torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, GCNConv
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.data import Data
from torch.nn import Linear, Conv1d, MaxPool1d, ModuleList


class GCNConvWithEdgeAttr(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_feature_dim):
        super(GCNConvWithEdgeAttr, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.edge_lin = torch.nn.Linear(edge_feature_dim, out_channels)

    def forward(self, x, edge_index, edge_attr):
        # Linearly transform node feature matrix.
        x = self.lin(x)

        # Normalize edge weights.
        row, col = edge_index
        deg = degree(row, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Propagate messages.
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, norm=norm)

    def message(self, x_j, edge_attr, norm):
        edge_attr = self.edge_lin(edge_attr)
        return norm.view(-1, 1) * (x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

class GNNsWithEdgeAttr(torch.nn.Module):
    def __init__(self, args, train_dataset, hidden_channels, num_layers, GNN, edge_feature_dim, k=0.6):
        super().__init__()
        if k < 1:  # Transform percentile to number.
            num_nodes = sorted([data.num_nodes for data in train_dataset])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = max(10, k)
        self.k = int(k)
        self.sort_aggr = aggr.SortAggregation(self.k)
        self.convs = ModuleList()
        self.convs.append(GCNConvWithEdgeAttr(train_dataset.num_features, hidden_channels, edge_feature_dim))
        for i in range(0, num_layers - 1):
            self.convs.append(GCNConvWithEdgeAttr(hidden_channels, hidden_channels, edge_feature_dim))
        self.convs.append(GCNConvWithEdgeAttr(hidden_channels, 1, edge_feature_dim))
        
        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1)
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.mlp = MLP([dense_dim, 32, args.num_classes], dropout=0.5, batch_norm=False)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.mlp.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        xs = [x]
        for conv in self.convs:
            xs += [conv(xs[-1], edge_index, edge_attr).tanh()]
        x = torch.cat(xs[1:], dim=-1)

        x = self.sort_aggr(x, batch)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = self.conv1(x).relu()
        x = self.maxpool1d(x)
        x = self.conv2(x).relu()
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]
        x = self.mlp(x)
        return x

class GNNsWithEdgeAttrGAT(torch.nn.Module):
    def __init__(self, args, train_dataset, hidden_channels, num_layers, edge_feature_dim, heads=1, k=0.6):
        super().__init__()
        if k < 1:  # Transform percentile to number.
            num_nodes = sorted([data.num_nodes for data in train_dataset])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = max(10, k)
        self.k = int(k)
        self.sort_aggr = aggr.SortAggregation(self.k)
        
        # Define the GAT convolution layers
        self.convs = ModuleList()
        self.convs.append(GATConv(train_dataset.num_features, hidden_channels, heads=heads, concat=False, edge_dim=edge_feature_dim))
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hidden_channels, hidden_channels, heads=heads, concat=False, edge_dim=edge_feature_dim))
        self.convs.append(GATConv(hidden_channels, 1, heads=1, concat=False, edge_dim=edge_feature_dim))

        # Define the 1D convolution and MLP layers
        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1)
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.mlp = MLP([dense_dim, 32, args.num_classes], dropout=0.5, batch_norm=False)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.mlp.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        xs = [x]
        for conv in self.convs:
            xs += [conv(xs[-1], edge_index, edge_attr=edge_attr).tanh()]
        x = torch.cat(xs[1:], dim=-1)

        x = self.sort_aggr(x, batch)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = self.conv1(x).relu()
        x = self.maxpool1d(x)
        x = self.conv2(x).relu()
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]
        x = self.mlp(x)
        return x


criterion = torch.nn.CrossEntropyLoss()
def train(train_loader):
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(args.device)
        out = model(data)  # Perform a single forward pass.
        loss = criterion(out, data.y)
        total_loss +=loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return total_loss/len(train_loader.dataset)

def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(args.device)
        out = model(data)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)

name = "HCPState"
threshold = 5
n_rois = 100
n_jobs = 15 # this script runs in parallel and requires the number of jobs is an input

from datetime import datetime

# Define the hyperparameter grid for grid search
log_filename = f"training_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

for roi in [100,400]:
    param_grid = {
        'model': ['GCNConv', 'GATConv', 'SAGEConv'],  # Model types to search
        'hidden': [32],               # Hidden layer sizes
        #'hidden_mlp': [32, 64, 128],         # Hidden layer sizes for MLP
        #'batch_size': [16, 32, 64],          # Batch sizes to try
        'num_layers': [3],          # Number of layers
        'lr': [5e-4],         # Learning rates to try
        'dropout': [0.5],            # Dropout rates
        'weight_decay': [0.0005],     # Weight decay values
    }
    root = 'data/state_'+str(roi)+'/state_'+str(roi)+'_PSK/'
    dataset = PSKLL_PSK_Dataset(root, name, None)
    labels = [d.y.item() for d in dataset]
    train_tmp, test_indices = train_test_split(list(range(len(labels))),
                            test_size=0.2, stratify=labels,random_state=123,shuffle= True)
    tmp = dataset[train_tmp]
    train_labels = [d.y.item() for d in tmp]
    train_indices, val_indices = train_test_split(list(range(len(train_labels))),
    test_size=0.125, stratify=train_labels,random_state=123,shuffle = True)
    train_dataset = tmp[train_indices]
    val_dataset = tmp[val_indices]
    test_dataset = dataset[test_indices]
    #print("dataset {} loaded with train {} val {} test {} splits".format(args.dataset,len(train_dataset), len(val_dataset), len(test_dataset)))
    logging.info("Dataset %s loaded with train %d val %d test %d splits", args.dataset, len(train_dataset), len(val_dataset), len(test_dataset))

    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=False)
    val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False)
    args.num_features,args.num_classes = dataset.num_features,dataset.num_classes


    for params in product(*param_grid.values()):
        args.model, args.hidden, args.num_layers, args.lr, args.dropout, args.weight_decay = params
        
        val_acc_history, test_acc_history, test_loss_history = [],[],[]
        seeds = [123,124,125,126,127,128,129,221,223,224,228,229]
        
        # Initialize model based on the current hyperparameter set
        if args.model == "GCNConv":
            model = GNNsWithEdgeAttr(
                args, train_dataset, hidden_channels=args.hidden,
                num_layers=args.num_layers, GNN=GCNConvWithEdgeAttr,
                edge_feature_dim=args.edge_feature_dim
            ).to(args.device)
        elif args.model == "GATConv":
            model = GNNsWithEdgeAttrGAT(
                args, train_dataset, hidden_channels=args.hidden,
                num_layers=args.num_layers, edge_feature_dim=args.edge_feature_dim,
                heads=4
            ).to(args.device)
            
        for index in range(args.runs):
            start = time.time()
            torch.manual_seed(seeds[index])
            if torch.cuda.is_available():
                torch.cuda.manual_seed(seeds[index])
            random.seed(seeds[index])
            np.random.seed(seeds[index])
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
        
            gnn = eval(args.model)            
            #print(model)
            logging.info(model)
            total_params = sum(p.numel() for p in model.parameters())
            #print(f"Total number of parameters is: {total_params}")
            logging.info("Total number of parameters is: %d", total_params)
            
            optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

            loss, test_acc = [],[]
            best_val_acc,best_val_loss = 0.0,0.0
            for epoch in range(args.epochs):
                loss = train(train_loader)
                val_acc = test(val_loader)
                test_acc = test(test_loader)
                if epoch%10==0:
                    logging.info("Epoch: %d, Loss: %.6f, Val Acc: %.2f, Test Acc: %.2f", epoch, loss.item(), val_acc, test_acc)
                    print("epoch: {}, loss: {}, val_acc:{}, test_acc:{}".format(epoch, np.round(loss.item(),6), np.round(val_acc,2),np.round(test_acc,2)))
                val_acc_history.append(val_acc)
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    
                    print("best val acc:",best_val_acc)
                    logging.info("Best Val Acc: %.2f", best_val_acc)
                    torch.save(model.state_dict(), path + args.dataset+args.model+'task-checkpoint-best-acc.pkl')
            model.load_state_dict(torch.load(path + args.dataset+args.model+'task-checkpoint-best-acc.pkl'))
            model.eval()
            test_acc = test(test_loader)
            test_loss = train(test_loader).item()
            test_acc_history.append(test_acc)
            test_loss_history.append(test_loss)
        
        end = time.time()

        log_msg = "model:{}, Hidden: {}, layers: {}, lr: {}, dropout:{}, weightdecay:{}, Loss: {:.4f}, Acc: {:.2f}, Std: {:.2f}, Running Time: {:.2f}".format(
        args.model, args.hidden, args.num_layers, args.lr, args.dropout, args.weight_decay, 
        np.mean(test_loss_history), np.mean(test_acc_history) * 100, np.std(test_acc_history) * 100, end - start
        )
        logging.info(log_msg)