In [None]:
import os
import glob
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from abcd_tools.utils.ConfigLoader import load_yaml

In [None]:
params = load_yaml("../parameters.yaml")

In [None]:
def parse_vol_info(vol_info: pd.DataFrame) -> pd.DataFrame:

    TPT_MAP = {
        'baseline': 'baseline_year_1_arm_1',
        '2year': '2_year_follow_up_y_arm_1',
        '4year': '4_year_follow_up_y_arm_1',
        '6year': '6_year_follow_up_y_arm_1',
    }

    tmp = vol_info.iloc[:, 0].str.split("_", expand=True)[[2, 3]]
    tmp.columns = ['src_subject_id', 'eventname']
    tmp['src_subject_id'] = 'NDAR_' + tmp['src_subject_id']
    tmp['eventname'] = tmp['eventname'].map(TPT_MAP)

    return tmp

def concatenate_hemispheres(lh: pd.DataFrame, rh: pd.DataFrame) -> pd.DataFrame:
    """Concatenate left and right hemisphere dataframes

    Args:
        lh (pd.DataFrame): Left hemisphere data
        rh (pd.DataFrame): Right hemisphere data
    
    Returns:
        pd.DataFrame: Concatenated data
    """
    lh.columns = [c + '_lh' for c in lh.columns]
    rh.columns = [c + '_rh' for c in rh.columns]
    return pd.concat([lh, rh], axis=1)

In [None]:
def combine_betas(sst_conditions: dict, hemispheres: list, beta_input_dir: str,
    vol_info: pd.DataFrame,release: str='r5') -> None:
    
    betas_out = pd.DataFrame()
    idx = ['src_subject_id', 'eventname', 'run']
    for condition in sst_conditions.keys():
        betas = {}
        for hemi in hemispheres:

            if release == 'r5':
                run1_fpath = f"{beta_input_dir}SST_1_{sst_conditions[condition]}-{hemi}.parquet"
                run2_fpath = f"{beta_input_dir}SST_2_{sst_conditions[condition]}-{hemi}.parquet"
            elif release == 'r6':
                run1_fpath = f"{beta_input_dir}sst_{condition}_beta_r01_{hemi}.parquet"
                run2_fpath = f"{beta_input_dir}sst_{condition}_beta_r02_{hemi}.parquet"

            run1 = pd.read_parquet(run1_fpath)
            run2 = pd.read_parquet(run2_fpath)

            run1['run'] = 1
            run2['run'] = 2

            run1 = pd.concat([vol_info, run1], axis=1)
            run2 = pd.concat([vol_info, run2], axis=1)

            combined = pd.concat([run1, run2])
            name = sst_conditions[condition]

            betas[hemi] = combined

        betas_df = concatenate_hemispheres(betas['lh'].set_index(idx), betas['rh'].set_index(idx))
        name = sst_conditions[condition]
        betas_df.columns = [c.replace('tableData', name + '_') for c in betas_df.columns]

        betas_df['condition'] = condition
        betas_out = pd.concat([betas_out, betas_df])

    return betas_out.reset_index()

        # betas_df.to_parquet(f"{beta_output_dir}average_betas_{condition}.parquet")

## Load unprocessed betas for one condition (correct go)

In [None]:
vol_info_r5 = pd.read_parquet(params['vol_info_path_r5'])
vol_info_r6 = pd.read_parquet(params['vol_info_path_r6'])
vol_info_r6 = parse_vol_info(vol_info_r6)

sst_conditions = {'cg': 'correct_go'}

In [None]:
# r5 = combine_betas(sst_conditions, 
#                     params['hemispheres'], 
#                     params['beta_input_dir_r5'], 
#                     vol_info_r5, 
#                     release='r5')
# r5 = r5[r5['eventname'] == 'baseline_year_1_arm_1']
# r5 = r5.dropna()
# r5.to_parquet("../../data/02_intermediate/cg_r5_combined_betas.parquet")
r5 = pd.read_parquet("../../data/02_intermediate/cg_r5_combined_betas.parquet")

In [None]:
# r6 = combine_betas(sst_conditions,
#                     params['hemispheres'],
#                     params['beta_input_dir_r6'],
#                     vol_info_r6,
#                     release='r6')
# r6 = r6[r6['eventname'] == 'baseline_year_1_arm_1']
# r6 = r6.dropna()
# r6.to_parquet("../../data/02_intermediate/cg_r6_combined_betas.parquet")
r6 = pd.read_parquet("../../data/02_intermediate/cg_r6_combined_betas.parquet")

In [None]:
# pick a random subset of subjects
np.random.seed(64)
subjects = r5['src_subject_id'].unique()
selected_subjects = np.random.choice(subjects, 15, replace=False)

In [None]:
def combine_betas(r5, r6, subjects):
    r5 = r5[r5['src_subject_id'].isin(subjects)]
    r6 = r6[r6['src_subject_id'].isin(subjects)]

    id_vars = ['src_subject_id', 'eventname', 'run', 'condition']
    idx = id_vars + ['variable']
    r5 = r5.melt(id_vars=id_vars).set_index(idx)
    r6 = r6.melt(id_vars=id_vars).set_index(idx)

    r5.rename(columns={'value': 'r5'}, inplace=True)
    r6.rename(columns={'value': 'r6'}, inplace=True)
    
    return pd.concat([r5, r6], axis=1)

combined = combine_betas(r5, r6, selected_subjects).dropna()

In [None]:
def plot_betas(subset, xvar='r5_cg', yvar='r6_cg', hue=None):

    g = sns.FacetGrid(subset, col='src_subject_id', col_wrap=5, height=3, hue=hue)
    g.map(sns.scatterplot, xvar,  yvar)

    if hue is not None:
        g.add_legend()

    for ax in g.axes.flat:
        ax.axline((0, 0), slope=1, color='k', ls='--')
        ax.grid(True, axis='both', linestyle=':')
        ax.set_xlim(-10, 10) 
        ax.set_ylim(-10, 10)
        ax.set_aspect('equal', adjustable='box')

In [None]:
combined

In [None]:
plot_betas(combined.reset_index(), xvar='r5', yvar='r6', hue='run')

## Examine Average Betas

In [None]:
def combine_average_betas(params, subjects, fpath="../../data/02_intermediate/average_betas.parquet"):

    r5_cg_avg = pd.read_parquet(params['beta_output_dir_r5'] + 'average_betas_cg.parquet').reset_index()
    r6_cg_avg = pd.read_parquet(params['beta_output_dir_r6'] + 'average_betas_cg.parquet').reset_index()

    r5_cg_avg = r5_cg_avg[r5_cg_avg['src_subject_id'].isin(subjects)]
    r6_cg_avg = r6_cg_avg[r6_cg_avg['src_subject_id'].isin(subjects)]

    r5_long = r5_cg_avg.melt(id_vars=['src_subject_id', 'eventname'], var_name='vertex', value_name='r5_cg')
    r6_long = r6_cg_avg.melt(id_vars=['src_subject_id', 'eventname'], var_name='vertex', value_name='r6_cg')

    idx = ['src_subject_id', 'eventname', 'vertex']
    long_compare = pd.concat([r5_long.set_index(idx), r6_long.set_index(idx)], axis=1)
    long_compare = long_compare.reset_index()

    # persist; this is a big df
    long_compare.to_parquet(fpath)

    return long_compare

In [None]:
fpath = "../../data/02_intermediate/cg_average_betas.parquet"
avg_betas = combine_average_betas(params, selected_subjects, fpath=fpath)
# avg_betas = pd.read_parquet(fpath)

In [None]:
avg_betas

In [None]:
plot_betas(avg_betas)