In [None]:
import os
import glob
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from abcd_tools.utils.ConfigLoader import load_yaml

In [None]:
params = load_yaml("../parameters.yaml")

In [None]:
def load_betas(betas_path: str) -> pd.DataFrame:
    """Load partitioned betas.

    Args:
        betas_path (str): Path to partitioned betas.

    Returns:
        pd.DataFrame: Partitioned betas.
    """
    files = glob.glob(betas_path + "*.parquet")
    df = pd.DataFrame()
    for f in files:
        tmp = pd.read_parquet(f)
        df = pd.concat([df, tmp], axis=1)
    return df

def broadcast_to_fsaverage(fis_agg: pd.Series, n_vertices=10242 + 1) -> pd.DataFrame:
    """Broadcast feature importance to fsaverage5.

    Args:
        fis_agg (pd.Series): Feature importance.
        n_vertices (int, optional): Number of vertices. Defaults to 10242+1.

    Returns:
        pd.DataFrame: Broadcasted feature importance.
    """

    def _split_hemisphere(df):
        df = df.reset_index(names=["correct", "condition", "hemisphere"])
        lh = df[df["hemisphere"] == "lh"].drop(columns="hemisphere")
        rh = df[df["hemisphere"] == "rh"].drop(columns="hemisphere")
        
        return lh, rh

    fis = fis_agg.copy()

    fis.index = fis.index.str.split("_", expand=True)
    fis = fis.unstack(level=2)
    fis.columns = pd.to_numeric(fis.columns).sort_values()

    # need to insert blank columns for missing vertices
    vertex_names = [*range(1, n_vertices)]
    null_df = pd.DataFrame(np.nan, columns=vertex_names, index=fis.index)
    null_df = null_df.drop(columns=fis.columns)

    df = fis.join(null_df, how="outer")
    lh, rh = _split_hemisphere(df)

    return lh, rh


# betas_r5 = load_betas(params["processed_beta_dir_r5"])
# betas_r6 = load_betas(params["processed_beta_dir_r6"])

In [None]:
betas_r6 = betas_r5.copy()

In [None]:
conditions = unique_regressors = set([c.rsplit("_", 2)[0] for c in betas_r5.columns])
conditions

In [None]:
correlations = betas_r5.corrwith(betas_r6, axis=0)
correlations

In [None]:
def reshape_correlations(hemi: pd.DataFrame, name: str):
    
    hemi = hemi.melt(id_vars=["correct", "condition"], 
                    var_name="vertex", 
                    value_name="correlation")
    hemi['hemi'] = name

    return hemi


In [None]:
lh, rh = broadcast_to_fsaverage(correlations)
lh = reshape_correlations(lh, "lh")
rh = reshape_correlations(rh, "rh")
df = pd.concat([lh, rh])

In [None]:
(df
    .drop(columns='vertex')
    .groupby(["condition", "hemi"])
    .agg(["mean", "std"])
)

In [None]:
sns.FacetGrid(df, col="condition", row='correct', hue="hemi", sharex=False).map(sns.histplot, "correlation").add_legend()