In [None]:
import scipy.io as sio
import pandas as pd
from sklearn.metrics import roc_curve, roc_auc_score, auc
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn.utils import resample
from sklearn.model_selection import LeaveOneOut
import numpy as np
import plotly.graph_objects as go
import xgboost as xgb

Mounted at /content/drive


In [None]:
def ml_classifier(input_table, target, n_iterations):
    """
    Train an XGBoost classifier using Leave-One-Out cross-validation and obtain test-set results
    via bootstrap.

    Args:
        input_table (pd.DataFrame): Input dataset containing features and class labels.
        target (string): Column to predict.
        n_iterations (int): Number of bootstrap iterations to perform.

    Returns:
        dict: Dictionary containing the AUC, FPR, TPR, thresholds, and feature importances for each iteration.
    """
    # Bootstrap 90% of the sample size each time
    np.random.seed(42)
    n_size = int(len(input_table) * 0.9)

    # Initialize variables
    stats = list()
    metrics = ['auc', 'fpr', 'tpr', 'thresholds', 'feature_importances']
    results = {'main': {m: [] for m in metrics}}

    # Loop through subsamples
    for i in tqdm(range(n_iterations), desc="Bootstrap iterations"):
        # Get subsamples
        subsampled_data = resample(input_table, n_samples=n_size, stratify=input_table[target].values)
        y = subsampled_data[target].values
        X = subsampled_data.drop(columns=[target]).values

        loo = LeaveOneOut()
        loo.get_n_splits(X)

        labels = []
        probabilities = []
        importances = []

        for train_index, test_index in loo.split(X):
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]

            clf = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
            clf.fit(X_train, y_train)

            labels.append(y_test)
            probabilities.append(clf.predict_proba(X_test)[:, 1])
            importances.append(clf.feature_importances_)

        stats.append(roc_auc_score(labels, probabilities))
        fpr, tpr, thresholds = roc_curve(labels, probabilities)
        results['main']['fpr'].append(fpr)
        results['main']['tpr'].append(tpr)
        results['main']['thresholds'].append(thresholds)
        results['main']['auc'].append(roc_auc_score(labels, probabilities))
        results['main']['feature_importances'].append(np.mean(importances, axis=0))

    return results

def plot_roc_all_features(results, n_iterations):
    """
    This function creates and plots a ROC for a classifier's performance using all features.

    Args:
    results (dict): A dictionary containing the results of the classifier, including 'fpr', 'tpr', and 'auc' keys.
    n_iterations (int): The number of iterations used for creating the interpolated TPR values.

    Returns:
    None: The function saves the ROC plot as an image file and displays it.
    """
    # Set plot parameters
    colors = {
        'filla': 'rgba(52, 152, 219, 0.2)',
        'linea': 'rgba(52, 152, 219, 0.5)',
        'maina': 'rgba(41, 128, 185, 1.0)',
        'grid': 'rgba(189, 195, 199, 0.5)',
        'annot': 'rgba(149, 165, 166, 0.5)',
        'highlight': 'rgba(192, 57, 43, 1.0)'
    }

    fpr_mean = np.linspace(0, 1, n_iterations)
    interp_tprs = []

    # Calculate confidence bands
    for i in range(n_iterations):
        fpr, tpr = results['main']['fpr'][i], results['main']['tpr'][i]
        interp_tprs.append(np.interp(fpr_mean, fpr, tpr))
        interp_tprs[-1][0] = 0.0

    tpr_mean = np.mean(interp_tprs, axis=0)
    tpr_mean[-1] = 1.0

    tpr_ci = np.std(interp_tprs, axis=0) * 1.96
    tpr_upper = np.clip(tpr_mean + tpr_ci, 0, 1)
    tpr_lower = tpr_mean - tpr_ci

    auc = np.mean(results['main']['auc'])

    plot_data = [
        go.Scatter(x=fpr_mean, y=tpr_upper, line=dict(color=colors['linea'], width=1), hoverinfo="skip", showlegend=False, name='upper'),
        go.Scatter(x=fpr_mean, y=tpr_lower, fill='tonexty', fillcolor=colors['filla'], line=dict(color=colors['linea'], width=1), hoverinfo="skip", showlegend=False, name='lower'),
        go.Scatter(x=fpr_mean, y=tpr_mean, line=dict(color=colors['maina'], width=2), hoverinfo="skip", showlegend=True, name=f'AUC = {auc:.3f} [{tpr_lower.mean():.3f} {tpr_upper.mean():.3f}]')
    ]

    fig = go.Figure(plot_data)
    fig.add_shape(type='line', line=dict(dash='dash'), x0=0, x1=1, y0=0, y1=1)

    fig.update_layout(
        template='plotly_white',
        title_x=0.5,
        xaxis_title="1 - Specificity",
        yaxis_title="Sensitivity",
        width=600,
        height=600,
        legend=dict(yanchor="bottom", xanchor="right", x=0.95, y=0.01),
        font=dict(family="Arial", size=22, color="black")
    )

    fig.update_yaxes(range=[0, 1], gridcolor=colors['grid'], scaleanchor="x", scaleratio=1, linecolor='black')
    fig.update_xaxes(range=[0, 1], gridcolor=colors['grid'], constrain='domain', linecolor='black')

    fig.show()

def plot_feature_importances(results, input_table, target):
    """
    Plot the feature importances based on the output of ml_classifier using a horizontal bar plot.

    Args:
        results (dict): The results dictionary from ml_classifier.
        input_table (pd.DataFrame): Input dataset.
        target (string): Target column.

    Returns:
        None: Shows the feature importance plot.
    """

    # Access 'feature_importances' from the nested dictionary structure
    feature_importance_mean = np.mean(results['main']['feature_importances'], axis=0)
    feature_importance_std = np.std(results['main']['feature_importances'], axis=0)

    # Calculate SEM
    n = len(results['main']['feature_importances'])
    sem = feature_importance_std / np.sqrt(n)

    # Get feature names
    feature_names = input_table.drop(columns=[target]).columns.tolist()

    # Match feature name convention from the group classifier
    def rearrange_name(name):
        parts = name.split()
        type_ = parts[1]   # 'wPLI', 'AEC-c'
        feature = parts[2] # 'N2-delta', 'N2-theta'
        return f"{feature}-{type_}"

    # Rearrange names
    feature_names = [rearrange_name(name) for name in feature_names]

    # Sorting the features by importance in ascending order
    sorted_idx = feature_importance_mean.argsort()
    feature_importance_mean = feature_importance_mean[sorted_idx]
    sem = sem[sorted_idx]
    feature_names_sorted = [feature_names[i] for i in sorted_idx]

    # Select the top ten features
    top_n = 10
    feature_importance_mean = feature_importance_mean[-top_n:]
    sem = sem[-top_n:]
    feature_names_sorted = feature_names_sorted[-top_n:]

    # Create a horizontal bar plot
    fig = go.Figure()

    # We use y-axis for feature names and x-axis for feature importance values in horizontal bar plots
    fig.add_trace(
        go.Bar(
            y=feature_names_sorted,
            x=feature_importance_mean,
            orientation='h',
            error_x=dict(type='data', array=sem, visible=True)
        )
    )

    # Layout configuration
    fig.update_layout(
        # title="Feature Importances with SEM",
        xaxis_title="Feature Importance",
        yaxis_title="GE features",
        template='plotly_white',
        width=600,
        height=600,
        font=dict(family="Arial", size=22, color="black")
    )

    fig.show()

In [None]:
## Run the AD-NoEp vs HC classification with all the GE AEC-c and GE wPLI features
# Load the CSV file of the features of interest
groups_input = pd.read_csv('./all_GE_AEC_wPLI_features.csv')
groups_input

Unnamed: 0,ID,Class,GE AEC-c N2-delta,GE wPLI N2-delta,GE AEC-c N2-theta,GE wPLI N2-theta,GE AEC-c N2-alpha,GE wPLI N2-alpha,GE AEC-c N2-beta,GE wPLI N2-beta,...,GE AEC-c REM-beta,GE wPLI REM-beta,GE AEC-c REM-gamma,GE wPLI REM-gamma,GE AEC-c Awake-delta,GE wPLI Awake-delta,GE AEC-c Awake-theta,GE wPLI Awake-theta,GE AEC-c Awake-alpha,GE wPLI Awake-alpha
0,ADEX_026,ADNoEp,0.128770,0.210655,0.083121,0.365524,0.079391,0.263201,0.080570,0.143701,...,0.059421,0.130546,0.042475,0.129062,,,,,,
1,ADEX_138,ADNoEp,0.157520,0.225476,0.088631,0.308320,0.085964,0.304171,0.062625,0.112467,...,0.037588,0.143366,0.028960,0.170873,0.267162,0.425184,0.143330,0.431266,0.100214,0.414539
2,ADEX_019,ADNoEp,0.170782,0.212405,0.084221,0.260896,0.091610,0.190766,0.064935,0.120135,...,0.047967,0.151273,0.023069,0.140361,0.098903,0.349207,0.062445,0.381927,0.067062,0.387860
3,ADEX_073,ADNoEp,0.098449,0.139083,0.077275,0.227344,0.100617,0.189084,0.129735,0.158784,...,0.078969,0.169277,0.038257,0.196619,0.205633,0.487886,0.129054,0.460271,0.085699,0.443449
4,ADEX_102,ADNoEp,0.150391,0.181301,0.085158,0.235557,0.093279,0.225444,0.084884,0.130535,...,0.066647,0.150543,0.049337,0.141984,0.078963,0.425201,0.085181,0.427421,0.099795,0.552120
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
84,ADEX_087,HC,0.135703,0.181405,0.105062,0.251805,0.089587,0.212933,0.058670,0.115248,...,0.054493,0.109224,0.034859,0.111385,0.085306,0.329474,0.070009,0.385339,0.069846,0.401191
85,ADEX_101,HC,0.100032,0.194700,0.090840,0.249370,0.104158,0.400345,0.062218,0.132396,...,0.062848,0.171623,0.036318,0.174285,0.154294,0.398091,0.136786,0.433232,0.103932,0.416147
86,ADEX_070,HC,0.141576,0.197467,0.085023,0.212583,0.112732,0.246422,0.099114,0.140721,...,0.108359,0.173262,0.035506,0.197693,0.119004,0.435142,0.064960,0.468404,0.089017,0.443442
87,ADEX_020,HC,0.113023,0.165595,0.083541,0.234830,0.136780,0.194487,0.149657,0.152425,...,0.106216,0.162013,0.104614,0.258386,0.254564,0.320868,0.059974,0.270267,0.107796,0.328306


In [None]:
# Subjects to exclude
ids_to_remove = ['ADEX_025', 'ADEX_103', 'ADEX_084', 'ADEX_080',
              'ADEX_079', 'ADEX_068', 'ADEX_048']

# Remove subjects
groups_input = groups_input[~groups_input['ID'].isin(ids_to_remove)]

groups_input

Unnamed: 0,ID,Class,GE AEC-c N2-delta,GE wPLI N2-delta,GE AEC-c N2-theta,GE wPLI N2-theta,GE AEC-c N2-alpha,GE wPLI N2-alpha,GE AEC-c N2-beta,GE wPLI N2-beta,...,GE AEC-c REM-beta,GE wPLI REM-beta,GE AEC-c REM-gamma,GE wPLI REM-gamma,GE AEC-c Awake-delta,GE wPLI Awake-delta,GE AEC-c Awake-theta,GE wPLI Awake-theta,GE AEC-c Awake-alpha,GE wPLI Awake-alpha
0,ADEX_026,ADNoEp,0.128770,0.210655,0.083121,0.365524,0.079391,0.263201,0.080570,0.143701,...,0.059421,0.130546,0.042475,0.129062,,,,,,
1,ADEX_138,ADNoEp,0.157520,0.225476,0.088631,0.308320,0.085964,0.304171,0.062625,0.112467,...,0.037588,0.143366,0.028960,0.170873,0.267162,0.425184,0.143330,0.431266,0.100214,0.414539
2,ADEX_019,ADNoEp,0.170782,0.212405,0.084221,0.260896,0.091610,0.190766,0.064935,0.120135,...,0.047967,0.151273,0.023069,0.140361,0.098903,0.349207,0.062445,0.381927,0.067062,0.387860
3,ADEX_073,ADNoEp,0.098449,0.139083,0.077275,0.227344,0.100617,0.189084,0.129735,0.158784,...,0.078969,0.169277,0.038257,0.196619,0.205633,0.487886,0.129054,0.460271,0.085699,0.443449
4,ADEX_102,ADNoEp,0.150391,0.181301,0.085158,0.235557,0.093279,0.225444,0.084884,0.130535,...,0.066647,0.150543,0.049337,0.141984,0.078963,0.425201,0.085181,0.427421,0.099795,0.552120
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
84,ADEX_087,HC,0.135703,0.181405,0.105062,0.251805,0.089587,0.212933,0.058670,0.115248,...,0.054493,0.109224,0.034859,0.111385,0.085306,0.329474,0.070009,0.385339,0.069846,0.401191
85,ADEX_101,HC,0.100032,0.194700,0.090840,0.249370,0.104158,0.400345,0.062218,0.132396,...,0.062848,0.171623,0.036318,0.174285,0.154294,0.398091,0.136786,0.433232,0.103932,0.416147
86,ADEX_070,HC,0.141576,0.197467,0.085023,0.212583,0.112732,0.246422,0.099114,0.140721,...,0.108359,0.173262,0.035506,0.197693,0.119004,0.435142,0.064960,0.468404,0.089017,0.443442
87,ADEX_020,HC,0.113023,0.165595,0.083541,0.234830,0.136780,0.194487,0.149657,0.152425,...,0.106216,0.162013,0.104614,0.258386,0.254564,0.320868,0.059974,0.270267,0.107796,0.328306


In [None]:
# Drop the unnecesary columns
groups_input = groups_input.drop('ID', axis=1)

# Select the groups to be compared and change the class names to integers
## Run for AD-NoEp vs HC
filtered_df = groups_input.loc[groups_input['Class'].isin(['ADNoEp', 'HC'])].replace({'ADNoEp': 1, 'HC': 0})

In [None]:
# Number of iterations for bootstrap
n_iterations = 1000

# Run the classifier
results_adnoep_hc = ml_classifier(filtered_df, "Class", n_iterations)

Bootstrap iterations: 100%|██████████| 1000/1000 [36:18<00:00,  2.18s/it]


In [None]:
plot_roc_all_features(results_adnoep_hc, n_iterations)

In [None]:
plot_feature_importances(results_adnoep_hc, filtered_df, "Class")

In [None]:
## Sleep vs Awake Analysis

# Identify columns containing the string "Awake"
awake_columns = [col for col in filtered_df.columns if "Awake" in col]

# Create the table with columns containing "Awake" and "Class", drop NaNs
awake_features = filtered_df[["Class"] + awake_columns].dropna()

# Create the table with columns excluding those containing "Awake", drop NaNs
# But keeping "Class"
non_awake_columns = [col for col in filtered_df.columns if col not in awake_columns and col not in ["ID", "Class"]]
sleep_features = filtered_df[["Class"] + non_awake_columns].dropna()

In [None]:
# Awake ADNoEp vs HC
n_iterations = 1000
results_adnoep_adep = ml_classifier(awake_features, "Class", n_iterations)
plot_roc_all_features(results_adnoep_adep, n_iterations)
plot_feature_importances(results_adnoep_adep, awake_features, "Class")

Bootstrap iterations: 100%|██████████| 1000/1000 [31:32<00:00,  1.89s/it]


In [None]:
# Sleep ADNoEp vs HC
n_iterations = 1000
results_adnoep_adep = ml_classifier(sleep_features, "Class", n_iterations)
plot_roc_all_features(results_adnoep_adep, n_iterations)
plot_feature_importances(results_adnoep_adep, sleep_features, "Class")

Bootstrap iterations: 100%|██████████| 1000/1000 [29:03<00:00,  1.74s/it]


In [None]:
## Run for AD-NoEp vs AD-Ep with all the GE AEC-c and GE wPLI features
filtered_df = groups_input.loc[groups_input['Class'].isin(['ADNoEp', 'ADEp'])].replace({'ADEp': 1, 'ADNoEp': 0})

In [None]:
# Number of iterations for bootstrap
n_iterations = 1000

# Run the classifier
results_adnoep_adep = ml_classifier(filtered_df, "Class", n_iterations)

Bootstrap iterations: 100%|██████████| 1000/1000 [22:57<00:00,  1.38s/it]


In [None]:
plot_roc_all_features(results_adnoep_adep, n_iterations)

In [None]:
plot_feature_importances(results_adnoep_adep, filtered_df, "Class")

In [None]:
## Sleep vs Awake Analysis

# Identify columns containing the string "Awake"
awake_columns = [col for col in filtered_df.columns if "Awake" in col]

# Create the table with columns containing "Awake" and "Class", drop NaNs
awake_features = filtered_df[["Class"] + awake_columns].dropna()

# Create the table with columns excluding those containing "Awake", drop NaNs
# But keeping "Class"
non_awake_columns = [col for col in filtered_df.columns if col not in awake_columns and col not in ["ID", "Class"]]
sleep_features = filtered_df[["Class"] + non_awake_columns].dropna()

In [None]:
# Awake ADNoEp vs ADEp
n_iterations = 1000
results_adnoep_adep = ml_classifier(awake_features, "Class", n_iterations)
plot_roc_all_features(results_adnoep_adep, n_iterations)
plot_feature_importances(results_adnoep_adep, awake_features, "Class")

Bootstrap iterations: 100%|██████████| 1000/1000 [21:10<00:00,  1.27s/it]


In [None]:
# Sleep ADNoEp vs ADEp
n_iterations = 1000
results_adnoep_adep = ml_classifier(sleep_features, "Class", n_iterations)
plot_roc_all_features(results_adnoep_adep, n_iterations)
plot_feature_importances(results_adnoep_adep, sleep_features, "Class")

Bootstrap iterations: 100%|██████████| 1000/1000 [21:22<00:00,  1.28s/it]
