In [None]:
import os

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import neurotools.plotting as ntp
import seaborn as sns

from joblib import Parallel, delayed
from itertools import repeat

from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import make_axes

from nilearn.datasets import fetch_atlas_surf_destrieux

from neurotools.plotting.ref import SurfRef

from abcd_tools.image.preprocess import map_hemisphere

from scipy.stats import false_discovery_control, pearsonr
from itertools import product
from abcd_tools.utils.ConfigLoader import load_yaml

In [None]:
cond = [
    ("correctstop", "correctgo"),
    ("correctstop", "incorrectgo"),
    ("incorrectstop", "correctstop"),
("incorrectstop", "incorrectgo")
]

In [None]:
params = load_yaml("../parameters.yaml")
model = "contrasts_ridge"
fis, best_fis, avg_fis, haufe_avg = pd.read_pickle(
        params["model_results_path"] + f"{model}_feature_importance.pkl"
    )

In [None]:
from nilearn.datasets import fetch_atlas_surf_destrieux

def map_destrieux(
    lh: pd.DataFrame,
    rh: pd.DataFrame,
    prefix: str = "",
    mask_non_significant=False,
    use_fdr=False,
    alpha=0.01,
) -> pd.DataFrame:
    """Map Destrieux atlas.

    Args:
        lh (pd.DataFrame): Left hemisphere.
        rh (pd.DataFrame): Right hemisphere.
        prefix (str, optional): Prefix. Defaults to ''.

    Returns:
        pd.DataFrame: Mapped dataframe.
    """

    dest = load_destrieux_atlas()

    correct_values = pd.unique(lh["correct"])
    condition_values = pd.unique(lh["condition"])

    idx = ["correct", "condition"]

    lh_df = pd.DataFrame()
    rh_df = pd.DataFrame()

    lh_tvalues = pd.DataFrame()
    rh_tvalues = pd.DataFrame()

    lh_pvalues = pd.DataFrame()
    rh_pvalues = pd.DataFrame()

    def _assemble_df(lh_mapped, rh_mapped, lh_correct, rh_correct, lh, rh):
        lh_mapped.index = lh_correct.index
        rh_mapped.index = rh_correct.index

        lh_tmp = pd.concat([lh, lh_mapped])
        rh_tmp = pd.concat([rh, rh_mapped])

        return lh_tmp, rh_tmp

    def apply_fdr(pvalues):
        return pd.DataFrame(
            false_discovery_control(pvalues, method="by"),
            index=pvalues.index,
            columns=pvalues.columns,
        )

    for correct in correct_values:

        for condition in condition_values:

            lh_correct = lh[(lh["correct"] == correct) & (lh["condition"] == condition)]
            rh_correct = rh[(rh["correct"] == correct) & (rh["condition"] == condition)]

            lh_correct = lh_correct.set_index(idx)
            rh_correct = rh_correct.set_index(idx)

            lh_mapped, lh_t, lh_p = map_hemisphere(
                lh_correct,
                mapping=dest["map_left"],
                labels=dest["labels"],
                prefix=prefix,
                suffix=".lh",
                return_statistics=True,
                decode_ascii=False,
            )
            rh_mapped, rh_t, rh_p = map_hemisphere(
                rh_correct,
                mapping=dest["map_right"],
                labels=dest["labels"],
                prefix=prefix,
                suffix=".rh",
                return_statistics=True,
                decode_ascii=False,
            )

            lh_df, rh_df = _assemble_df(
                lh_mapped, rh_mapped, lh_correct, rh_correct, lh_df, rh_df
            )
            lh_tvalues, rh_tvalues = _assemble_df(
                lh_t, rh_t, lh_correct, rh_correct, lh_tvalues, rh_tvalues
            )
            lh_pvalues, rh_pvalues = _assemble_df(
                lh_p, rh_p, lh_correct, rh_correct, lh_pvalues, rh_pvalues
            )

    if use_fdr:
        lh_pvalues = apply_fdr(lh_pvalues)
        rh_pvalues = apply_fdr(rh_pvalues)

    if mask_non_significant:
        lh_tvalues = lh_tvalues.mask(lh_pvalues > alpha)
        rh_tvalues = rh_tvalues.mask(rh_pvalues > alpha)

    df = pd.concat([lh_df, rh_df], axis=1)
    vmin, vmax = get_fullrang_minmax(df)

    return lh_df.reset_index(), rh_df.reset_index(), vmin, vmax

In [None]:
lh, rh = broadcast_to_fsaverage(haufe_avg['EEA'])

from scipy.stats import ttest_1samp

def compute_tstat(mapping: dict) -> dict:
     """Compute t-statistics.

     Args:
         lh_mapping (pd.DataFrame): Left hemisphere mapping.

     Returns:
         dict: T-statistics.
     """

     t_values = {}
     p_values = {}

     for roi, vertex in mapping.items():
         t, p = ttest_1samp(vertex, 0, axis=0, nan_policy='omit')
         t_values[roi] = t
         p_values[roi] = p

     return t_values, p_values

In [None]:
lh

In [None]:
def map_hemisphere(vertices: pd.DataFrame, mapping: np.array, labels: list,
                   prefix: str=None, suffix: str=None,
                   decode_ascii: bool=True, return_statistics: bool=False
                   ) -> pd.DataFrame:
    """Map tabular vertexwise fMRI values to ROIs using nonzero average aggregation.

    Args:
        vertices (pd.DataFrame): Tabular vertexwise data (columns are vertices).
        mapping (np.array): Array of ROI indices. Must be the same length as `vertices`.
        labels (list): ROI labels for resulting averaged values.
        prefix (str, optional): Prefix added to all column names. Defaults to None.
        suffix (str, optional): Suffix added to all column names. Defaults to None.

    Returns:
        pd.DataFrame: Nonzero-averaged ROIs.
    """

    if decode_ascii:
        labels = [label.decode() for label in labels]

    map_dict = {}
    avg_dict = {}

    if isinstance(vertices, pd.DataFrame):
        vertices = vertices.values

    for idx in mapping:

        indices = np.where(mapping == idx)[0]

        map_dict[idx] = vertices[:, indices]
        map_dict[idx][map_dict[idx] == 0] = np.nan
        avg_dict[idx] = np.nanmean(map_dict[idx], axis=1)
        map_dict[idx] = map_dict[idx]


    def _assemble_df(collection: dict, labels, prefix, suffix) -> pd.DataFrame:
        df = pd.DataFrame(collection, index=[0])
        df = df.reindex(sorted(df.columns), axis=1)

        if len(labels) > df.shape[1]:
            labels = labels[1:]

        labels = [prefix + str(label) + suffix for label in labels]
        df.columns = labels

        return df

    rois = _assemble_df(avg_dict, labels, prefix, suffix)

    if return_statistics:

        tvalues, pvalues = compute_tstat(map_dict)

        tvalues = _assemble_df(tvalues, labels, prefix, suffix)
        pvalues = _assemble_df(pvalues, labels, prefix, suffix)

        return rois, tvalues, pvalues
    else:
        return rois

map_destrieux(lh, rh)