In [1]:
import numpy as np
import os
import pandas as pd
from pathlib import Path
import scipy.io
import h5py
import sklearn
import copy
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
import scipy.stats as stats
from scipy.stats import ttest_ind, mannwhitneyu

In [None]:
import torch
from torch_geometric.data import Data, Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader
import os
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from torch import Tensor
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import random
from itertools import product
from sklearn.model_selection import train_test_split, KFold
import torch.nn as nn

# Dataset and model

In [None]:
class MindDataset(InMemoryDataset):
    def __init__(self, root, matrices_harm, strat_covars, adjacency_matrices, roi_data, transform=None, pre_transform=None):
        self.matrices_harm = matrices_harm
        self.strat_covars = strat_covars
        self.adjacency_matrices = adjacency_matrices
        self.roi_data = roi_data
        super().__init__(root, transform, pre_transform)


        self.data, self.slices = torch.load(self.processed_paths[0], weights_only = False)

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        data_list = []

        for i in range(len(self.matrices_harm)):
            # Node features: all elements in a row
            node_features = torch.tensor(self.matrices_harm[i], dtype=torch.float)
            roi_features = torch.tensor(self.roi_data.iloc[i].values, dtype=torch.float).unsqueeze(1)

            # gender = self.strat_covars['Gender'].iloc[i]
            age = self.strat_covars['Age'].iloc[i]
            age_gender_features = torch.tensor([[age]] * self.matrices_harm[i].shape[0], dtype=torch.float)
            node_features = torch.cat([node_features, age_gender_features, roi_features], dim=1)

            edge_index = torch.tensor(np.array(np.where(self.adjacency_matrices[i] == 1)), dtype=torch.long)

            # edge_attr = []
            # for j in range(edge_index.shape[1]):
            #     edge_attr.append(torch.tensor(self.matrices_harm[i][edge_index[0, j], edge_index[1, j]], dtype=torch.float))
            # edge_attr = torch.stack(edge_attr).unsqueeze(1)

            data = Data(x=node_features, edge_index=edge_index, y=torch.tensor(self.strat_covars['Dx'].iloc[i], dtype=torch.long))
            # data = Data(x=node_features, edge_index=edge_index, edge_attr = edge_attr, y=torch.tensor(self.strat_covars['Dx'].iloc[i], dtype=torch.long)) #WITH EDGE ATTRIBUTES

            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def len(self):
        return super().len()

    def get(self, idx):
        return super().get(idx)

class NodeNorm(nn.Module):
    def __init__(self, nn_type="n", unbiased=False, eps=1e-5, power_root=2):
        super(NodeNorm, self).__init__()
        self.unbiased = unbiased
        self.eps = eps
        self.nn_type = nn_type
        self.power = 1 / power_root

    def forward(self, x):
        if self.nn_type == "n":
            mean = torch.mean(x, dim=1, keepdim=True)
            std = (
                torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps
            ).sqrt()
            x = (x - mean) / std
        elif self.nn_type == "v":
            std = (
                torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps
            ).sqrt()
            x = x / std
        elif self.nn_type == "m":
            mean = torch.mean(x, dim=1, keepdim=True)
            x = x - mean
        elif self.nn_type == "srv":  # square root of variance
            std = (
                torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps
            ).sqrt()
            x = x / torch.sqrt(std)
        elif self.nn_type == "pr":
            std = (
                torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps
            ).sqrt()
            x = x / torch.pow(std, self.power)
        return x

    def __repr__(self):
        original_str = super().__repr__()
        components = list(original_str)
        nn_type_str = f"nn_type={self.nn_type}"
        components.insert(-1, nn_type_str)
        new_str = "".join(components)
        return new_str

def get_normalization(norm_type, num_channels=None):
    if norm_type is None:
        norm = None
    elif norm_type == "batch":
        norm = nn.BatchNorm1d(num_features=num_channels)
    elif norm_type == "node_n":
        norm = NodeNorm(nn_type="n")
    elif norm_type == "node_v":
        norm = NodeNorm(nn_type="v")
    elif norm_type == "node_m":
        norm = NodeNorm(nn_type="m")
    elif norm_type == "node_srv":
        norm = NodeNorm(nn_type="srv")
    elif norm_type.find("node_pr") != -1:
        power_root = norm_type.split("_")[-1]
        power_root = int(power_root)
        norm = NodeNorm(nn_type="pr", power_root=power_root)
    elif norm_type == "layer":
        norm = nn.LayerNorm(normalized_shape=num_channels)
    else:
        raise NotImplementedError
    return norm

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, norm_type="node_n"):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        self.convs.append(GCNConv(input_dim, hidden_dim))
        self.bns.append(get_normalization(norm_type=norm_type, num_channels=hidden_dim))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.bns.append(get_normalization(norm_type=norm_type, num_channels=hidden_dim))

        self.fc = Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):

        for conv, bn in zip(self.convs, self.bns):
            x = F.relu(conv(x, edge_index))
            x = bn(x)

        x = global_mean_pool(x, batch)

        x = self.fc(x)
        return x

import Confounder_Correction_Classes
from Confounder_Correction_Classes import ComBatHarmonization

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train():
    model.train()
    epoch_loss = 0
    all_preds = []
    all_labels = []
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        all_preds.append(out.argmax(dim=1).cpu().numpy())
        all_labels.append(data.y.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    accuracy = (all_preds == all_labels).mean()
    f1 = f1_score(all_labels, all_preds)
    train_loss = epoch_loss / len(train_loader)

    return accuracy, train_loss, f1

def evaluate(loader):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            out = model(data.x, data.edge_index, data.batch)
            loss = F.cross_entropy(out, data.y)
            total_loss += loss.item()
            all_preds.append(out.argmax(dim=1).cpu().numpy())
            all_labels.append(data.y.cpu().numpy())

    # Calculate metrics
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    accuracy = (all_preds == all_labels).mean()
    f1 = f1_score(all_labels, all_preds, zero_division=1.0)
    val_loss = total_loss / len(loader)

    return accuracy, val_loss, f1, all_labels, all_preds

def split_data(data, strat_covars, roi_data, indices_1, indices_2):

    data_1 = data.drop(indices_2).reset_index(drop=True)
    strat_covars_1 = strat_covars.drop(indices_2).reset_index(drop=True)
    roi_1 = roi_data.drop(indices_2).reset_index(drop=True)

    data_2 = data.drop(indices_1).reset_index(drop=True)
    strat_covars_2 = strat_covars.drop(indices_1).reset_index(drop=True)
    roi_2 = roi_data.drop(indices_1).reset_index(drop=True)

    return data_1, strat_covars_1, roi_1, data_2, strat_covars_2, roi_2

def reconstruct_and_create_adjacency(data_harm, threshold_percentile=90):
    n_matrices, upper_triangle_size = data_harm.shape
    N = int((1 + np.sqrt(1 + 8 * upper_triangle_size)) // 2)
    matrices = np.zeros((n_matrices, N, N))
    adjacency_matrices = np.zeros((n_matrices, N, N), dtype=int)

    for i in range(n_matrices):
        matrix = np.eye(N)
        upper_indices = np.triu_indices(N, k=1)
        matrix[upper_indices] = data_harm.iloc[i]
        matrix = matrix + matrix.T
        np.fill_diagonal(matrix, 1)
        matrices[i] = matrix

        threshold = np.percentile(matrix, threshold_percentile)
        adjacency_matrices[i] = (matrix >= threshold).astype(int)

    return matrices, adjacency_matrices

def harmonize_data_2(data_1_raw, strat_covars_1, data_2_raw, strat_covars_2, ext_batch):
    volumes_columns = np.arange(0, data_1_raw.shape[1])

    feat_detail={'volumes':           {'id': volumes_columns,
                                            'categorical': ['Gender'],
                                            'continuous':['Age']}}

    combat_function=ComBatHarmonization(cv_method=None, ref_batch=None,
                                           regression_fit=0,
                                           feat_detail=feat_detail,
                                           feat_of_no_interest=None)

    data_1_dict={'data': data_1_raw, 'covariates': strat_covars_1}
    data_1_harm = combat_function.fit_transform(data_1_dict)
    data_1_harm = pd.DataFrame(data_1_harm)

    #Harmonize the test set
    data_2_raw = pd.DataFrame(data_2_raw)
    data_2_raw.columns = data_2_raw.columns.astype(int)
    data_2_raw.index = pd.RangeIndex(start=0, stop=len(data_2_raw), step=1)
    all_data = pd.concat([data_1_harm, data_2_raw], ignore_index=True)
    all_strat_covars = pd.concat([strat_covars_1, strat_covars_2], ignore_index=True)
    all_strat_covars.loc[all_strat_covars["batch"] != ext_batch, "batch"] = 0


    test_combat_function=ComBatHarmonization(cv_method=None, ref_batch=0,
                                           regression_fit=0,
                                           feat_detail=feat_detail,
                                           feat_of_no_interest=None)

    all_data_dict={'data': all_data, 'covariates': all_strat_covars}
    all_data_harm = test_combat_function.fit_transform(all_data_dict)
    all_data_harm = pd.DataFrame(all_data_harm)

    data_2_harm = all_data_harm.drop(data_1_harm.index)
    data_2_harm = data_2_harm.reset_index(drop=True)
    data_2_harm = pd.DataFrame(data_2_harm)

    return data_1_harm, data_2_harm

def harmonize_data(train_data_raw, train_strat_covars, test_data_raw, test_strat_covars):

            volumes_columns = np.arange(0, train_data_raw.shape[1])

            feat_detail={'volumes':           {'id': volumes_columns,
                                            'categorical': ['Gender'],
                                            'continuous':['Age']}}

            combat_function=ComBatHarmonization(cv_method=None, ref_batch=None,
                                              regression_fit=0,
                                              feat_detail=feat_detail,
                                              feat_of_no_interest=None)

            train_data_dict={'data': train_data_raw, 'covariates': train_strat_covars}
            train_data_harm = combat_function.fit_transform(train_data_dict)
            train_data_harm = pd.DataFrame(train_data_harm)

            test_data_dict={'data': test_data_raw, 'covariates': test_strat_covars}
            test_data_harm = combat_function.transform(test_data_dict)
            test_data_harm = pd.DataFrame(test_data_harm)

            return train_data_harm, test_data_harm

In [28]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from torch import Tensor
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import random
from itertools import product
from sklearn.model_selection import train_test_split, KFold
from torch_geometric.explain import Explainer, GNNExplainer

# K-Fold cross validation

In [None]:
ext_test_batches = [1, 3, 4, 5, 6, 7]
indices_to_remove = [indices_to_remove_1, indices_to_remove_3, indices_to_remove_4, indices_to_remove_5, indices_to_remove_6, indices_to_remove_7] 
final_results = []
all_node_feature_importances =[]
all_feature_importances = []
all_node_importances = []
all_feature_importances_2 = []

for ext_test_batch, indices in zip(ext_test_batches, indices_to_remove):
    print(f"Test batch: {ext_test_batch}")

    data_raw = pd.read_csv('data_raw_nmm_new.csv')
    strat_covars = pd.read_csv('MatchedData01.csv')
    roi_raw = pd.read_csv('roi_data_raw_nmm_norm.csv')

    data_raw = data_raw.drop(indices)
    strat_covars = strat_covars.drop(indices)
    roi_raw = roi_raw.drop(indices)

    #SEPARATE CROSS VALIDATION DATA AND EXTERNAL TEST SET DATA
    ext_test_indices = strat_covars[strat_covars['batch'] == ext_test_batch].index
    cv_indices = strat_covars[strat_covars["batch"] != ext_test_batch].index

    cv_data_raw, cv_strat_covars, cv_roi_raw, ext_test_data_raw, ext_test_strat_covars, ext_test_roi_raw = split_data(data_raw, strat_covars, roi_raw, cv_indices, ext_test_indices)

    # Extract the best parameters
    best_num_layers = best_num_layers_dict[ext_test_batch]
    best_hidden_dim = best_hidden_dim_dict[ext_test_batch]
    best_norm_type = best_norm_type_dict[ext_test_batch]
    best_batch_size = best_batch_size_dict[ext_test_batch]

    print(f"Best parameters: num_layers={best_num_layers}, hidden_dim={best_hidden_dim}, norm_type={best_norm_type}, batch_size={best_batch_size}")

    cv_data_harm, ext_test_data_harm = harmonize_data_2(cv_data_raw, cv_strat_covars, ext_test_data_raw, ext_test_strat_covars, ext_test_batch)
    cv_roi_harm, ext_test_roi_harm = harmonize_data_2(cv_roi_raw, cv_strat_covars, ext_test_roi_raw, ext_test_strat_covars, ext_test_batch)

    ext_test_matrices_harm, ext_test_adjacency_matrices = reconstruct_and_create_adjacency(ext_test_data_harm)

    ext_test_dataset_path = f"Datasets/CV/MindDatasetExtTestNMM90_batch{ext_test_batch}"
    ext_test_dataset = MindDataset(
        root=ext_test_dataset_path,
        matrices_harm=ext_test_matrices_harm,
        strat_covars=ext_test_strat_covars,
        adjacency_matrices=ext_test_adjacency_matrices,
        roi_data=ext_test_roi_harm
    )

    cv_matrices_harm, cv_adjacency_matrices = reconstruct_and_create_adjacency(cv_data_harm)
    cv_dataset_path = f"Datasets/CV/MindDatasetCVNMM90_batch{ext_test_batch}"
    cv_dataset = MindDataset(
        root=cv_dataset_path,
        matrices_harm=cv_matrices_harm,
        strat_covars=cv_strat_covars,
        adjacency_matrices=cv_adjacency_matrices,
        roi_data=cv_roi_harm
    )
    train_indices, val_indices = train_test_split(range(len(cv_dataset)), test_size=0.3, random_state=3)
    train_loader = DataLoader(cv_dataset[train_indices], batch_size=best_batch_size, shuffle=True)
    val_loader = DataLoader(cv_dataset[val_indices], batch_size=best_batch_size, shuffle=False)
    test_loader = DataLoader(ext_test_dataset, batch_size=best_batch_size, shuffle=False)

    set_seed(42)

    model = GCN(cv_dataset.num_features, hidden_dim=best_hidden_dim, output_dim=cv_dataset.num_classes, num_layers=best_num_layers, norm_type=best_norm_type)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    patience = 20
    best_val_loss = float("inf")
    best_val_f1 = 0
    patience_counter = 0

    for epoch in range(100):

        train_acc, train_loss, train_f1 = train()
        val_acc, val_loss, val_f1, _, _ = evaluate(val_loader)  # Validation set metrics
        print(
            f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, Train F1: {train_f1:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}, Val F1: {val_f1:.4f}"
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_model_123.pt")
            print(f"Saved best model at epoch {epoch} with Val Loss: {val_loss:.4f}")
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch} as validation loss did not improve for {patience} consecutive epochs.")
            break


    # Evaluate on test set after training is complete
    model.load_state_dict(torch.load("best_model_123.pt"))
    test_acc, test_loss, test_f1, all_labels, all_preds = evaluate(test_loader)

    test_prec = precision_score(all_labels, all_preds, zero_division=1.0)
    test_recall = recall_score(all_labels, all_preds, zero_division=1.0)

    tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()

    final_results.append({
                "batch": ext_test_batch,
                "num_layers": best_num_layers,
                "hidden_dim": best_hidden_dim,
                "norm_type": best_norm_type,
                "batch_size": best_batch_size,
                "accuracy": test_acc*100,
                "loss": test_loss,
                "f1": test_f1*100,
                "precision": test_prec*100,
                "recall": test_recall*100,
                "true_negative": tn,
                "false_positive": fp,
                "false_negative": fn,
                "true_positive": tp
            })

    print(f"External test {ext_test_batch} Accuracy: {test_acc:.4f}, Test F1: {test_f1:.4f}, Precision:{test_prec:.4f}, Recall:{test_recall:.4f}, True Negative {tn}, False Positive {fp}, False Negative {fn}, True Positive {tp}")

    ### NODE FEATURE IMPORTANCE ###

    explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    model_config=dict(
        mode='binary_classification',
        task_level='graph',
        return_type='raw'),
    )

    explanations = []

    for data in test_loader:
        explanation = explainer(data.x, data.edge_index, batch = data.batch)
        explanations.append(explanation)

    all_node_masks = torch.cat([exp.node_mask for exp in explanations], dim=0)


    num_features = all_node_masks.shape[1]
    num_graphs = all_node_masks.shape[0] // 122

    reshaped = all_node_masks.view(num_graphs, 122, num_features)

    avg_node_feature_importance = reshaped.mean(dim=0)

    plt.figure(figsize=(8, 5))
    sns.heatmap(avg_node_feature_importance.numpy(), cmap="viridis", xticklabels=10, yticklabels=10)
    plt.xlabel("Feature")
    plt.ylabel("Node")
    plt.title("Average Node Feature Importance")
    plt.tight_layout()
    plt.show()

    node_ids, feature_ids = torch.meshgrid(
        torch.arange(avg_node_feature_importance.size(0)),
        torch.arange(avg_node_feature_importance.size(1)),
        indexing='ij'
    )

    df_node_feature_importance = pd.DataFrame({
        "node": node_ids.flatten().numpy(),
        "feature": feature_ids.flatten().numpy(),
        "mean_importance": avg_node_feature_importance.flatten().numpy()
    })

    df_node_feature_importance = df_node_feature_importance.sort_values(by="mean_importance", ascending=False)

    df_node_feature_importance.to_csv(f"node_feature_importance_batch{ext_test_batch}.csv", index=False)

    all_node_feature_importances.append(df_node_feature_importance)

    ### FEATURE IMPORTANCE ###

    ### ALTERNATIVE METHOD ###
    subject_feature_sums = reshaped.sum(dim=1)
    feature_importance_2 = subject_feature_sums.mean(dim=0)
    df_feature_importance_2 = pd.DataFrame({'Feature': range(len(feature_importance_2)), 'Importance': feature_importance_2})
    df_feature_importance_2 = df_feature_importance_2.sort_values(by='Importance', ascending=False)

    all_feature_importances_2.append(df_feature_importance_2.set_index('Feature'))

    ######
    feature_importance = all_node_masks.sum(dim=0)
    df_feature_importance = pd.DataFrame({'Feature': range(len(feature_importance)), 'Importance': feature_importance})
    df_feature_importance = df_feature_importance.sort_values(by='Importance', ascending=False)

    #save df_feature_importance as a csv file
    df_feature_importance.to_csv(f"feature_importance_batch{ext_test_batch}.csv", index=False)

    all_feature_importances.append(df_feature_importance.set_index('Feature'))

    feature_importance = feature_importance.numpy()

    top_k = 20
    sorted_idx = np.argsort(feature_importance)[::-1][:top_k]
    top_features = feature_importance[sorted_idx]

    plt.figure(figsize=(10, 6))
    plt.barh(range(top_k), top_features[::-1], color='steelblue')
    plt.yticks(range(top_k), [f'{i}' for i in sorted_idx[::-1]])
    plt.xlabel("Importance Score")
    plt.title(f"Top {top_k} Most Important Features")

    for i, val in enumerate(top_features[::-1]):
        plt.text(val + 0.01, i, f'{val:.0f}', va='center', fontsize=10)

    plt.tight_layout()
    plt.show()

    print(df_feature_importance[df_feature_importance["Feature"].isin([122, 123, 124])])

    ### NODE IMPORTANCE ###

    all_importance_scores = []

    model.eval()

    with torch.no_grad():
      for data in test_loader:
        for graph_idx in np.unique(data.batch.numpy()):
            out, node_embeddings = model(data.x, data.edge_index, data.batch, return_embeddings=True)
            predicted_class = out.argmax(dim=1)[graph_idx].item()  # Get class of selected graph
            class_weights = model.fc.weight[predicted_class].detach().numpy()

            mask = (data.batch.numpy() == graph_idx)  # Mask for nodes in the selected graph
            importance_scores = node_embeddings[mask] @ torch.tensor(class_weights)
            all_importance_scores.append(importance_scores.cpu().numpy())

    all_importance_scores = np.vstack(all_importance_scores)
    mean_importance_per_node = np.mean(all_importance_scores, axis=0)

    df_node_importance = pd.DataFrame({'Feature': range(len(mean_importance_per_node)), 'Importance': mean_importance_per_node})

    df_node_importance.to_csv(f"node_importance_batch{ext_test_batch}.csv", index=False)

    all_node_importances.append(df_node_importance.set_index('Feature'))


    plt.figure(figsize=(10, 6))
    sns.heatmap(mean_importance_per_node.reshape(-1, 1), cmap="Reds")
    plt.title("Node Importance via CAM")
    plt.xlabel("Importance Score")
    plt.ylabel("Nodes")
    plt.show()

all_dfs_concat = pd.concat(all_node_feature_importances)
avg_node_feature_df = all_dfs_concat.groupby(['node', 'feature'], as_index=False)['mean_importance'].mean()
avg_node_feature_df.to_csv("avg_node_feature_df.csv", index=False)

# Importance Analysis

In [31]:
nmm = pd.read_csv("neuromorphometrics_original.csv", sep = ";")

In [32]:
nmm

Unnamed: 0,ROIid,ROIname,Vgm,Vwm,Vcsf,ROIcolor,CircNumber,CircRegion,CircLabel,Hemisphere
0,23,Right Accumbens Area,1,0,0,102 102 255,116,Basal ganglia,Accumbens,right
1,30,Left Accumbens Area,1,0,0,102 102 255,52,Basal ganglia,Accumbens,left
2,31,Right Amygdala,1,0,0,255 177 100,114,Amygdala,Amygdala,right
3,32,Left Amygdala,1,0,0,255 177 100,50,Amygdala,Amygdala,left
4,36,Right Caudate,1,0,0,255 0 255,117,Basal ganglia,Caudate,right
...,...,...,...,...,...,...,...,...,...,...
117,203,Left TMP temporal pole,1,0,0,102 0 0,27,Temporal lobe,lateral,left
118,204,Right TrIFG triangular part of the inferior fr...,1,0,0,0 102 0,77,Frontal lobe,lateral,right
119,205,Left TrIFG triangular part of the inferior fro...,1,0,0,0 102 0,13,Frontal lobe,lateral,left
120,206,Right TTG transverse temporal gyrus,1,0,0,255 215 0,94,Temporal lobe,supra-temporal,right


In [33]:
node_names = nmm['ROIname'].tolist()         
circ_regions = nmm['CircRegion'].tolist()   
extra_features = ['Age', 'ROI Volume']
extra_regions = ['None'] * 2

feature_names =node_names + extra_features
all_circ_regions = circ_regions + extra_regions

node_name_map = {i: name for i, name in enumerate(node_names)}
feature_name_map = {i: name for i, name in enumerate(feature_names)}
circ_region_map = {i: region for i, region in enumerate(all_circ_regions)}

## Node Feature

In [None]:
avg_node_feature_df = pd.read_csv("avg_node_feature_df.csv")

In [None]:
avg_node_feature_df

In [None]:
avg_matrix_df = avg_node_feature_df.pivot(index='node', columns='feature', values='mean_importance')

avg_matrix = avg_matrix_df.values

plt.figure(figsize=(8, 5))
sns.heatmap(avg_matrix, cmap="viridis", xticklabels=10, yticklabels=10)
plt.xlabel("Feature")
plt.ylabel("Node")
# plt.title("Average Node Feature Importance")
plt.tight_layout()
plt.savefig("node_feature_importance.png", dpi=300)
plt.show()

In [None]:
avg_node_feature_df['node_name'] = avg_node_feature_df['node'].map(node_name_map)
avg_node_feature_df['feature_name'] = avg_node_feature_df['feature'].map(feature_name_map)
avg_node_feature_df['CircNode'] = avg_node_feature_df['node'].map(circ_region_map)
avg_node_feature_df['CircFeature'] = avg_node_feature_df['feature'].map(circ_region_map)
avg_node_feature_df = avg_node_feature_df.sort_values(by="mean_importance", ascending=False).reset_index(drop=True)
avg_node_feature_df

## Feature

In [None]:
mean_per_feature = avg_matrix_df.mean(axis=0).sort_values(ascending=False)*100

# Create final DataFrame
df_mean_importance = mean_per_feature.reset_index()
df_mean_importance.columns = ['Feature', 'AverageImportance']

# Sort descending by importance
df_mean_importance = df_mean_importance.sort_values(by='AverageImportance', ascending=False)
df_mean_importance.reset_index(drop=True, inplace=True)
df_mean_importance

In [35]:
df_mean_importance['FeatureName'] = df_mean_importance['Feature'].map(feature_name_map)
df_mean_importance['CircRegion'] = df_mean_importance['Feature'].map(circ_region_map)

In [None]:
df_mean_importance

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Apply a clean style
sns.set_theme(style="ticks")

# Number of top features
top_k = 20

# Get top k sorted rows by importance
top_df = df_mean_importance.nlargest(top_k, 'AverageImportance').copy()
top_df = top_df[::-1]  # Reverse for barh plotting

# Define colors by CircRegion using a color palette
unique_regions = top_df['CircRegion'].unique()
palette = sns.color_palette("deep", len(unique_regions))[::-1]
region_color_map = {region: palette[i] for i, region in enumerate(unique_regions)}
bar_colors = top_df['CircRegion'].map(region_color_map)

# Create the plot
plt.figure(figsize=(10, 6))
bars = plt.barh(range(top_k), top_df['AverageImportance'], color=bar_colors)

sns.despine()

# Y-axis feature labels
plt.yticks(range(top_k), top_df['FeatureName'], fontsize=10)

# Add text annotations (importance scores)
for i, val in enumerate(top_df['AverageImportance']):
    plt.text(val + 0.01, i, f'{val:.2f}', va='center', fontsize=9)

# Labels and title
plt.xlabel("Average Importance", fontsize=12)
plt.title(f"Top {top_k} Most Important Features", fontsize=14)

# Create a legend for CircRegion
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=color, label=region) for region, color in region_color_map.items()]
plt.legend(handles=legend_elements, title="Brain Macroregion", bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()

In [None]:
# Concatenate all DataFrames along columns
concat_df = pd.concat(all_feature_importances_2, axis=1)

# Average importance scores across all batches
mean_importance = concat_df.mean(axis=1)

# Create final DataFrame
df_mean_importance = mean_importance.reset_index()
df_mean_importance.columns = ['Feature', 'AverageImportance']

# Sort descending by importance
df_mean_importance = df_mean_importance.sort_values(by='AverageImportance', ascending=False)
df_mean_importance.reset_index(drop=True, inplace=True)
df_mean_importance

In [82]:
df_mean_importance['FeatureName'] = df_mean_importance['Feature'].map(feature_name_map)
df_mean_importance['CircRegion'] = df_mean_importance['Feature'].map(circ_region_map)

In [None]:
# Apply a clean style
sns.set(style="ticks")

# Number of top features
top_k = 20

# Get top k sorted rows by importance
top_df = df_mean_importance.nlargest(top_k, 'AverageImportance').copy()
top_df = top_df[::-1]  # Reverse for barh plotting

# Define colors by CircRegion using a color palette
unique_regions = top_df['CircRegion'].unique()
palette = sns.color_palette("deep", len(unique_regions))[::-1]
region_color_map = {region: palette[i] for i, region in enumerate(unique_regions)}
bar_colors = top_df['CircRegion'].map(region_color_map)

# Create the plot
plt.figure(figsize=(6, 5))
bars = plt.barh(range(top_k), top_df['AverageImportance'], color=bar_colors)

# Y-axis feature labels
plt.yticks(range(top_k), top_df['FeatureName'], fontsize=10)

# Add text annotations (importance scores)
for i, val in enumerate(top_df['AverageImportance']):
    plt.text(val + 0.01, i, f'{val:.2f}', va='center', fontsize=9)

# Labels and title
plt.xlabel("Average Importance", fontsize=12)
plt.title(f"Top {top_k} Most Important Features", fontsize=14)

# Create a legend for CircRegion
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=color, label=region) for region, color in region_color_map.items()]
plt.legend(handles=legend_elements, title="Brain Region", bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()


## Node

In [4]:
import pandas as pd

# List of batch numbers you used
ext_test_batches = [1, 3, 4, 5, 6, 7]

# Initialize the list to store DataFrames
all_node_importances = []

# Loop through the batch numbers, read each CSV, set index, and append
for batch in ext_test_batches:
    filename = f"node_importance_batch{batch}.csv"
    df = pd.read_csv(filename)
    df = df.set_index('Feature')  # Set 'Feature' column as index
    all_node_importances.append(df)


In [None]:
# Concatenate all DataFrames along columns
concat_df = pd.concat(all_feature_importances, axis=1)

# Average importance scores across all batches
mean_importance = concat_df.mean(axis=1)

# Create final DataFrame
df_mean_importance = mean_importance.reset_index()
df_mean_importance.columns = ['Feature', 'AverageImportance']

# Sort descending by importance
df_mean_importance = df_mean_importance.sort_values(by='AverageImportance', ascending=False)
df_mean_importance.reset_index(drop=True, inplace=True)
df_mean_importance

In [73]:
df_mean_importance['FeatureName'] = df_mean_importance['Feature'].map(feature_name_map)
df_mean_importance['CircRegion'] = df_mean_importance['Feature'].map(circ_region_map)

In [6]:
avg_importance = np.mean(np.stack(all_node_importances, axis=2), axis=2)

df_node_importance = pd.DataFrame({
    "Feature": np.arange(avg_importance.shape[0]),
    "AverageImportance": avg_importance.flatten()
})



In [13]:
df_node_importance['FeatureName'] = df_node_importance['Feature'].map(feature_name_map)
df_node_importance['CircRegion'] = df_node_importance['Feature'].map(circ_region_map)

In [None]:
df_node_importance = df_node_importance.sort_values(by='AverageImportance', ascending=False)
df_node_importance.reset_index(drop=True, inplace=True)
df_node_importance

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Apply a clean style
sns.set(style="ticks")

# Number of top features
top_k = 10

# Get top k sorted rows by importance
top_df = df_node_importance.nlargest(top_k, 'AverageImportance').copy()
top_df = top_df[::-1]  # Reverse for barh plotting

# Define colors by CircRegion using a color palette
unique_regions = top_df['CircRegion'].unique()
palette = sns.color_palette("deep", len(unique_regions))[::-1]
region_color_map = {region: palette[i] for i, region in enumerate(unique_regions)}
bar_colors = top_df['CircRegion'].map(region_color_map)

# Create the plot
plt.figure(figsize=(6, 5))
bars = plt.barh(range(top_k), top_df['AverageImportance'], color=bar_colors)

sns.despine()

# Y-axis feature labels
plt.yticks(range(top_k), top_df['FeatureName'], fontsize=10)

# Add text annotations (importance scores)
for i, val in enumerate(top_df['AverageImportance']):
    plt.text(val + 0.01, i, f'{val:.2f}', va='center', fontsize=9)

# Labels and title
plt.xlabel("Average Importance", fontsize=12)
plt.title(f"Top {top_k} Most Important Nodes", fontsize=14)

# Create a legend for CircRegion
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=color, label=region) for region, color in region_color_map.items()]
plt.legend(handles=legend_elements, title="Brain Region", bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()

In [30]:
df_node_importance = df_node_importance.sort_values(by='Feature', ascending=True)

In [None]:
plt.figure(figsize=(10, 6))
sns.heatmap(df_node_importance['AverageImportance'].values.reshape(-1, 1), cmap="Reds") 
# plt.title("Node Importance via CAM")
plt.xlabel("Importance Score")
plt.ylabel("Nodes")
plt.savefig("node_importance_cam.png", dpi = 300)
plt.show()