# Mask feature importance for salience network using Yeo et al 2011

In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from nilearn import datasets
from nilearn.surface import load_surf_data

import neurotools.plotting as ntp

from abcd_tools.utils.ConfigLoader import load_yaml
params = load_yaml('../parameters.yaml')

## Load Freesurfer fsaverage5 Yeo2011 network atlas

In [None]:
fs_templates = "../../data/01_raw/label/"
yeo_lh_path = fs_templates + "lh.Yeo2011_7Networks_N1000.annot"
yeo_rh_path = fs_templates + "rh.Yeo2011_7Networks_N1000.annot"

yeo_lh_atlas = load_surf_data(yeo_lh_path)
yeo_rh_atlas = load_surf_data(yeo_rh_path)

In [None]:
# plot atlas

fsaverage = datasets.load_fsaverage('fsaverage5')
fsaverage_sulcal = datasets.load_fsaverage_data(data_type="sulcal")

ntp.plot(
    {
        'lh': yeo_lh_atlas.flatten(),
        'rh': yeo_rh_atlas.flatten()
    },
    threshold=0,
    cmap="Set2"
)

In [None]:
# subset ventral attention (salience) network, binarize
salience_idx = 4
salience_lh = np.where(yeo_lh_atlas == salience_idx, 1, 0)
salience_rh = np.where(yeo_rh_atlas == salience_idx, 1, 0)

visual_idx = 1
visual_lh = np.where(yeo_lh_atlas == visual_idx, 1, 0)
visual_rh = np.where(yeo_rh_atlas == visual_idx, 1, 0)

In [None]:
ntp.plot(
    {
        'lh': salience_lh,
        'rh': salience_rh
    }
)

### Import feature importance

In [None]:
fis_path = params['model_results_path'] + 'vertex_ridge_feature_importance.pkl'
fis, best_fis, avg_fis, haufe_avg, haufe_fis = pd.read_pickle(fis_path)

# peel off EEA and pTF
EEA_hauf = haufe_avg['EEA']
tf_hauf = haufe_avg['tf']

In [None]:
def broadcast_to_fsaverage(fis_agg: pd.Series, n_vertices=10242) -> pd.DataFrame:
    """Broadcast feature importance to fsaverage5.

    Args:
        fis_agg (pd.Series): Feature importance.
        n_vertices (int, optional): Number of vertices. Defaults to 10242+1.

    Returns:
        pd.DataFrame: Broadcasted feature importance.
    """

    def _split_hemisphere(df):
        df = df.reset_index(names=["correct", "condition", "hemisphere"])
        lh = df[df["hemisphere"] == "lh"].drop(columns="hemisphere")
        rh = df[df["hemisphere"] == "rh"].drop(columns="hemisphere")

        return lh, rh

    fis = fis_agg.copy()

    fis.index = fis.index.str.split("_", expand=True)
    fis = fis.unstack(level=2)
    # fis = fis.unstack()

    # convert columns to integers and sort
    fis.columns = fis.columns.astype(int)
    fis = fis.reindex(sorted(fis.columns), axis=1)

    # need to insert blank columns for missing vertices
    vertex_names = [*range(1, n_vertices + 1)]
    # vertex_names = [*range(0, n_vertices)]
    null_df = pd.DataFrame(np.nan, columns=vertex_names, index=fis.index)
    null_df = null_df.drop(columns=fis.columns)

    df = fis.join(null_df, how="outer")
    lh, rh = _split_hemisphere(df)

    return lh, rh

In [None]:
def table_to_dict(df: pd.DataFrame, idx=['correct', 'condition']):
    """Take dataframe (hemi) where each row is a double-index condition
    and return a dictionary of numpy arrays. """

    return (df
        .assign(cond=lambda x: x[idx[0]] + '_' + x[idx[1]])
        .drop(columns=idx)
        .set_index('cond')
        .groupby(level=0)
        .apply(lambda x: x.values.flatten())
        .to_dict()
    )

def apply_mask(hemi_dict, hemi_mask):

    masked = {}
    for condition, values in hemi_dict.items():
        
        masked[condition] = np.where(hemi_mask, values, 0)
    
    return masked


In [None]:
def plot_hist(lh, rh, ax):

    def _flat_df(lh, rh):
        array = np.concatenate([lh, rh])
        array = array[array != 0]
        posneg = array > 0

        return pd.DataFrame({
            'values': array,
            'posneg': posneg
        })

    df = _flat_df(lh, rh)

    n_pos = sum(df['posneg'])
    prop = n_pos / len(df)

    mean = df['values'].mean()
    std = df['values'].std()

    label_map = {True: 'Pos.', False: "Neg."}
    df = df.replace(label_map)

    sns.histplot(
        data=df,
        x='values',
        hue='posneg',
        palette='seismic',
        hue_order=['Neg.', 'Pos.'],
        ax=ax
    )
    ax.set_xlabel("")
    ax.set_ylabel("Number of Features")

    ax.set_title(
        # f'Pos. vertices: {n_pos}; neg. vertices: {n_neg}'
        rf'N Pos. Features: {n_pos} ({prop:.2%}); Mean FIS = {mean:.2e} $\pm$ {std:.2e}'
    )

# fig, ax = plt.subplots()
# plot_hist(lh_salience['correct_stop'], rh_salience['correct_stop'], ax)

In [None]:
def make_masked_figure(lh, rh, conditions, target_map, title):

    fig, axs = plt.subplots(ncols=len(conditions), nrows=2, 
                        figsize=(30, 8), sharey=True, 
                        height_ratios=[2.25, 1])

    for i, condition in enumerate(conditions):

        lh_plot = lh[condition]
        rh_plot = rh[condition]

        ax = axs[0, i]

        ax.set_title(target_map[condition])

        ntp.plot({'lh': lh_plot, 'rh': rh_plot},
            threshold=0,
            cmap='seismic',
            colorbar=False,
            ax=ax
        )

        ax = axs[1, i]


        plot_hist(lh_plot, rh_plot, ax)
    
    plt.suptitle(title, size=20, x=0.15)

def generate_plot(fis, mask_lh, mask_rh, target_map, title):

    lh_tab, rh_tab = broadcast_to_fsaverage(fis)
    lh_masked = apply_mask(table_to_dict(lh_tab), mask_lh)
    rh_masked = apply_mask(table_to_dict(rh_tab), mask_rh)

    conditions = lh_masked.keys()

    make_masked_figure(lh_masked, rh_masked, conditions, target_map, title)


## Ventral Attention (Salience) network

### EEA

In [None]:
target_map = params['target_map']

# generate_plot(EEA_hauf, salience_lh, salience_rh, target_map, "EEA")
# plt.savefig(params['plot_output_path'] + 'EEA_salience_fis.png', dpi=300, bbox_inches='tight')

### $pTF$

In [None]:
# generate_plot(tf_hauf, salience_lh, salience_rh, target_map, r"$p$TF")

# plt.savefig(params['plot_output_path'] + 'tf_salience_fis.png', dpi=300, bbox_inches='tight')

## Visual network

### EEA

In [None]:
# generate_plot(EEA_hauf, visual_lh, visual_rh, target_map, "EEA")
# plt.savefig(params['plot_output_path'] + 'EEA_visual_fis.png', dpi=300, bbox_inches='tight')

### $pTF$

In [None]:
# generate_plot(tf_hauf, visual_lh, visual_rh, target_map, r"$p$TF")
# plt.savefig(params['plot_output_path'] + 'tf_visual_fis.png', dpi=300, bbox_inches='tight')

## Exmine ACC

In [None]:
from netneurotools import datasets as netds
schaefer = netds.fetch_schaefer2018(data_dir="../../data/01_raw/nnt/", version='fsaverage5')

In [None]:
schaefer['100Parcels7Networks']

In [None]:
schaefer_lh = schaefer['400Parcels7Networks'][0]
schaefer_rh = schaefer['400Parcels7Networks'][1]

schaefer_lh = load_surf_data(schaefer_lh)
schaefer_rh = load_surf_data(schaefer_rh)

In [None]:
mni_schaefer = datasets.fetch_atlas_schaefer_2018()
labs = [l.decode() for l in mni_schaefer.labels]

rois = [
    # 'Med_1',
    'Med_2',
    # 'Med_3',
    # 'Med_4'
    # 'Med_5'
    # 'Med_6'
    # 'Med_7'
    # 'Med_8'
]

salience_rois = [(idx, label) for idx, label in enumerate(labs) for roi in rois if roi in label]
salience_lh_idx = [item[0] for item in salience_rois if 'LH' in item[1]]
salience_rh_idx = [item[0] for item in salience_rois if 'RH' in item[1]]




ntp.plot(
    {
        'lh': np.where(np.isin(schaefer_lh, salience_lh_idx), 1, 0),
        'rh': np.where(np.isin(schaefer_rh, salience_rh_idx), 1, 0)
    },
    threshold=0
)

salience_rois

In [None]:
acc_lh_mask = np.where(np.isin(schaefer_lh, salience_lh_idx), 1, 0)
acc_rh_mask = np.where(np.isin(schaefer_rh, salience_rh_idx), 1, 0)


In [None]:
generate_plot(haufe_avg['EEA'], acc_lh_mask, acc_lh_mask, target_map, 'EEA (ACC)')
plt.savefig(params['plot_output_path'] + 'EEA_acc_mask.png', dpi=300, bbox_inches='tight')