In [26]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colormaps
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import auc
from sklearn.svm import SVC
from pathlib import Path
import numpy as np
import seaborn as sns
base_folder = 'results/schaefer400'
atlas_basename = 'schaefer400'
output = Path('results')
atlas_networks = [dir_ for dir_ in output.iterdir() if
                        dir_.is_dir() and atlas_basename in dir_.name]
output = output / atlas_basename
subjects_df = pd.read_csv('clinical_data.csv')
subjects_df = subjects_df.astype({'id': int})
subjects_df = subjects_df.set_index('id')
filename = Path('global_measures.csv')

networks_names = {
      "SalVentAttnLH": "Salience/Ventral Attention (Left Hemisphere)", "SalVentAttn": "Salience/Ventral Attention",
      "DorsAttn": "Dorsal Attention", "Cont": "Frontoparietal", "SomMot": "Somatomotor", "Default": "Default", "Vis": "Visual",
      "Limbic": "Limbic", "Global": "Global"
    }
networks_nce = {
      "SalVentAttnLH": "language", "SalVentAttn": "executive", "DorsAttn": "attention", "SomMot": "executive",
      "Cont": "visuoespatial", "Default": "memory", "Vis": "visuoespatial", "Limbic": "memory", "Global": "attention"
    }

In [77]:
def add_curve(graph_densities, measure, lower_error, upper_error, group, color_index, ax):
    ax.plot(graph_densities, measure, label=group, color=f'C{color_index}')
    ax.plot(graph_densities, lower_error, alpha=0.1, color=f'C{color_index}')
    ax.plot(graph_densities, upper_error, alpha=0.1, color=f'C{color_index}')
    ax.legend()
    ax.fill_between(graph_densities, lower_error, upper_error, alpha=0.2)

    
def add_statistical_significance(p_at_thresholds, ax, significance_levels, eps=1e-4):
    pvalues = p_at_thresholds[p_at_thresholds.columns[0]]
    labels = ['*' * i for i in range(len(significance_levels), 0, -1)]
    significance_levels.insert(0, 0.0)
    significance_levels.append(1.)
    labels.append('ns')
    categorized_pvalues = pd.cut(pvalues, significance_levels, right=False, labels=labels)
    spacing = 0.1
    if len(pvalues) > 1:
        spacing = pvalues.index[1] - pvalues.index[0] + eps

    significance_bar(ax, categorized_pvalues, labels, spacing)
    
    
def significance_bar(ax, categorized_pvalues, labels, spacing):
    line_y = ax.get_ylim()[1]
    max_threshold, min_threshold = categorized_pvalues.index[-1], categorized_pvalues.index[0]
    # Use light grey for *, dark grey for **, and black for ***
    colors = {label: col for label, col in zip(labels, colormaps.get_cmap('Greys')(np.linspace(0.8, 0.2, len(labels))))}
    for label in labels:
        significant_values = categorized_pvalues[categorized_pvalues == label]
        # Build a list of tuples with the start and end of each significant region
        if len(significant_values) > 0 and label not in 'ns':
            significant_regions = [(significant_values.index[0], significant_values.index[0])]
            for i, threshold in enumerate(significant_values.index):
                if i > 0:
                    if threshold - significant_values.index[i - 1] > spacing:
                        significant_regions.append((threshold, threshold))
                    else:
                        significant_regions[-1] = (significant_regions[-1][0], threshold)
            
            significant_regions = [(start - spacing, end + spacing) for start, end in significant_regions]
            for start, end in significant_regions:
                if end > max_threshold:
                    end = max_threshold
                if start < min_threshold:
                    start = min_threshold
                ax.plot((start, end), [line_y * 0.98, line_y * 0.98], linewidth=2, color=colors[label])

def get_network_name(atlas_basename, network):
    return network.lstrip(f'{atlas_basename}_') if is_network(network) else 'Global'


def is_network(atlas_name):
    return len(atlas_name.split('_')) > 1
                

In [4]:
def meshgrid(x, y, h=.02, offset=0.07):
    x_min, x_max = x.min() - offset, x.max() + offset
    y_min, y_max = y.min() - offset, y.max() + offset
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    return xx, yy


def add_svm_contours(ax, clf, xx, yy, **params):
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    return ax.contourf(xx, yy, Z, **params)


def normalize_values(df, columns):
    df[columns] = df[columns].astype(float)
    for column in columns:
        df[column] = (df[column] - df[column].mean()) / df[column].std()
    return df


def get_measure_at_threshold(subjects_df, groups, measure_label, network, network_nce, filename):
    graph_density, nces, values, categories, group_mapping = 0.0, [], [], [], {}
    for j, group in enumerate(groups):
        group_df = subjects_df[subjects_df['group'] == group]
        group_mapping[group] = j
        group_network_measures = pd.read_pickle(network / f'{filename.stem}_{group}.pkl')
        measures_at_threshold = group_network_measures.sort_values(by='threshold').iloc[-1]
        if measure_label not in measures_at_threshold.index:
            continue
        graph_density = measures_at_threshold['threshold']
        nces.extend(group_df[network_nce].values)
        values.extend(measures_at_threshold[measure_label])
        categories.extend([group] * len(group_df))
    df = pd.DataFrame({'nce': nces, 'measure': values, 'group': categories}).dropna()
    df = normalize_values(df, ['nce', 'measure'])
    return graph_density, df, group_mapping


def fit_and_plot_svm(df, cats_mapping, ax):
    df = df.replace({'group': cats_mapping})
    clf_nces, clf = SVC(), SVC()
    features = df[['nce', 'measure']].values
    nces = features[:, 0].reshape(-1, 1)
    categories = df['group'].values
    # Do a LOOCV to get the accuracy
    accuracies_nce, accuracies_features = [], []
    for i in range(len(features)):
        clf_nces.fit(np.delete(nces, i, axis=0), np.delete(categories, i))
        accuracies_nce.append(clf_nces.score(nces[i].reshape(1, -1), categories[i].reshape(1, -1)))
        clf.fit(np.delete(features, i, axis=0), np.delete(categories, i))
        accuracies_features.append(clf.score(features[i].reshape(1, -1), categories[i].reshape(1, -1)))
    nce_mean, nce_std = np.mean(accuracies_nce), np.std(accuracies_nce)
    features_mean, features_std = np.mean(accuracies_features), np.std(accuracies_features)
    print(f'Mean NCE accuracy: {nce_mean} +/- {nce_std}')
    print(f'Mean features accuracy: {features_mean} +/- {features_std}')
    clf.fit(features, categories)
    xx, yy = meshgrid(features[:, 0], features[:, 1])
    add_svm_contours(ax, clf, xx, yy, cmap='coolwarm', alpha=0.1)
    
    return features_mean - nce_mean, np.sqrt(nce_std ** 2 + features_std ** 2)
    


def add_group_to_plot(measures_values, group, color_index, measure_label, ax):
    group_values = measures_values[measures_values['group'] == group]
    densities, auc_value = group_values['threshold'].values, 0.0
    if measure_label in group_values.columns:
        measure_values = group_values[measure_label].values
        lower_error, upper_error = group_values[measure_label] - group_values[f'{measure_label}_ste'], \
                                   group_values[measure_label] + group_values[f'{measure_label}_ste']
        sorted_densities = np.argsort(densities)
        if group == 'covid':
            group = 'long-COVID'
        else:
            group = group.capitalize()
        add_curve(densities, measure_values, lower_error, upper_error, group, color_index, ax)
        if len(densities) > 1:
            auc_value = auc(densities[sorted_densities], measure_values[sorted_densities])
    return auc_value


In [75]:
def plot_nce_to_measure(atlas_basename, networks_dirs, networks_names, subjects_df, measure_label, measure_desc,
                        networks_nce, output, filename):
    ncols, nrows = 2, -(-len(networks_dirs) // 2)
    fig, axes = plt.subplots(figsize=(15, 5 * nrows), nrows=nrows, ncols=ncols)
    gains = {network.name: {} for network in networks_dirs}
    for i, network in enumerate(networks_dirs):
        print(f'Processing {network.name}')
        ax = axes[i // 2, i % 2] if nrows > 1 else axes[i % 2]
        network_basename = get_network_name(atlas_basename, network.name)
        if network_basename not in networks_nce:
            continue
        network_nce = networks_nce[network_basename]
        groups = sorted(subjects_df['group'].unique())
        connection_density, measure_df, group_mapping = get_measure_at_threshold(subjects_df, groups, measure_label, network,
                                                                            network_nce, filename)
        plot_df = measure_df.copy().replace({'group': {'covid': 'long-COVID', 'control': 'Control'}})
        sns.scatterplot(data=plot_df, x='nce', y='measure', hue='group', ax=ax)
        if not measure_df.empty:
            gains[network.name] = fit_and_plot_svm(measure_df, group_mapping, ax)
        ax.legend()
        ax.set_title(f'{networks_names[network_basename]}')
        ax.set_ylabel(f'{measure_desc} at t={connection_density * 100:.0f}%')
        ax.set_xlabel(f'{network_nce} score')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    fig.suptitle(measure_desc)
    fig.savefig(output / f'NCE_to_{measure_label}.png')
    plt.show()

    return gains

In [19]:
def classify_with_efficiency(atlas_basename, networks_dirs, subjects_df,
                        networks_nce, filename):
    gains = {network.name: {} for network in networks_dirs}
    for network in networks_dirs:
        print(f'Processing {network.name}')
        network_basename = get_network_name(atlas_basename, network.name)
        if network_basename not in networks_nce:
            continue
        network_nce = networks_nce[network_basename]
        groups = sorted(subjects_df['group'].unique())
        measure_df, group_mapping = get_efficiency_at_threshold(subjects_df, groups, network,
                                                                            network_nce, filename)
        if not measure_df.empty:
            gains[network.name] = fit_svm(measure_df, group_mapping)
        

    return gains


def fit_svm(df, cats_mapping):
    df = df.replace({'group': cats_mapping})
    clf_nces, clf = SVC(), SVC()
    features = df[['nce', 'local_eff', 'global_eff']].values
    nces = features[:, 0].reshape(-1, 1)
    categories = df['group'].values
    accuracies_nce, accuracies_features = [], []
    for i in range(len(features)):
        clf_nces.fit(np.delete(nces, i, axis=0), np.delete(categories, i))
        accuracies_nce.append(clf_nces.score(nces[i].reshape(1, -1), categories[i].reshape(1, -1)))
        clf.fit(np.delete(features, i, axis=0), np.delete(categories, i))
        accuracies_features.append(clf.score(features[i].reshape(1, -1), categories[i].reshape(1, -1)))
    nce_mean, nce_std = np.mean(accuracies_nce), np.std(accuracies_nce)
    features_mean, features_std = np.mean(accuracies_features), np.std(accuracies_features)
    print(f'Mean NCE accuracy: {nce_mean} +/- {nce_std}')
    print(f'Mean features accuracy: {features_mean} +/- {features_std}')
    
    return features_mean - nce_mean, np.sqrt(nce_std ** 2 + features_std ** 2)

def get_efficiency_at_threshold(subjects_df, groups, network, network_nce, filename):
    nces, local_eff_values, global_eff_values, categories, group_mapping = [], [], [], [], {}
    for j, group in enumerate(groups):
        group_df = subjects_df[subjects_df['group'] == group]
        group_mapping[group] = j
        group_network_measures = pd.read_pickle(network / f'{filename.stem}_{group}.pkl')
        measures_at_threshold = group_network_measures.sort_values(by='threshold').iloc[-1]
        if 'global_efficiency' not in measures_at_threshold.index or 'avg_local_efficiency' not in measures_at_threshold.index:
            continue
        graph_density = measures_at_threshold['threshold']
        nces.extend(group_df[network_nce].values)
        local_eff_values.extend(measures_at_threshold['avg_local_efficiency'])
        global_eff_values.extend(measures_at_threshold['global_efficiency'])
        categories.extend([group] * len(group_df))
    df = pd.DataFrame({'nce': nces, 'local_eff': local_eff_values, 'global_eff': global_eff_values, 'group': categories}).dropna()
    df = normalize_values(df, ['nce', 'local_eff', 'global_eff'])
    
    return df, group_mapping

In [22]:
classify_with_efficiency(atlas_basename, atlas_networks, subjects_df, networks_nce, filename)

In [23]:
plot_nce_to_measure(atlas_basename, atlas_networks, networks_names, subjects_df, 'global_efficiency', 'Global Efficiency', networks_nce, output, filename)

In [24]:
plot_nce_to_measure(atlas_basename, atlas_networks, networks_names, subjects_df, 'avg_local_efficiency', 'Avg. Local Efficiency', networks_nce, output, filename)

In [73]:
def plot_measure(atlas_basename, networks_dirs, networks_names, measure_label, measure_desc, output, filename):
    ncols, nrows = 2, -(-len(networks_dirs) // 2)
    fig, axes = plt.subplots(figsize=(15, 5 * nrows), nrows=nrows, ncols=ncols)
    aucs = {network.name: {} for network in networks_dirs}
    for i, network in enumerate(networks_dirs):
        measures_values = pd.read_csv(network / filename.name, index_col=0)
        ax = axes[i // 2, i % 2] if nrows > 1 else axes[i % 2]
        groups = sorted(measures_values['group'].unique())
        for color_index, group in enumerate(groups):
            aucs[network.name][group] = add_group_to_plot(measures_values, group, color_index, measure_label, ax)
        if f'{measure_label}_p' in measures_values.columns:
            p_at_thresholds = measures_values[['threshold', f'{measure_label}_p']].drop_duplicates().set_index(
                'threshold')
            add_statistical_significance(p_at_thresholds, ax, significance_levels=[0.001, 0.005, 0.01])
        network_basename = get_network_name(atlas_basename, network.name)
        ax.set_title(f'{networks_names[network_basename]}')
        ax.set_xlabel('Connection density (%)', fontsize=12)
        ax.set_ylabel(measure_desc, fontsize=12)
        ax.spines['top'].set_visible(False), ax.spines['right'].set_visible(False)
        ax.set_xticks(ax.get_xticks()[1:-1])
        ax.set_yticks(ax.get_yticks()[1:-1])
        ax.set_yticklabels([f'{tick:.2f}' for tick in ax.get_yticks()])
        ax.set_xticklabels([f'{tick * 100:.0f}' for tick in ax.get_xticks()])
    fig.suptitle(measure_desc)
    fig.savefig(output / f'{measure_label}.png')
    plt.show()

    return aucs


In [80]:
plot_measure(atlas_basename, atlas_networks, networks_names, 'avg_clustering', 'Avg. Clustering Coefficient', output, filename)

In [81]:
plot_measure(atlas_basename, atlas_networks, networks_names, 'avg_local_efficiency', 'Avg. Local efficiency', output, filename)