In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import neurotools.plotting as ntp

from abcd_tools.utils.ConfigLoader import load_yaml
params = load_yaml('../parameters.yaml')

In [None]:
summary = params['model_results_path'] + 'all_vertex_contrasts_ridge_summary.csv'
res = params['model_results_path'] + 'all_vertex_contrasts_ridge_results.pkl'

summary = pd.read_csv(summary)
# res = pd.read_pickle(res)

In [None]:
summary

In [None]:
summary = (summary
    .assign(
        max_r2=lambda x: x['mean_scores_r2'] + x['std_scores_r2'],
        min_r2=lambda x: x['mean_scores_r2'] - x['std_scores_r2']
    )
    .filter(['target', 'mean_scores_r2', 'max_r2', 'min_r2'])
    .melt(id_vars='target')
    .drop(columns='variable')
    .rename(columns={'value': 'mean_scores_r2'})
)

In [None]:
summary

In [None]:
process_map = params['process_map']
target_map = params['target_map']
color_map = params['color_map']

def quick_effect_plot(df, ax):
    colors = list(pd.unique(df['color']))
    
    sns.barplot(
        data=df,
        x='target',
        y='mean_scores_r2',
        hue='process',
        palette=colors,
        ax=ax
    )

    ax.set_ylabel(r"Model $R^2$")
    ax.set_xlabel("")

    plt.xticks(rotation=30, ha='right')


fig, ax = plt.subplots(figsize=(15, 8))

(summary
    .assign(
        process=lambda x: x['target'].replace(process_map),
        color=lambda x: x['process'].replace(color_map)
        )
    .replace(target_map)
    .sort_values(['process', 'mean_scores_r2'], ascending=[True, False])
    .pipe(quick_effect_plot, ax)
)

sns.set_context('paper', font_scale=1.75)
plt.savefig(params['plot_output_path'] + "ridge_contrasts_effectsize.pdf", bbox_inches='tight')

In [None]:
fpath = params['model_results_path'] + "all_vertex_contrasts_ridge_feature_importance.pkl"
fis, best_fis, avg_fis, haufe_avg, haufe_fis = pd.read_pickle(fpath)

In [None]:
haufe_avg

In [None]:
def broadcast_to_fsaverage(fis_agg: pd.Series, n_vertices=10242) -> pd.DataFrame:
    """Broadcast feature importance to fsaverage5.

    Args:
        fis_agg (pd.Series): Feature importance.
        n_vertices (int, optional): Number of vertices. Defaults to 10242+1.

    Returns:
        pd.DataFrame: Broadcasted feature importance.
    """

    def _split_hemisphere(df):
        df = df.reset_index(names=["correct", "condition", "hemisphere"])
        lh = df[df["hemisphere"] == "lh"].drop(columns="hemisphere")
        rh = df[df["hemisphere"] == "rh"].drop(columns="hemisphere")

        return lh, rh

    fis = fis_agg.copy()

    fis.index = fis.index.str.split("_", expand=True)
    fis = fis.unstack(level=2)
    # fis = fis.unstack()

    # convert columns to integers and sort
    fis.columns = fis.columns.astype(int)
    fis = fis.reindex(sorted(fis.columns), axis=1)

    # need to insert blank columns for missing vertices
    vertex_names = [*range(1, n_vertices + 1)]
    # vertex_names = [*range(0, n_vertices)]
    null_df = pd.DataFrame(np.nan, columns=vertex_names, index=fis.index)
    null_df = null_df.drop(columns=fis.columns)

    df = fis.join(null_df, how="outer")
    lh, rh = _split_hemisphere(df)

    return lh, rh

def table_to_dict(df: pd.DataFrame, idx=['correct', 'condition']):
    """Take dataframe (hemi) where each row is a double-index condition
    and return a dictionary of numpy arrays. """

    return (df
        .assign(cond=lambda x: x[idx[0]] + '_' + x[idx[1]])
        .drop(columns=idx)
        .set_index('cond')
        .groupby(level=0)
        .apply(lambda x: x.values.flatten())
        .to_dict()
    )

lh, rh = broadcast_to_fsaverage(haufe_avg['EEA'])

In [None]:
def make_masked_figure(lh, rh, conditions, target_map, title):

    fig, axs = plt.subplots(ncols=len(conditions), figsize=(30, 4))

    for ax, condition in zip(axs, conditions):

        lh_plot = lh[condition]
        rh_plot = rh[condition]

        ax.set_title(target_map[condition])

        ntp.plot({'lh': lh_plot, 'rh': rh_plot},
            threshold=0,
            cmap='seismic',
            colorbar=False,
            ax=ax
        )
    
    plt.suptitle(title, size=20, x=0.15)

def generate_plot(fis, target_map, title):

    lh_tab, rh_tab = broadcast_to_fsaverage(fis)
    lh_masked = table_to_dict(lh_tab)
    rh_masked = table_to_dict(rh_tab)

    conditions = lh_masked.keys()

    make_masked_figure(lh_masked, rh_masked, conditions, target_map, title)

target_map = {
    "correctstop_correctgo": 'Correct Stop vs. Correct Go',
    "correctstop_incorrectgo": "Correct Stop vs. Incorrect Go",
    "incorrectstop_correctstop": "Incorrect Stop vs. Correct Stop",
    "incorrectstop_incorrectgo": "Incorrect Stop vs. Incorrect Go"
}

generate_plot(haufe_avg['EEA'], target_map, "EEA")
plt.savefig(params['plot_output_path'] + 'EEA_contrasts.png', dpi=300, bbox_inches='tight')


In [None]:
generate_plot(haufe_avg['tf'], target_map, r"$p$TF")
plt.savefig(params['plot_output_path'] + 'tf_contrasts.png', dpi=300, bbox_inches='tight')