Follows from `rdex-prediction` models: is variance in empirically-derived SSRT explained by RDEX model parameters?

In [None]:
import pandas as pd
import numpy as np

import BPt as bp

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import seaborn as sns

from sklearn.linear_model import ElasticNet

from abcd_tools.utils.ConfigLoader import load_yaml
from abcd_tools.utils.io import load_tabular

from scipy.stats import pearsonr, spearmanr, gaussian_kde

In [None]:
params = load_yaml("../parameters.yaml")

In [None]:
behavioral = load_tabular(params["targets_path"])
behavioral = behavioral.drop(columns=["correct_go_mrt", "correct_go_stdrt"]) # drop standard metrics

Also drop model-derrived SSRT, as well as the match and mismatch accumulators as they're compositely defined in EEA.

In [None]:
behavioral = behavioral.drop(columns=["SSRT", "vT", "vF"])

In [None]:
# limit dataset to training subjects from rdex_prediction model
rdex_predict_ds = pd.read_pickle(params["sst_dataset_path"])
predict_train_idx = rdex_predict_ds.train_subjects

behavioral = behavioral[behavioral.index.isin(predict_train_idx)]
behavioral

In [None]:
target_columns = behavioral[['issrt']].columns
ds = bp.Dataset(behavioral, targets=target_columns)
ds

In [None]:
def define_crosspredict_pipeline(ds: bp.Dataset) -> bp.Pipeline:
   
    # Just scale float type features
    scaler = bp.Scaler('robust', scope='float')
    normalizer = bp.Scaler('normalize', scope='float')

    # Define regression model
    # mod_obj=ElasticNet()
    mod_obj="ridge"
    # mod_params = {
    #     'alpha': bp.p.Log(lower=1e-5, upper=1e5),
    #     'l1_ratio': bp.p.Scalar(lower=0.001, upper=1).set_mutation(sigma=0.165)}
    mod_params = {'alpha': bp.p.Log(lower=1e-5, upper=1e5)}
    param_search = bp.ParamSearch('HammersleySearch', n_iter=100, cv='default')
    model = bp.Model(
        obj=mod_obj, 
        params=mod_params,  
        param_search=param_search
    )

    # Then define full pipeline
    pipe = bp.Pipeline([scaler, normalizer, model])

    return pipe

def fit_crosspredict_model(ds: bp.Dataset) -> bp.CompareDict:

    pipe = define_crosspredict_pipeline(ds)
    cv = bp.CV(splits=5, n_repeats=1)
    ps = bp.ProblemSpec(n_jobs=8, random_state=42)


    results = bp.evaluate(pipeline=pipe,
                      dataset=ds,
                      problem_spec=ps,
                      mute_warnings=True,
                      progress_bar=False,
                      verbose=0,
                      cv=cv)

    return results

# results = fit_crosspredict_model(ds)
# pd.to_pickle(results, params["model_results_path"] + "crosspredict_model_results.pkl")
# results

In [None]:
results = pd.read_pickle(params["model_results_path"] + "crosspredict_model_results.pkl")

In [None]:
def make_plot_df(results: bp.EvalResults, params: dict) -> pd.DataFrame:
    fis = results.get_fis()
    fis_long = fis.melt()
    fis_long['process'] = fis_long['variable'].replace(params['process_map'])
    fis_long['variable'] = fis_long['variable'].replace(params['target_map'])

    fis_sorted = fis_long.groupby('variable').mean().sort_values('value').index
    fis_long = fis_long.set_index('variable').loc[fis_sorted].reset_index()
    return fis_long

fis_long = make_plot_df(results, params)

In [None]:
def make_crosspredict_plot(fis_long: pd.DataFrame, params: dict, metrics: tuple) -> None:
    sns.set_theme(style="whitegrid")
    palette = params['color_map']

    title = f'Feature Importance Predicting Empirical SSRT\nAvg. $R^2$: {metrics[0]:.2%} $\pm$ {metrics[1]:.2%}'

    fig, ax = plt.subplots()
    sns.barplot(x='value', y='variable', hue='process', 
            data=fis_long, palette=palette, dodge=False, ax=ax)
    ax.set_title(title)
    ax.set_xlabel('Avg. Feature Importance')
    ax.set_ylabel('')
    ax.legend(title='')

    plt.savefig(params['plot_output_path'] + 'crosspredict_feature_importance.png', dpi=300, bbox_inches='tight')

metrics = (results.mean_scores['r2'], results.std_scores['r2'])

make_crosspredict_plot(fis_long, params, metrics)

In [None]:
fis_avg = fis_long.groupby('variable').agg(['mean', 'std'])
fis_avg.to_csv(params['model_results_path'] + 'crosspredict_feature_importance.csv')
fis_avg

## Examine N-back EEA

In [None]:
# load results
nback_sst_res = pd.read_pickle(params['model_results_path'] + 'all_vertex_sst_nback_ridge_results.pkl')
sst_res = pd.read_pickle(params['model_results_path'] + 'all_vertex_ridge_results.pkl')
nback_res = pd.read_pickle(params['model_results_path'] + 'all_vertex_nback_ridge_results.pkl')

nback_eea = pd.read_csv(params['nback_targets_path']).set_index(['src_subject_id', 'eventname'])['e']

def assemble_model_data(res, eea_var):
    ds = res._dataset
    pred_dfs = res.get_preds_dfs()

    return ds, pred_dfs

# SST EEA predicted using SST task-fMRI
sst_ds, sst_preds = assemble_model_data(sst_res['EEA'], 'EEA')

# N-Back EEA predicted using SST task-fMRI
nback_sst_ds, nback_sst_preds = assemble_model_data(nback_sst_res['e'], 'e')

# N-Back EEA predicted using N-Back task-fMRI
nback_ds, nback_preds = assemble_model_data(nback_res['e'], 'e')


In [None]:
# join with n-back EEA
sst_preds = [df.join(nback_eea) for df in sst_preds]
sst_preds = pd.concat(sst_preds)
sst_preds = sst_preds[~sst_preds.index.duplicated(keep='first')].dropna()

nback_sst_preds = pd.concat(nback_sst_preds)
nback_sst_preds = nback_sst_preds[~nback_sst_preds.index.duplicated(keep='first')].dropna()

nback_preds = pd.concat(nback_preds)
nback_preds = nback_preds[~nback_preds.index.duplicated(keep='first')].dropna()

# EEA targets
eea = pd.concat([sst_ds['EEA'], nback_eea], axis=1).dropna()
eea.columns = ['sst_eea', 'nback_eea']

target_map = {
    'e': 'N-Back EEA',
    'nback_mrt': 'RT',
    'nback_stdrt': 'RT Variability',
}

process_map = {
    'e': 'EEA',
    'nback_mrt': 'empirical',
    'nback_stdrt': 'empirical',
}
    
nback_sst_summary = pd.read_csv(params['model_results_path'] + 'vertex_sst_nback_ridge_models_summary.csv')
nback_sst_summary['process'] = nback_sst_summary['target'].replace(process_map)
nback_sst_summary['target'] = nback_sst_summary['target'].replace(target_map)

nback_sst_summary = nback_sst_summary[~nback_sst_summary['target'].str.contains('dprime')]

In [None]:
def make_residual_plot(xvar, yvar, df, xlab, ylab, ax):
        values = np.vstack([df[xvar], df[yvar]])
        residual_kernel = gaussian_kde(values)(values)

        rho, p = spearmanr(df[xvar], df[yvar])

        if p < 0.001:
            p_str = rf"$\rho = {rho:.2f}, p < 0.001$"
        else:
            p_str = rf"$\rho = {rho:.2f}, p = {p:.2f}"


        # sns.scatterplot(x=xvar, y=yvar, data=df,
        #         alpha=0.5, s=8, ax=ax, c=residual_kernel, lowess=True)

        sns.regplot(x=xvar, y=yvar, data=df, ax=ax, lowess=True,
                scatter_kws={'alpha': 0.0})
        scatter = ax.scatter(df[xvar], df[yvar],
                alpha=0.25, s=5, c=residual_kernel, cmap='viridis')

        # need line to keep things equal, apparently
        line_min = min(df[xvar].min(), df[yvar].min())
        line_max = max(df[xvar].max(), df[yvar].max())

        ax.plot([line_min, line_max], [line_min, line_max], color='black', alpha=0.0)

        ax.set_xlabel(xlab)
        ax.set_ylabel(ylab)
        ax.text(0.05, 0.95, "Residuals", ha='left', va='top', transform=ax.transAxes)
        ax.text(0.05, 0.88, p_str, ha='left', va='top', transform=ax.transAxes)

        # ax.set_aspect('equal', 'box')

        return residual_kernel

def make_correlation_plot(xvar, yvar, df, xlab, ylab, ax):
        values = np.vstack([df[xvar], df[yvar]])
        parameter_kernel = gaussian_kde(values)(values)

        rho, p = spearmanr(df[xvar], df[yvar])
        if p < 0.001:
            p_str = rf"$\rho = {rho:.2f}, p < 0.001$"
        else:
            p_str = rf"$\rho = {rho:.2f}, p = {p:.2f}"

        sns.regplot(x=xvar, y=yvar, data=df, ax=ax, lowess=True,
                scatter_kws={'alpha': 0.0})

        scatter = ax.scatter(df[xvar], df[yvar],
                alpha=0.25, s=5, c=parameter_kernel, cmap='magma')

        ax.set_xlabel(xlab)
        ax.set_ylabel(ylab)
        ax.text(0.05, 0.95, p_str, ha='left', va
        ='top', transform=ax.transAxes)

        return parameter_kernel

In [None]:
def label_plot(ax, label):
    ax.set_title(label,
        fontdict={'fontsize': 14, 'fontweight': 'bold'},
        loc='left')

In [None]:
sns.set_context("paper", font_scale=1.25)
grid = {'width_ratios': [1, .25, 1, 1, 1]}
fig, axs = plt.subplots(nrows=1, ncols=5, figsize=(14, 4), 
                        layout='constrained', gridspec_kw=grid)


parameter_kernel = make_correlation_plot('nback_eea', 'sst_eea', eea, 'N-Back EEA Estimates', 'SST EEA Estimates', axs[0])
axs[0].xaxis.set_major_locator(MaxNLocator(integer=True))

vmin = parameter_kernel.min()
vmax = parameter_kernel.max()
points = plt.scatter([], [], c=[], vmin=vmin, vmax=vmax, cmap='magma')
cbar = fig.colorbar(points, ax=axs[0], orientation='vertical', fraction=0.025, ticks=[])
cbar.set_label('Point Density')
label_plot(axs[0], "a)")

axs[1].axis('off')

sst_kernel = make_residual_plot('y_true', 'predict', sst_preds, 'SST EEA Estimates', 'SST EEA Predicted from SST task-fMRI', axs[2])
label_plot(axs[2], "b)")

nback_kernel = make_residual_plot('y_true', 'predict', nback_preds, 'N-Back EEA Estimates', 'N-Back EEA Predicted from N-Back task-fMRI', axs[3])
axs[3].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[3].yaxis.set_major_locator(MaxNLocator(integer=True))
label_plot(axs[3], "c)")

nback_sst_kernel = make_residual_plot('y_true', 'predict', nback_sst_preds, 'N-Back EEA Estimates', 'N-Back EEA Predicted from SST task-fMRI', axs[4])
axs[4].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[4].yaxis.set_major_locator(MaxNLocator(integer=True))
label_plot(axs[4], "d)")

vmin = min(sst_kernel.min(), nback_kernel.min(), nback_sst_kernel.min())
vmax = max(sst_kernel.max(), nback_kernel.max(), nback_sst_kernel.max())

points = plt.scatter([], [], c=[], vmin=vmin, vmax=vmax, cmap='viridis')
cbar = fig.colorbar(points, ax=axs[4], orientation='vertical', fraction=0.025, ticks=[])
cbar.set_label('Point Density')


# axs[3].sharey(axs[4])

axs[0].set_aspect(1./axs[0].get_data_ratio())
axs[0].xaxis.set_major_locator(MaxNLocator(integer=True))

for i in [2, 3, 4]:
    axs[i].set_aspect('equal', 'box')
    axs[i].xaxis.set_major_locator(MaxNLocator(integer=True))
    axs[i].yaxis.set_major_locator(MaxNLocator(integer=True))
axs[0].sharey(axs[2])
sns.despine()


plt.savefig(params['plot_output_path'] + 'nback_sst_eea_correlations.png', dpi=300, bbox_inches='tight')
plt.show()
