In [None]:
import os
import glob
# import pandas as pd
import numpy as np

import polars as pl
import pandas as pd

from typing import Dict, List

import matplotlib.pyplot as plt
import seaborn as sns

from abcd_tools.utils.io import load_tabular
from abcd_tools.utils.ConfigLoader import load_yaml

In [None]:
params = load_yaml("../parameters.yaml")

In [None]:
def parse_vol_info(vol_info: pl.DataFrame) -> pl.DataFrame:
    """Parse volume information to extract subject ID and event name
    
    Args:
        vol_info (pl.DataFrame): Volume information DataFrame
        
    Returns:
        pl.DataFrame: Parsed information with subject ID and event name
    """
    TPT_MAP = {
        'baseline': 'baseline_year_1_arm_1',
        '2year': '2_year_follow_up_y_arm_1',
        '4year': '4_year_follow_up_y_arm_1',
        '6year': '6_year_follow_up_y_arm_1',
    }
    
    return (
        vol_info
        .select(pl.col(vol_info.columns[0]))  # Select first column
        .with_columns([
            pl.col(vol_info.columns[0])
            .str.split('_')
            .list.get(2)
            .alias('src_subject_id'),
            
            pl.col(vol_info.columns[0])
            .str.split('_')
            .list.get(3)
            .alias('eventname')
        ])
        .with_columns([
            pl.col('src_subject_id').map_elements(lambda x: f'NDAR_{x}'),
            pl.col('eventname').replace_strict(TPT_MAP)
        ])
        .select(['src_subject_id', 'eventname'])
    )

def concatenate_hemispheres(lh: pl.DataFrame, rh: pl.DataFrame) -> pl.DataFrame:
    """Concatenate left and right hemisphere dataframes
    
    Args:
        lh (pl.DataFrame): Left hemisphere data
        rh (pl.DataFrame): Right hemisphere data
        
    Returns:
        pl.DataFrame: Concatenated data
    """
    idx = ['src_subject_id', 'eventname', 'run']
    lh = lh.select([
        pl.col('*').name.map(lambda x: x + '_lh' if x not in idx else x)
    ])
    rh = rh.select([
        pl.col('*').name.map(lambda x: x + '_rh' if x not in idx else x)
    ])
    return pl.concat([lh, rh], how='align')

def load_degrees_of_freedom(r1_fpath: str, r2_fpath: str) -> pd.DataFrame:
    """Load censored frame information for run averaging.

    Args:
        r1_fpath (str): Filepath to run 1 info
        r2_fpath (str): Filepath to run 2 info

    Returns:
        pd.DataFrame: DOFs for runs 1 and 2
    """
    r1_dof = load_tabular(r1_fpath, cols=['tfmri_sstr1_beta_dof'])
    r2_dof = load_tabular(r2_fpath, cols=['tfmri_sstr2_beta_dof'])

    return pd.concat([r1_dof, r2_dof], axis=1)

In [None]:
def combine_betas(
    sst_conditions: Dict[str, str],
    hemispheres: List[str],
    beta_input_dir: str,
    vol_info: pl.DataFrame,
    subjects: list,
    release: str = 'r5'
) -> pl.DataFrame:
    """
    Combine beta values from different conditions and hemispheres into a single DataFrame.
    
    Args:
        sst_conditions: Dictionary mapping condition names to their identifiers
        hemispheres: List of hemispheres to process (e.g., ['lh', 'rh'])
        beta_input_dir: Directory containing the beta parquet files
        vol_info: DataFrame containing volume information
        release: Release version ('r5' or 'r6')
        
    Returns:
        Combined DataFrame with beta values for all conditions and hemispheres
    """
    betas_out = pl.DataFrame()
    idx = ['src_subject_id', 'eventname', 'run']
    
    for condition in sst_conditions.keys():
        betas = {}
        
        for hemi in hemispheres:

            # Construct file paths based on release version
            if release == 'r6':
                run1_fpath = f"{beta_input_dir}sst_{condition}_beta_r01_{hemi}.parquet"
                run2_fpath = f"{beta_input_dir}sst_{condition}_beta_r02_{hemi}.parquet"
            else:
                run1_fpath = f"{beta_input_dir}SST_1_{sst_conditions[condition]}-{hemi}.parquet"
                run2_fpath = f"{beta_input_dir}SST_2_{sst_conditions[condition]}-{hemi}.parquet"
            
            # Read parquet files
            run1 = pl.read_parquet(run1_fpath)
            run2 = pl.read_parquet(run2_fpath)
            
            # Add run column
            run1 = run1.with_columns(pl.lit(1).alias('run'))
            run2 = run2.with_columns(pl.lit(2).alias('run'))
            
            # Combine with vol_info
            run1 = vol_info.hstack(run1)
            run2 = vol_info.hstack(run2)

            # Combine runs
            combined = pl.concat([run1, run2])
            betas[hemi] = combined
        
        # Concatenate hemispheres
        betas_df = concatenate_hemispheres(betas['lh'], betas['rh'])
        
        # Rename columns
        betas_df = betas_df.rename(lambda x: x.replace('tableData', '')
                        if 'tableData' in x else x)
                        
        # Add condition column
        betas_df = betas_df.with_columns(pl.lit(condition).alias('condition'))
        
        # Filter for baseline
        betas_df = betas_df.filter(pl.col('eventname') == 'baseline_year_1_arm_1')
        betas_df = betas_df.filter(pl.col('src_subject_id').is_in(subjects))
        betas_df = betas_df.drop_nulls()
            
        betas_df = betas_df.unpivot(
            index=['src_subject_id', 'eventname', 'run', 'condition'],
            variable_name='vertex', value_name=f'{release}_beta'
        )
        # Concatenate with previous results
        betas_out = pl.concat([betas_out, betas_df]) if not betas_out.is_empty() else betas_df

    # replace 0 with Null
    betas_out = betas_out.with_columns(pl.col(f'{release}_beta').replace(0, None))
    
    return betas_out


## Load unprocessed betas for one condition (correct go)

In [None]:
vol_info_r4 = pl.read_parquet(params['vol_info_path_r4'])
vol_info_r5 = pl.read_parquet(params['vol_info_path_r5'])
vol_info_r6 = pl.read_parquet(params['vol_info_path_r6'])
vol_info_r6 = parse_vol_info(vol_info_r6)

In [None]:
# select a few random subjects from release 4 (not all will carry through because of missingness)
sample_subjects = vol_info_r4.filter(pl.col('eventname') == 'baseline_year_1_arm_1').sample(5, seed=42)
sample_subjects = sample_subjects['src_subject_id'].to_list()
sample_subjects

In [None]:
sst_conditions = {
    'cs': 'correct_stop',
    'cg': 'correct_go',
    'is': 'incorrect_stop',
    'ig': 'incorrect_go',
    'csvcg': 'correct_stop_vs_correct_go',
    'igvcg': 'incorrect_go_vs_correct_go'
    }

In [None]:
r5 = combine_betas(sst_conditions, 
                    ['lh', 'rh'], 
                    params['beta_input_dir_r5'], 
                    vol_info_r5,
                    sample_subjects, 
                    release='r5')

r5.write_parquet("../../data/02_intermediate/r5_betas.parquet")
r5 = pl.read_parquet("../../data/02_intermediate/r5_betas.parquet")
r5

In [None]:
r6 = combine_betas(sst_conditions,
                    ['lh', 'rh'],
                    params['beta_input_dir_r6'],
                    vol_info_r6,
                    sample_subjects,
                    release='r6')
r6.write_parquet("../../data/02_intermediate/r6_betas.parquet")
r6 = pl.read_parquet("../../data/02_intermediate/r6_betas.parquet")
r6

In [None]:
joined = pl.concat([r5, r6], how='align')
joined

## Compare Releases 5 and 6

In [None]:
def plot_betas(subset, row = 'src_subject_id', col = 'condition', xvar='r5_cg', yvar='r6_cg', hue=None):

    g = sns.FacetGrid(subset, col=col, row=row, hue=hue)
    g.map(sns.scatterplot, xvar,  yvar)
    g.set_titles(col_template="{col_name}", row_template="{row_name}")

    if hue is not None:
        g.add_legend()

    for ax in g.axes.flat:
        ax.axline((0, 0), slope=1, color='k', ls='--')
        ax.grid(True, axis='both', linestyle=':')
        ax.set_xlim(-10, 10) 
        ax.set_ylim(-10, 10)
        ax.set_aspect('equal', adjustable='box')
    

plot_betas(joined.to_pandas(), xvar='r5_beta', yvar='r6_beta', hue='run')

## Compare Release 5 to Release 4

In [None]:
# I didn't grab these conditions for r4
sst_conditions.pop('csvcg')
sst_conditions.pop('igvcg')
sst_conditions

In [None]:
r4 = combine_betas(sst_conditions,
                    ['lh', 'rh'],
                    params['beta_input_dir_r4'],
                    vol_info_r4,
                    sample_subjects,
                    release='r4')
r4.write_parquet("../../data/02_intermediate/r4_betas.parquet")
r4 = pl.read_parquet("../../data/02_intermediate/r4_betas.parquet")
r4

In [None]:
joined = pl.concat([r4, r5], how='align')
joined

In [None]:
plot_betas(joined.to_pandas().dropna(), xvar='r4_beta', yvar='r5_beta', hue='run')

### Examine Average Betas

In [None]:
# def combine_average_betas(params, subjects, fpath="../../data/02_intermediate/average_betas.parquet"):

#     r5_cg_avg = pd.read_parquet(params['beta_output_dir_r5'] + 'average_betas_cg.parquet').reset_index()
#     r6_cg_avg = pd.read_parquet(params['beta_output_dir_r6'] + 'average_betas_cg.parquet').reset_index()

#     r5_cg_avg = r5_cg_avg[r5_cg_avg['src_subject_id'].isin(subjects)]
#     r6_cg_avg = r6_cg_avg[r6_cg_avg['src_subject_id'].isin(subjects)]

#     r5_long = r5_cg_avg.melt(id_vars=['src_subject_id', 'eventname'], var_name='vertex', value_name='r5_cg')
#     r6_long = r6_cg_avg.melt(id_vars=['src_subject_id', 'eventname'], var_name='vertex', value_name='r6_cg')

#     idx = ['src_subject_id', 'eventname', 'vertex']
#     long_compare = pd.concat([r5_long.set_index(idx), r6_long.set_index(idx)], axis=1)
#     long_compare = long_compare.reset_index()

#     # persist; this is a big df
#     long_compare.to_parquet(fpath)

#     return long_compare

## Examine Degrees of Freedom (used in averaging)

In [None]:
# dof_r5 = load_degrees_of_freedom(params['mri_r1_dof_path_r5'], params['mri_r2_dof_path_r5'])
# dof_r6 = load_degrees_of_freedom(params['mri_r1_dof_path_r6'], params['mri_r2_dof_path_r6'])

# def reshape_dof(dof):
#     dof.columns = ['run1', 'run2']
#     dof = dof.reset_index().melt(id_vars=['src_subject_id', 'eventname'], var_name='run', value_name='dof')
#     return dof.set_index(['src_subject_id', 'eventname', 'run'])

# dof_r5 = reshape_dof(dof_r5)
# dof_r6 = reshape_dof(dof_r6)

# dof = dof_r5.join(dof_r6, lsuffix='_r5', rsuffix='_r6').reset_index()
# dof

In [None]:
# sns.scatterplot(data=dof.dropna(), x='dof_r5', y='dof_r6', hue='run', alpha=0.5)