### Loading Data

In [105]:
import os
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv, BatchNorm, global_mean_pool, global_add_pool, global_max_pool
from sklearn.model_selection import train_test_split
# from sklearn.metrics import accuracy_score
from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import KFold
import networkx as nx
from torch_geometric.utils import from_networkx

# Define folder path and TSV file containing labels
participants_xlsx_file = '.\\data\\participants.xsls'
participants_tsv_file = '.\\data\\participants.tsv'
node2vec_folder = ".\\data\\graph_embeddings_4"
graph_shapes_file = ".\\data\\graph_shapes"

# Load TSV file
df_labels = pd.read_csv(participants_tsv_file, sep='\t')

# Filter out rows where 'group' is 'n/a'
df_labels = df_labels[df_labels['group'].notna() & (df_labels['group'] != 'n/a')]

# Extract labels as a dictionary and convert to integers
labels_dict = dict(zip(df_labels['participant_id'], df_labels['group'].astype(int)))

# Function to load data
def load_data(node2vec_folder, labels_dict):
    data_list = []  # List to store loaded data
    missing_files = []  # List to track missing files
    processed_subs = []  # List to track processed subjects

    for sub_id, label in labels_dict.items():
        # Correct the format of the subject ID to match the file format
        sub_id_corrected = sub_id.replace('sub-', 'Sub')
        csv_file = os.path.join(node2vec_folder, f"Node2Vec_PCMCI_{sub_id_corrected}_Harvard.csv")
        gml_file = os.path.join(graph_shapes_file, f"GNN input_PCMCI_{sub_id_corrected}_Harvard.gml")

        # Check if the file exists
        if os.path.exists(csv_file) and os.path.exists(gml_file):
            # Load node features from CSV
            node_features = pd.read_csv(csv_file, header=0, index_col=0).values
            num_nodes, num_features = node_features.shape

            # Load the graph structure from GML file
            nx_graph = nx.read_gml(gml_file)
            data = from_networkx(nx_graph)
            edge_index = data.edge_index

            # Create a complete graph for edge_index
            # edge_index = torch.combinations(torch.arange(num_nodes), 2).t()
            
            # Convert data to PyTorch tensors
            x = torch.tensor(node_features, dtype=torch.float)
            y = torch.tensor([label - 1], dtype=torch.long)  # Convert labels to 0 and 1
            
            # Create a Data object for PyTorch Geometric
            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)
            processed_subs.append(sub_id)
        else:
            missing_files.append(sub_id)  # Add missing subject ID to the list

    return data_list

# Load the data
data_list = load_data(node2vec_folder, labels_dict)

# Check the number of loaded data samples
print(f"Number of loaded samples: {len(data_list)}")

# If no data was loaded, raise an error
if len(data_list) == 0:
    raise ValueError("No data samples were loaded. Please check the file path and format.")




Number of loaded samples: 134


## Early stopping

In [106]:
# Define function for early stopping
def early_stopping(val_accuracies, patience):
    if len(val_accuracies) > patience:
        recent_accuracies = val_accuracies[-patience:]
        if max(recent_accuracies) <= val_accuracies[-patience-1]:
            return True
    return False

## Hyperparameters

In [107]:

# Hyperparameter search with Grid Search
param_grid = {
    'hidden_channels': [16, 32, 64],
    'dropout': [0.2, 0.4],
    'lr': [0.01, 0.0001],
    'weight_decay': [1e-2, 1e-4]
}
grid = ParameterGrid(param_grid)

## Evaluation

In [108]:
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score
from sklearn.metrics import confusion_matrix, average_precision_score, mean_squared_error, precision_recall_curve, auc as sklearn_auc


def compute_metrics(preds, labels):

    labels = labels.cpu().detach().numpy()
    preds = preds.cpu().detach().numpy()
    preds_binary = (preds > 0.5).astype(int)
    
    # print(labels)
    # print(preds)
    # print(preds_binary)

    accuracy = accuracy_score(labels, preds_binary)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds_binary, average='binary', zero_division=0)
    auc = roc_auc_score(labels, preds)
    mse = mean_squared_error(labels, preds)  # MSE computation
    conf_matrix = confusion_matrix(labels, preds_binary)
    average_precision = average_precision_score(labels, preds)
    # tn, fp, fn, tp = conf_matrix.ravel()
    # specificity = tn / (tn + fp)

    # Compute precision-recall curve
    # precision_curve, recall_curve, _ = precision_recall_curve(labels, preds)
    # auprc = sklearn_auc(recall_curve, precision_curve)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc": auc,
        "accuracy": accuracy,
        # "specificity": specificity,
        "conf_matrix": conf_matrix,
        "ap": average_precision,
        "mse": mse,
        # "auprc":auprc
    }


# Evaluation function to calculate accuracy
def evaluate(loader, model):
    model.eval()
    with torch.no_grad():
        for data in loader:
            out = model(data)
            res = compute_metrics(out, data.y)
            
    return res


In [109]:
import csv
import datetime
import numpy as np
import matplotlib.pyplot as plt


current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

def save_Results(fileName, test_metrics):

    with open(f'./results/test_metrics_{fileName}_{current_time}.csv', mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write the header
        writer.writerow(["Metric", "Value"])
        

        # Write each metric and its value
        for metric, value in test_metrics.items():
            writer.writerow([metric, value])


def plot_average_auc(fileName, train_aucs_all, val_aucs_all):
    # Convert lists of lists to numpy arrays for easier manipulation
    train_aucs_all = np.array(train_aucs_all)
    val_aucs_all = np.array(val_aucs_all)
    
    # Compute mean and standard deviation across folds for each epoch
    mean_train_aucs = np.mean(train_aucs_all, axis=0)
    # std_train_aucs = np.std(train_aucs_all, axis=0)

    mean_val_aucs = np.mean(val_aucs_all, axis=0)
    # std_val_aucs = np.std(val_aucs_all, axis=0)

    epochs = range(1, len(mean_train_aucs) + 1)

    # Plot the average AUC curves with shaded areas for standard deviation
    plt.plot(epochs, mean_train_aucs, label='Average Train AUC', color='blue')
    # plt.fill_between(epochs, mean_train_aucs - std_train_aucs, mean_train_aucs + std_train_aucs, color='blue', alpha=0.2)

    plt.plot(epochs, mean_val_aucs, label='Average Validation AUC', color='orange')
    # plt.fill_between(epochs, mean_val_aucs - std_val_aucs, mean_val_aucs + std_val_aucs, color='orange', alpha=0.2)

    # Add labels, title, and legend
    plt.xlabel('Epochs')
    plt.ylabel('AUC')
    plt.title('Average AUC Curves Across Folds')
    plt.legend(loc='best')
    plt.grid(True)
    plt.savefig(f'./results/AverageAUC_{fileName}_{current_time}.png')
    plt.close()    


## Different GNN Models

In [110]:
class GATModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.3):
        super(GATModel, self).__init__()
        
        # Define 3 GATConv layers with residual connections and batch normalization
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        self.bn1 = BatchNorm(hidden_channels * heads)
        self.res_proj1 = nn.Linear(in_channels, hidden_channels * heads)
        
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout)
        self.bn2 = BatchNorm(hidden_channels * heads)
        
        self.conv3 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout)
        self.bn3 = BatchNorm(out_channels)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # First layer with residual connection
        res1 = x
        res1 = self.res_proj1(res1)
        x = self.conv1(x, edge_index)
        x = self.bn1(x).relu()
        x = self.dropout(x) + res1

        # Second layer with residual connection
        res2 = x
        x = self.conv2(x, edge_index)
        x = self.bn2(x).relu()
        x = self.dropout(x) + res2

        x = self.conv3(x, edge_index)
        x = self.bn3(x)

        # Global mean pooling
        x = global_mean_pool(x, batch)
        x = torch.sigmoid(x)
        x = x.squeeze(-1)
        return x

class GCNModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout):
        super(GCNModel, self).__init__()

        # First GCN layer with residual connection
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.bn1 = BatchNorm(hidden_channels)
        self.res_proj1 = nn.Linear(in_channels, hidden_channels)
        
        # Second GCN layer
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.bn2 = BatchNorm(hidden_channels)
        
        # Third GCN layer with residual connection
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.bn3 = BatchNorm(hidden_channels)
        self.res_proj3 = nn.Linear(hidden_channels, hidden_channels)

        # Final GCN layer to map to output channels
        self.conv_out = GCNConv(hidden_channels, out_channels)
        self.bn_out = BatchNorm(out_channels)

        # Fully connected layers after pooling
        self.fc1 = nn.Linear(out_channels, hidden_channels)
        self.fc2 = nn.Linear(hidden_channels, out_channels)

        # Dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # First GCN layer with residual connection
        res1 = self.res_proj1(x)
        x = self.conv1(x, edge_index).relu()
        x = self.bn1(x).relu() + res1  # Residual connection
        x = self.dropout(x)

        # Second GCN layer
        x = self.conv2(x, edge_index).relu()
        x = self.bn2(x).relu()
        x = self.dropout(x)
        
        # Third GCN layer with residual connection
        res3 = self.res_proj3(x)
        x = self.conv3(x, edge_index).relu()
        x = self.bn3(x).relu() + res3  # Residual connection
        x = self.dropout(x)
        
        # Final GCN layer
        x = self.conv_out(x, edge_index).relu()
        x = self.bn_out(x).relu()
        x = self.dropout(x)

        # Global pooling
        x = global_mean_pool(x, batch)

        # Fully connected layers
        x = self.fc1(x).relu()
        x = self.dropout(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)

        return x.squeeze(-1)



## Rum Model

In [111]:
def run_model(model_name):

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

        
    train_data, test_data = train_test_split(data_list, test_size=0.1, random_state=42)
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)
    
    # Define k-fold cross-validation
    K = 5  # Number of folds
    kf = KFold(n_splits=K, shuffle=True, random_state=42)

    # Define early stopping parameters
    PATIENCE = 5

    # Store overall best results
    best_overall_train_auc = 0.0
    best_overall_test_auc = 0.0
    best_overall_params = None

    # Hyperparameter tuning with k-fold cross-validation
    for params in grid:
        print(f"\nTesting parameters: {params}")
        
        # Store fold results
        fold_train_aucs = []
        fold_val_aucs = []

        # k-fold cross-validation
        for fold, (train_index, val_index) in enumerate(kf.split(train_data)):
            print(f"\nFold {fold + 1}/{K}")
            # Split data into training and test sets for the current fold
            fold_train_data = [data_list[i] for i in train_index]
            fold_val_data = [data_list[i] for i in val_index]

            # print('/////fold_train_data')
            # print(fold_train_data)
            # print(fold_train_data[0])
            # print(len(fold_train_data))

            # Create DataLoader for training and test sets
            train_loader = DataLoader(fold_train_data, batch_size=len(fold_train_data), shuffle=False)
            val_loader = DataLoader(fold_val_data, batch_size=len(fold_val_data), shuffle=False)
            
            # Initialize the model with current hyperparameters
            model = model_name(in_channels=data_list[0].num_features,
                            hidden_channels=params['hidden_channels'],
                            out_channels=1,
                            dropout=params['dropout'])

            # Define loss function and optimizer with L2 regularization
            criterion = nn.BCEWithLogitsLoss()
            optimizer = torch.optim.Adam(model.parameters(),
                                        lr=params['lr'],
                                        weight_decay=params['weight_decay'])

            # Train the GCN model for the current fold
            num_epochs = 100
            train_aucs = []
            val_aucs = []

            

            for epoch in range(num_epochs):
                model.train()
                total_loss = 0
                for data in train_loader:
                    optimizer.zero_grad()
                    out = model(data)
                    loss = criterion(out, data.y.float())
                    loss.backward()
                    optimizer.step()
                    total_loss = loss.item()

                # Calculate training and validation metrics for each epoch
                train_metrics = evaluate(train_loader, model)
                train_aucs.append(train_metrics['auc'])

                val_metrics = evaluate(val_loader, model)
                val_aucs.append(val_metrics['auc'])
        
                if epoch % 10 == 0:
                    print(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}, "
                        f"Train Acc: {train_metrics['accuracy']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")

                # Check for early stopping
                if early_stopping(val_aucs, PATIENCE):
                    print("Early stopping triggered...")
                    break
            
            
            fold_train_aucs.append(train_aucs)
            fold_val_aucs.append(val_aucs)

        fileName = f"lr_{params['lr']}_wd_{params['weight_decay']}_drout_{params['dropout']}_hidChannels{params['hidden_channels']}"
        # plot_average_auc(fileName, fold_train_aucs, fold_val_aucs)

        # Calculate average train and val auc across all folds
        avg_train_auc = np.mean([item for sublist in fold_train_aucs for item in sublist])
        avg_val_auc = np.mean([item for sublist in fold_val_aucs for item in sublist])

        print(f"\nAverage Train AUC with current parameters: {avg_train_auc:.4f}")
        print(f"Average Validation AUC with current parameters: {avg_val_auc:.4f}")

        # Update best overall results
        if avg_val_auc > best_overall_test_auc:
            best_overall_test_auc = avg_val_auc
            best_overall_train_auc = avg_train_auc
            best_overall_params = params

        



        # ///////////
        # Calculate final train and test metrics for the current fold
        final_test_metrics = evaluate(test_loader, model)
        save_Results(fileName, final_test_metrics)



    print(f"\nBest overall parameters: {best_overall_params}")
    print(f"Best overall train accuracy: {best_overall_train_auc:.4f}")
    print(f"Best overall test accuracy: {best_overall_test_auc:.4f}")


In [None]:
model_dict = {
    "GATModel": GATModel,
    "GCNModel": GCNModel
}

# Prompt the user for the model name
model_name = 'GATModel'

# Get the model class from the dictionary based on the input
model_class = model_dict.get(model_name)

if model_class is None:
    print(f"Model '{model_name}' not found. Please enter a valid model name.")
else:
    run_model(model_class)