In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import neurotools.plotting as ntp

from abcd_tools.utils.ConfigLoader import load_yaml
params = load_yaml('../parameters.yaml')

In [None]:
def broadcast_to_fsaverage(fis_agg: pd.Series, n_vertices=10242) -> pd.DataFrame:
    """Broadcast feature importance to fsaverage5.

    Args:
        fis_agg (pd.Series): Feature importance.
        n_vertices (int, optional): Number of vertices. Defaults to 10242+1.

    Returns:
        pd.DataFrame: Broadcasted feature importance.
    """

    def _split_hemisphere(df):
        df = df.reset_index(names=["correct", "condition", "hemisphere"])
        lh = df[df["hemisphere"] == "lh"].drop(columns="hemisphere")
        rh = df[df["hemisphere"] == "rh"].drop(columns="hemisphere")

        return lh, rh

    fis = fis_agg.copy()

    fis.index = fis.index.str.split("_", expand=True)
    fis = fis.unstack(level=2)
    # fis = fis.unstack()

    # convert columns to integers and sort
    fis.columns = fis.columns.astype(int)
    fis = fis.reindex(sorted(fis.columns), axis=1)

    # need to insert blank columns for missing vertices
    vertex_names = [*range(1, n_vertices + 1)]
    # vertex_names = [*range(0, n_vertices)]
    null_df = pd.DataFrame(np.nan, columns=vertex_names, index=fis.index)
    null_df = null_df.drop(columns=fis.columns)

    df = fis.join(null_df, how="outer")
    lh, rh = _split_hemisphere(df)

    return lh, rh

def table_to_dict(df: pd.DataFrame, idx=['correct', 'condition']):
    """Take dataframe (hemi) where each row is a double-index condition
    and return a dictionary of numpy arrays. """

    return (df
        .assign(cond=lambda x: x[idx[0]] + '_' + x[idx[1]])
        .fillna(0)
        .drop(columns=idx)
        .set_index('cond')
        .groupby(level=0)
        .apply(lambda x: x.values.flatten())
        .to_dict()
    )

def apply_mask(hemi_dict, hemi_mask):

    masked = {}
    for condition, values in hemi_dict.items():
        
        masked[condition] = np.where(hemi_mask, values, 0)
    
    return masked

In [None]:
def collect_features(fis, filter = 'incorrect',
    parameters = ['EEA', 'tf', 'SSRT', 'B']):

    res = {}

    for parameter in parameters:

        lh, rh = (
            fis[parameter]
            .filter(like=filter)
            .pipe(broadcast_to_fsaverage)
        )

        res[parameter] = (
            {
                'lh': table_to_dict(lh),
                'rh': table_to_dict(rh)
            }
        )

    return res


In [None]:
ridge_fis_path = params['model_results_path'] + 'ridge_feature_importance.pkl'
lasso_fis_path = params['model_results_path'] + 'lasso_feature_importance.pkl'

# ridge_fis = pd.read_pickle(ridge_fis_path)
# lasso_fis = pd.read_pickle(lasso_fis_path)

ridge_fis, ridge_best_fis, ridge_avg_fis, ridge_haufe_avg = pd.read_pickle(ridge_fis_path)
lasso_fis, lasso_best_fis, lasso_avg_fis, lasso_haufe_avg = pd.read_pickle(lasso_fis_path)

target_map = params['target_map']

In [None]:
plot_data = {
    'ridge': {
        'avg': collect_features(ridge_avg_fis),
        'haufe': collect_features(ridge_haufe_avg)
    },
    'lasso': {
        'avg': collect_features(lasso_avg_fis),
        'haufe': collect_features(lasso_haufe_avg)
    }
}

In [None]:
from itertools import product

In [None]:
def get_global_minmax(plot_data):
    all_values = []
    for method in plot_data.values():
        for feature_dict in method.values():
            for hemi in feature_dict.values():
                all_values.extend(hemi.values())
    
    all_values = np.array(all_values)
    all_values = all_values[~np.isnan(all_values)]
    
    min_val = all_values.min()
    max_val = all_values.max()
    return max(abs(min_val), abs(max_val))
    

def make_error_plot(plot_data):

    aggregation = ['Average FIS', "Haufe-Transformed FIS"]
    aggregation = {
        'avg': 'Average FIS',
        'haufe': "Haufe-Transformed FIS"
    }

    targets = ['EEA', 'tf', 'SSRT', 'B']

    conditions = ['incorrect_go', 'incorrect_stop']
    # models = ['ridge', 'lasso']

    cond_mod = list(product(aggregation.keys(), conditions))
    cond_mod = [('', '')] + cond_mod
    targets = [''] + targets

    print(list(cond_mod))
    fig, axs = plt.subplots(ncols=5, nrows=len(targets), figsize=(25,15))
    fontsize = 15

    max = get_global_minmax(plot_data)
    min = -max

    axs[0, 0].set_axis_off()

    for i, target in enumerate(targets):
        
        for j, cond in enumerate(cond_mod):

            # column labels
            if i == 0:
                if cond == ('', ''):
                    pass
                else:
                    ax = axs[0, j]
                    ax.set_axis_off()
                    text = cond[0].title() + '\n\n' +  cond[1].replace('_', ' ').title()
                    ax.text(0, 0.5, text, fontsize=fontsize)
            if j == 0:
                # row labels
                if target == '':
                    pass
                else:
                    ax = axs[i, j]
                    ax.set_axis_off()
                    ax.text(0, 0.5, target_map[target], fontsize=fontsize)
            if (i > 0) and (j > 0):

                ax = axs[i, j]


                data = plot_data[cond[0]][target]
                lh = data['lh'][cond[1]]
                rh = data['rh'][cond[1]]

                ntp.plot(
                    {'lh': lh, 'rh': rh},
                    threshold=0,
                    cmap='seismic',
                    # colorbar=False,
                    # vmin=min,
                    # vmax=max,
                    ax=ax
                )
   
make_error_plot(plot_data['ridge'])
plt.savefig(params['plot_output_path'] + 'ridge_error_fis_plot.png', dpi=300, bbox_inches='tight')
plt.close()


In [None]:

make_error_plot(plot_data['lasso'])
plt.savefig(params['plot_output_path'] + 'lasso_error_fis_plot.png', dpi=300, bbox_inches='tight')

In [None]:
contrast_fis_path = params['model_results_path'] + 'contrasts_ridge_feature_importance.pkl'

contrast_fis, contrast_best_fis, contrast_avg_fis, contrast_haufe_avg = pd.read_pickle(contrast_fis_path)

In [None]:
def get_global_contrast_minmax(plot_data):
    all_values = []
    for target in plot_data.values():
        for hemi in target.values():
            for condition in hemi.values():
                all_values.extend(condition)
    
    all_values = np.array(all_values)
    all_values = all_values[~np.isnan(all_values)]
    
    min_val = all_values.min()
    max_val = all_values.max()
    return max(abs(min_val), abs(max_val))

contrasts_plot = collect_features(
    contrast_haufe_avg,
    filter='correct'
    )
    

In [None]:
get_global_contrast_minmax(contrasts_plot)

In [None]:
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import make_axes


    

def make_colorbar(
    fig, ax, vmin, vmax, cmap, label="Haufe-Transformed Feature Importance"
):

    # plot colorbar
    nb_ticks = 5
    cbar_tick_format = "%.2g"
    norm = Normalize(vmin=vmin, vmax=vmax)
    proxy_mappable = ScalarMappable(norm=norm, cmap=cmap)
    ticks = np.linspace(vmin, vmax, nb_ticks)

    ax.set_axis_off()

    cax, kw = make_axes(ax, fraction=0.5, shrink=0.5)


    fig.colorbar(
        proxy_mappable,
        cax=cax,
        ticks=ticks,
        orientation="vertical",
        format=cbar_tick_format,
        ticklocation="left",
    )


def make_contrast_plot(plot_data, target_map):


    # targets = ['EEA', 'tf', 'SSRT', 'B']
    targets = list(plot_data)

    conditions = {
        '': '',
    'correctstop_correctgo': "Correct Stop vs. Correct Go",
    'correctstop_incorrectgo': "Correct Stop vs. Incorrect Go",
    'incorrectstop_correctstop': "Incorrect Stop vs. Correct Stop",
    'incorrectstop_incorrectgo': "Incorrect Stop vs. Incorrect Go"}

    targets = [''] + targets

    fig, axs = plt.subplots(ncols=len(conditions), nrows=len(targets), figsize=(25,15))
    fontsize = 15

    max = get_global_contrast_minmax(plot_data)
    min = -max


    # axs[0, 0].set_axis_off()

    make_colorbar(fig, axs[0,0], min, max, cmap='seismic')

    for i, target in enumerate(targets):
        
        for j, cond in enumerate(conditions.keys()):

            # column labels
            if i == 0:
                if cond == '':
                    pass
                else:
                    ax = axs[0, j]
                    ax.set_axis_off()
                    # text = cond[0].title() + '\n\n' +  cond[1].replace('_', ' ').title()
                    ax.text(0, 0.5, conditions[cond], fontsize=fontsize)
            if j == 0:
                # row labels
                if target == '':
                    pass
                else:
                    ax = axs[i, j]
                    ax.set_axis_off()
                    ax.text(0, 0.5, target_map[target], fontsize=fontsize)
            if (i > 0) and (j > 0):

                ax = axs[i, j]


                data = plot_data[target]
                lh = data['lh'][cond]
                rh = data['rh'][cond]

                ntp.plot(
                    {'lh': lh, 'rh': rh},
                    threshold=0,
                    cmap='seismic',
                    colorbar=False,
                    vmin=min,
                    vmax=max,
                    ax=ax
                )

    


In [None]:
make_contrast_plot(contrasts_plot, params['target_map'])
plt.savefig(params['plot_output_path'] + 'ridge_contrast_fis_plot.png', dpi=300, bbox_inches='tight')

In [None]:
fis, best_fis, avg_fis, haufe_avg = pd.read_pickle(
    params["model_results_path"] + f"contrasts_ridge_feature_importance.pkl"
)


In [None]:
def format_for_plotting(
    fis: pd.DataFrame, correct: str, condition: str
) -> pd.DataFrame:
    tmp = fis[(fis["correct"] == correct) & (fis["condition"] == condition)]

    tmp = tmp.drop(columns=["correct", "condition"])
    tmp[np.isnan(tmp)] = 0
    return tmp


In [None]:
lh, rh = broadcast_to_fsaverage(haufe_avg['EEA'])
# haufe_avg['EEA']

In [None]:
# format_for_plotting(lh, correct = 'correctstop', condition='correctgo')
lh

In [None]:
from itertools import product
correct = ["correct", "incorrect"]
cond = ["go", "stop"]

correct_cond = list(product(correct, cond))
correct_cond

In [None]:
import glob

def strip_suffix(s: str):
    """Strip the first two parts of a string separated by underscores."""
    return "_".join(s.split("_")[2:])

def make_contrast(fpath: str, cond1: tuple, cond2: tuple) -> pd.DataFrame:
    """Make a contrast between two conditions.
    Args:
        fpath (str): Path to the folder containing the betas.
        cond1 (tuple): First condition (string, string).
        cond2 (tuple): Second condition (string, string).
    Returns:
        pd.DataFrame: Contrast between the two conditions.
    """

    files = glob.glob(fpath + "*.parquet")
    df1 = pd.read_parquet([f for f in files if cond1[0] in f][0])
    df2 = pd.read_parquet([f for f in files if cond2[0] in f][0])

    df1 = df1.fillna(0)
    df2 = df2.fillna(0)

    df1.columns = [strip_suffix(col) for col in df1.columns]
    df2.columns = [strip_suffix(col) for col in df2.columns]

    contrast_name = f"{cond1[1]}_{cond2[1]}"
    df = df1 - df2
    df = df.rename(columns={col: f"{contrast_name}_{col}" for col in df.columns})

    return df
def load_contrasts(params: dict) -> pd.DataFrame:
    """Load specific contrasts from processed betas."""
    return pd.concat(
        [
            make_contrast(
                params["sst_betas_path"], ("cs", "correctstop"), ("cg", "correctgo")
            ),
            make_contrast(
                params["sst_betas_path"], ("cs", "correctstop"), ("ig", "incorrectgo")
            ),
            make_contrast(
                params["sst_betas_path"], ("is", "incorrectstop"), ("cs", "correctstop")
            ),
            make_contrast(
                params["sst_betas_path"], ("is", "incorrectstop"), ("ig", "incorrectgo")
            ),
        ],
        axis=1,
    )
sst_contrasts = load_contrasts(params)




In [None]:
from abcd_tools.utils.io import load_tabular

mri_confounds_sst = load_tabular(params["mri_confounds_sst_path"])
sst_scopes = pd.read_pickle(params['sst_scopes_path'])
sst_contrast_scopes = pd.read_pickle(params['sst_contrast_scopes_path'])
targets_no_tf = load_tabular(params["targets_no_tf_path"])
targets_nback = load_tabular(params["nback_targets_path"])


In [None]:
import BPt as bp

def make_bpt_dataset(
    betas: pd.DataFrame,
    scopes: dict,
    mri_confounds: pd.DataFrame,
    targets: pd.DataFrame,
    test_split=0.2,
    random_state=123,
    fpath: str = None,
) -> bp.Dataset:
    """Create a BPt dataset from betas, confounds, and targets.

    Args:
        betas (pd.DataFrame): Concatenated betas.
        scopes (dict): Scopes.
        mri_confounds (pd.DataFrame): MRI confounds.
        targets (pd.DataFrame): Behavioral targets.
        test_split (float, optional): Test split. Defaults to 0.2.
        random_state (int, optional): Random state. Defaults to 123.
        fpath (str, optional): Path to save dataset. Defaults to None.

    Returns:
        bp.Dataset: BPt dataset.
    """

    df = pd.concat([betas, mri_confounds, targets], axis=1)
    df = df.dropna(axis=1, how="all").dropna(axis=1, how="all").dropna()

    dataset = bp.Dataset(df, targets=targets.columns.tolist())

    scopes["covariates"] = mri_confounds.columns.tolist()
    for k, v in scopes.items():
        dataset.add_scope(v, k, inplace=True)

    dataset = dataset.auto_detect_categorical()
    dataset = dataset.add_scope("mri_info_deviceserialnumber", "category")
    dataset = dataset.ordinalize("category")

    # deal with possible inf values
    dataset = dataset.replace([np.inf, -np.inf], np.nan)
    dataset = dataset.dropna()

    dataset = dataset.set_test_split(test_split, random_state=random_state)

    if fpath is not None:
        dataset.to_pickle(fpath)
        print(f"Dataset saved to {fpath}")

    return dataset


make_bpt_dataset(
    sst_contrasts,
    sst_contrast_scopes,
    mri_confounds_sst,
    targets_nback,
    fpath=params["sst_nback_dataset_path"]
)