# RDEX-ABCD Brain-Behavior Model Permutation Plots

In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import os
import glob

from abcd_tools.utils.ConfigLoader import load_yaml
params = load_yaml('../parameters.yaml')

In [None]:
fpath = params['permutation_results_path']
files = glob.glob(fpath + "*")

target_map = params['target_map']

In [None]:
def load_null_model(fpath: str) -> pd.DataFrame:
    target = fpath.split("_")[-1].split('.')[0]
    res = pd.read_pickle(fpath)

    pval_r2 = res[0]['r2']
    null_r2 = res[1]['r2']

    return pd.DataFrame({
        "target": target,
        "pval_r2": pval_r2,
        "null_r2": null_r2
    })

def load_permutations(files: list) -> pd.DataFrame:
    permutations = pd.DataFrame()
    
    for file in files:

        permutations = pd.concat([
            permutations,
            load_null_model(file)
        ])
    
    target_rep = {
        'mrt': 'correct_go_mrt',
        'stdrt': 'correct_go_stdrt'
    }
    
    return permutations.replace(target_rep)

permutations = load_permutations(files)
permutations

In [None]:
model_res_path = params['model_results_path'] + "all_vertex_ridge_summary.csv"
model_res = pd.read_csv(model_res_path)

model_values = (permutations
    .merge(model_res, on='target')
    .filter(items=['target', 'mean_scores_r2', 'std_scores_r2'])
    .drop_duplicates()
    
)
model_values

In [None]:
def make_permutation_plot(permutations: pd.DataFrame, model_res: pd.DataFrame, 
                        target_map: dict):

    model_values = (permutations
        .merge(model_res, on='target')
        .filter(items=['target', 'mean_scores_r2', 'std_scores_r2'])
        .drop_duplicates()
        .replace(target_map)
    )

    g = sns.FacetGrid(data=permutations.replace(target_map), 
                        col='target', 
                        col_wrap=4, 
                        sharex=False
                    )
    g.map_dataframe(sns.histplot, x='null_r2')

    for ax, pos in zip(g.axes.flat, model_values['mean_scores_r2']):
        ax.axvline(x=pos, color='r', linestyle='dashed')

    g.set_titles("{col_name}")
    g.set_xlabels(r"Null $R^2$")
    g.set_ylabels("Permutations")

make_permutation_plot(permutations, model_res, target_map)
plt.savefig(params['plot_output_path'] + "permutation_plots.pdf", bbox_inches='tight')