Follows from `rdex-prediction` models: is variance in empirically-derived SSRT explained by RDEX model parameters?

In [None]:
import pandas as pd

import BPt as bp

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.linear_model import ElasticNet

from abcd_tools.utils.ConfigLoader import load_yaml
from abcd_tools.utils.io import load_tabular


In [None]:
params = load_yaml("../parameters.yaml")

In [None]:
behavioral = load_tabular(params["targets_path"])
behavioral = behavioral.drop(columns=["correct_go_mrt", "correct_go_stdrt"]) # drop standard metrics

Also drop model-derrived SSRT, as well as the match and mismatch accumulators as they're compositely defined in EEA.

In [None]:
behavioral = behavioral.drop(columns=["SSRT", "vT", "vF"])

In [None]:
# limit dataset to training subjects from rdex_prediction model
rdex_predict_ds = pd.read_pickle(params["dataset_path"])
predict_train_idx = rdex_predict_ds.train_subjects

behavioral = behavioral[behavioral.index.isin(predict_train_idx)]
behavioral

In [None]:
target_columns = behavioral[['issrt']].columns
ds = bp.Dataset(behavioral, targets=target_columns)
ds

In [None]:
def define_crosspredict_pipeline(ds: bp.Dataset) -> bp.Pipeline:
   
    # Just scale float type features
    scaler = bp.Scaler('robust', scope='float')
    normalizer = bp.Scaler('normalize', scope='float')

    # Define regression model
    mod_obj=ElasticNet()
    mod_params = {
        'alpha': bp.p.Log(lower=1e-5, upper=1e5),
        'l1_ratio': bp.p.Scalar(lower=0.001, upper=1).set_mutation(sigma=0.165)}
    param_search = bp.ParamSearch('HammersleySearch', n_iter=100, cv='default')
    model = bp.Model(
        obj=mod_obj, 
        params=mod_params,  
        param_search=param_search
    )

    # Then define full pipeline
    pipe = bp.Pipeline([scaler, normalizer, model])

    return pipe

def fit_crosspredict_model(ds: bp.Dataset) -> bp.CompareDict:

    pipe = define_crosspredict_pipeline(ds)
    cv = bp.CV(splits=5, n_repeats=1)
    ps = bp.ProblemSpec(n_jobs=8, random_state=42)


    results = bp.evaluate(pipeline=pipe,
                      dataset=ds,
                      problem_spec=ps,
                      mute_warnings=True,
                      cv=cv)

    return results

results = fit_crosspredict_model(ds)
results

In [None]:
def make_plot_df(results: bp.EvalResults, params: dict) -> pd.DataFrame:
    fis = results.get_fis()
    fis_long = fis.melt()
    fis_long['process'] = fis_long['variable'].replace(params['process_map'])
    fis_long['variable'] = fis_long['variable'].replace(params['target_map'])

    fis_sorted = fis_long.groupby('variable').mean().sort_values('value').index
    fis_long = fis_long.set_index('variable').loc[fis_sorted].reset_index()
    return fis_long

fis_long = make_plot_df(results, params)

In [None]:
def make_crosspredict_plot(fis_long: pd.DataFrame, params: dict, metrics: tuple) -> None:
    sns.set_theme(style="whitegrid")
    palette = params['color_map']

    title = f'Feature Importance Predicting Empirical SSRT\nAvg. $R^2$: {metrics[0]:.2f} $\pm$ {metrics[1]:.2f}'

    fig, ax = plt.subplots()
    sns.barplot(x='value', y='variable', hue='process', 
            data=fis_long, palette=palette, dodge=False, ax=ax)
    ax.set_title(title)
    ax.set_xlabel('Avg. Feature Importance')
    ax.set_ylabel('')
    ax.legend(title='')

    plt.savefig(params['plot_output_path'] + 'crosspredict_feature_importance.png', dpi=300)
    plt.show()

metrics = (results.mean_scores['r2'], results.std_scores['r2'])

make_crosspredict_plot(fis_long, params, metrics)