In [None]:
import os
import pickle
import glob

from itertools import repeat

import pandas as pd
import numpy as np

import BPt as bp
from BPt.extensions import LinearResidualizer

from joblib import Parallel, delayed    

import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.colors import Normalize
from matplotlib.colorbar import make_axes
from matplotlib.cm import ScalarMappable

from sklearn.preprocessing import OneHotEncoder
from abcd_tools.utils.io import load_tabular
from abcd_tools.utils.ConfigLoader import load_yaml
from abcd_tools.image.preprocess import map_hemisphere

import BPt as bp

import neurotools.plotting as ntp
from neurotools.plotting.ref import SurfRef, VolRef

from nilearn.datasets import fetch_atlas_surf_destrieux

In [None]:
model = 'ridge'
summary = pd.read_csv(params['model_results_path'] + f'{model}_models_summary.csv')
summary

In [None]:
params = load_yaml("../parameters.yaml")

In [None]:

def load_betas(betas_path: str) -> pd.DataFrame:
    """Load partitioned betas.

    Args:
        betas_path (str): Path to partitioned betas.

    Returns:
        pd.DataFrame: Partitioned betas.
    """
    files = os.listdir(betas_path)
    df = pd.DataFrame()
    for f in files:
        tmp = pd.read_parquet(betas_path + f)
        df = pd.concat([df, tmp], axis=1)
    return df
# betas = load_betas(params["betas_path"])

In [None]:
# mri_confounds_path = params['mri_confounds_path']
# mri_confounds = load_tabular(mri_confounds_path)
# mri_confounds

In [None]:
# targets_path = params['filtered_behavioral_path']
# targets = load_tabular(targets_path)
# targets

In [None]:
def gather_scopes(betas: pd.DataFrame, mri_confounds: pd.DataFrame, 
        fpath: str=None) -> dict:
    """Assign names to predictors based on SST condition, plus covariates.

    Args:
        betas (pd.DataFrame): Concatenated betas. 
        mri_confounds (pd.DataFrame): MRI confounds.  

    Returns:
        dict: Scopes
    """
    unique_regressors = set([c.rsplit('_', 2)[0] for c in betas.columns])
    scopes = {}

    for u in unique_regressors:
        scopes[u] = []
        for c in betas.columns:
            if u == '_'.join(c.split('_', 2)[:2]):
                scopes[u].append(c)

    scopes['mri_confounds'] = mri_confounds.columns

    if fpath is not None:
        pd.to_pickle(scopes, fpath)
        print(f"Scopes saved to {fpath}")
    
    return scopes


# predictor_scopes = gather_scopes(betas, mri_confounds, params['model_results_path'])
# predictor_scopes.keys()

In [None]:
def make_bpt_dataset(betas: pd.DataFrame, scopes: dict, mri_confounds: pd.DataFrame, 
                targets: pd.DataFrame, test_split=0.2, random_state=42,
                fpath: str=None) -> bp.Dataset:
    """Create a BPt dataset from betas, confounds, and targets.

    Args:
        betas (pd.DataFrame): Concatenated betas.
        scopes (dict): Scopes.
        mri_confounds (pd.DataFrame): MRI confounds.
        targets (pd.DataFrame): Behavioral targets.
        test_split (float, optional): Test split. Defaults to 0.2.
        random_state (int, optional): Random state. Defaults to 42.
        fpath (str, optional): Path to save dataset. Defaults to None.

    Returns:
        bp.Dataset: BPt dataset.
    """
    
    scopes['covariates'] = mri_confounds.columns.tolist()

    df = pd.concat([betas, mri_confounds, targets], axis=1)
    dataset = bp.Dataset(df, targets=targets.columns.tolist())

    for k, v in scopes.items():
        dataset.add_scope(v, k, inplace=True)

    dataset = dataset.auto_detect_categorical()
    dataset = dataset.add_scope('mri_info_deviceserialnumber', 'category')
    dataset = dataset.ordinalize('category')

    dataset = dataset.dropna()

    dataset = dataset.set_test_split(test_split, random_state=random_state)

    if fpath is not None:
        dataset.to_pickle(fpath)
        print(f"Dataset saved to {fpath}")

    return dataset


# ds = make_bpt_dataset(betas, predictor_scopes, mri_confounds, targets)

# ds.to_pickle("../../data/04_model_input/rdex_prediction_dataset.pkl")
# ds = pd.read_pickle("../../data/04_model_input/rdex_prediction_dataset.pkl")
    

In [None]:
def define_regression_pipeline(ds: bp.Dataset, scopes: dict, 
        model: str) -> bp.Pipeline:

    # Just scale float type features
    robust_scaler = bp.Scaler('robust', scope='float')

    # Define residualization procedure
    ohe = OneHotEncoder(categories='auto', drop='if_binary', 
            sparse=False, handle_unknown='ignore')
    ohe_tr = bp.Transformer(ohe, scope='category')
    
    resid = LinearResidualizer(to_resid_df=ds['covariates'], fit_intercept=True)
    resid_tr = bp.Scaler(resid, scope=list(scopes.keys()))

    # Define regression model
    mod_params = {'alpha': bp.p.Log(lower=1e-5, upper=1e5)}

    if model == "ridge":
        mod_obj="ridge"
    elif model == 'elastic':
        mod_obj=ElasticNet()
        l1_ratio = bp.p.Scalar(lower=0.001, upper=1).set_mutation(sigma=0.165)
        mod_params['l1_ratio'] = l1_ratio
    elif model == 'lasso':
        mod_obj='lasso regressor'

    param_search = bp.ParamSearch('HammersleySearch', n_iter=100, cv='default')
    model = bp.Model(
        obj=mod_obj, 
        params=mod_params,  
        param_search=param_search
    )

    return bp.Pipeline([robust_scaler, ohe_tr, resid_tr, model])

# pipe = define_regression_pipeline(ds, predictor_scopes, 'ridge')


In [None]:
def fit_model(ds: bp.Dataset, scopes: dict, model='ridge', complete=False,
    n_cores=-1, random_state=42) -> bp.CompareDict:
    
    if model in ['elastic', 'ridge', 'lasso']:
        pipe = define_regression_pipeline(ds, scopes, model=model)
    else:
        raise Exception(f'Specified model {model} is not implemented')
    
    if not complete:
        compare_scopes = []
        for key in scopes.keys():
            if key != 'covariates': 
                compare_scopes.append(bp.Option(['covariates', key], name=f'cov + {key}'))
            else:
                compare_scopes.append('covariates')
        compare_scopes = bp.Compare(compare_scopes)
    else:
        compare_scopes = None

    if n_cores == -1:
        n_cores = os.cpu_count()

    ps = bp.ProblemSpec(n_jobs=n_cores, random_state=random_state)
    cv = bp.CV(splits=5, n_repeats=1)

    results = bp.evaluate(pipeline=pipe,
                      dataset=ds,
                      problem_spec=ps,
                      scope=compare_scopes,
                      mute_warnings=True,
                      target=bp.Compare(ds.get_cols('target')),
                      cv=cv)

    return results

In [None]:
def filter_scopes(keys: list, scopes: dict) -> dict:
    keys.extend(['mri_confounds'])
    return {k: v for k, v in scopes.items() if k in keys}
    
def save_model_results(res: bp.CompareDict, name: str, model: str, path: str) -> None:
    """Save model results to disk.

    Args:
        res (bp.CompareDict): Model results.
        name (str): Model name.
        model (str): Model type.
        path (str): Path to save results.
    """
    summary = res.summary()
    summary.to_csv(f'{path}/{name}_{model}_summary.csv')
    pd.to_pickle(res, f'{path}/{name}_{model}_results.pkl')
    print(f"Results saved to {path}")

    

In [None]:
def fit_full_model(ds: bp.Dataset, scopes: dict, model: str, fpath: str) -> bp.CompareDict:
    res = fit_model(ds, scopes, model, complete=True)
    save_model_results(res, 'all_conditions', model, fpath)

In [None]:
model_comparisons = ['ridge', 'elastic', 'lasso']
predictor_comparison = ['incorrect_go', 'correct_go', 'correct_stop, incorrect_stop']


# scopes = filter_scopes(['incorrect_go'], predictor_scopes)

In [None]:
def generate_slurm_script(model: str, 
                    condition: str,
                    walltime: str='24:00:00',
                    mem: str='256G', 
                    envname: str='sst-rdex',
                    fpath:str="../slurm/") -> str:
    script = f"""#!/bin/bash
#SBATCH --job-name={model}_{condition}
#SBATCH --output=logs/%x_%j.out
#SBATCH --error=logs/%x_%j.err
#SBATCH --time={walltime}
#SBATCH --partition=general
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem={mem}
#SBATCH --mail-type=ALL

source ${{HOME}}/.bashrc
conda activate {envname}

cd ../pipelines/

python3 run_model.py {model} {condition}
"""
    with open(f"{fpath}/{model}_{condition}.sh", 'w') as f:
        f.write(script)
    print(f"Slurm script saved to {fpath}/{model}_{condition}.sh")
    
    return script


# model_comparisons = ['ridge', 'elastic', 'lasso']
# predictor_comparison = ['incorrect_go', 'correct_go', 'correct_stop', 'incorrect_stop']

# for model in model_comparisons:
#     for condition in predictor_comparison:
#         generate_slurm_script(model, condition)


# generate_slurm_script('ridge', 'incorrect_go')


In [None]:
# def make_effect_compare_plot(df, title='Vertexwise Regressor Model Fit Comparison', contrasts=False):

#     if contrasts:
#         df['procedure'] = df['procedure'].str.cat(df['inputs'], sep=': ')
#         hatches = ['', 'xx', 'oo', 'OO']
#     else:
#         hatches = ['', 'oo', '+', 'OO']

#     fig, ax = plt.subplots(figsize=(15,5))
    
#     # greypallete = np.repeat('lightgrey', len(df))

#     n_procedures = len(df['procedure'].drop_duplicates())
#     greypallete = list(np.repeat('lightgrey', n_procedures))

#     order =df['target'].drop_duplicates()

#     g = sns.barplot(
#         x='target', 
#         y='mean_scores_r2', 
#         hue='procedure', 
#         data=df,
#         palette=greypallete,
#         order=order)

#     g.legend_.set_title('')

#     ax.grid(linestyle=':')
#     bars = ax.patches[:len(ax.patches)-n_procedures]
#     x_coords = [p.get_x() + 0.5 * p.get_width() for p in bars]
#     y_coords = [p.get_height() for p in bars]

#     ax.errorbar(x=x_coords, y=y_coords, yerr=df["std_scores_r2"], fmt="none", c="k")

#     # only want one set of colors
#     palette = df[['target', 'color']].drop_duplicates()['color']
 
#     for bars, hatch, legend_handle in zip(ax.containers, hatches, 
#                                           ax.legend_.legendHandles):
#         for bar, color in zip(bars, palette):
#             bar.set_facecolor(color)
#             bar.set_hatch(hatch)
#         legend_handle.set_hatch(hatch + hatch)

#     ptplt = sns.pointplot(
#         x='target', 
#         y='test_r2', 
#         data=df, 
#         hue='procedure', 
#         markersize=5,
#         dodge=0.5, 
#         linestyles="none",
#         palette=greypallete,
#         order=order,
#         legend=False
#     )

    # # formatting
    # ax.set(xlabel=None)
    # ax.set(ylabel='Avg. $R^{2}$')
    # ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    # ax.spines[['top', 'right']].set_visible(False)

    # fig.subplots_adjust(top=0.9)
    # fig.suptitle(title)

In [None]:
def join_test_prediction(results_summary: pd.DataFrame, test_prediction: pd.DataFrame) -> pd.DataFrame:
    """Join test predictions with results summary.

    Args:
        results_summary (pd.DataFrame): Results summary.
        test_prediction (pd.DataFrame): Test predictions.
    
    Returns:
        pd.DataFrame: Joined dataframes.
    """
    idx = ['target', 'scope']
    return results_summary.join(test_prediction.set_index(idx), on=idx, how='left')

def get_test_prediction(results, metric='r2'):
    """Get test prediction from results.

    Args:
        results (bp.CompareDict): Model results.
        metric (str, optional): Metric to extract. Defaults to 'r2'.
    
    Returns:
        pd.DataFrame: Test predictions
    """

    if isinstance(results, str):
        results = pd.read_pickle(results)

    scores = pd.DataFrame()

    for l, m in results.items():

        lab = l.__dict__['options']

        if len(lab) == 1:
            scope = 'all'
            target = lab[0].__dict__['name']
        else:
            scope = lab[0].__dict__['name']
            target = lab[1].__dict__['name']

        if scope == 'cov + mri_confounds':
            continue
        else:

            best_model_idx = np.argmax(m.scores[metric])
            best_model = m.estimators[best_model_idx]

            ds = m._dataset
            X_test, y_test = ds.get_Xy(m.ps, subjects='test')
            pred = best_model.score(X_test, y_test)
            
            tmp = pd.DataFrame({'scope': scope, 'target': target, 'test_r2': pred}, index=[0])
            scores = pd.concat([scores, tmp], axis=0)

    return scores

In [None]:
def assemble_summary(results_path: str, params: dict, model='ridge', n_jobs=2) -> pd.DataFrame:
    """Assemble model summary.
    
    Args:
        results_path (str): Path to model results.
    
    Returns:
        pd.DataFrame: Model summary.
    """

    results_paths = glob.glob(params['model_results_path']+ f"*{model}_results.pkl")
    summary_paths = glob.glob(params['model_results_path']+ f"*{model}_summary.csv")

    test_predictions = Parallel(n_jobs=n_jobs)(
        delayed(get_test_prediction)(r) for r in results_paths
    )
    summary = pd.concat([pd.read_csv(p) for p in summary_paths])
    summary = join_test_prediction(summary, pd.concat(test_predictions))

    summary.to_csv(results_path + f'{model}_models_summary.csv')

    return summary


In [None]:
def relabel_plotting_data(df, process_map, target_map, color_map):
  """Relabel data for plotting.

  Args:
      df (pd.DataFrame): Dataframe.
      process_map (dict): Process map.
      target_map (dict): Target map.
      color_map (dict): Color map.

  Returns:
      pd.DataFrame: Relabeled dataframe
  """

  df = df[df['scope'] != 'cov + mri_confounds']
  df['scope'] = df['scope'].str.replace('cov \+ ', '', regex=True)
  df.loc[:, 'process'] = df['target']
  df['process'] = df['process'].replace(process_map)

  df.loc[:, 'color'] = df['process']
  df['color'] = df['color'].replace(color_map)
  df['target'] = df['target'].replace(target_map)

  return df

def sort(df: pd.DataFrame) -> pd.DataFrame:
  """Sort dataframe.

  Args:
      df (pd.DataFrame): Dataframe.
    
  Returns:
      pd.DataFrame: Sorted dataframe.
  """
  avg = (df[['target', 'mean_scores_r2', 'std_scores_r2']]
          .groupby('target')
          .mean(numeric_only=True)
          .sort_values('mean_scores_r2', ascending=False)
  )
  avg.columns = ['avg_mean', 'avg_std']
  df = (df
        .set_index('target')
      #   .drop(columns=['test_r2'])
        .join(avg)
        .sort_values(by=['process', 'avg_mean'], ascending=[True, False])
        .reset_index()
        .drop(columns=['avg_mean', 'avg_std'])
  )
  return df


process_map = params['process_map']
target_map = params['target_map']
color_map = params['color_map']

# summary = pd.read_csv(params['model_results_path'] + 'ridge_models_summary.csv')
# summary = relabel_plotting_data(summary, process_map, target_map, color_map)
# sort(summary)

In [None]:
def make_effect_compare_plot(df, title='Vertexwise Regressor Model Fit Comparison'):

    hatches = ['', '/', '-', 'X', 'O']

    fig, ax = plt.subplots(figsize=(15,5))
    
    # greypallete = np.repeat('lightgrey', len(df))

    n_scopes = len(df['scope'].drop_duplicates())
    greypallete = list(np.repeat('lightgrey', n_scopes))

    order =df['target'].drop_duplicates()

    g = sns.barplot(
        x='target', 
        y='mean_scores_r2', 
        hue='scope', 
        data=df,
        palette=greypallete,
        order=order)

    g.legend_.set_title('')

    ax.grid(linestyle=':')
    bars = ax.patches[:len(ax.patches)-n_scopes]
    x_coords = [p.get_x() + 0.5 * p.get_width() for p in bars]
    y_coords = [p.get_height() for p in bars]

    ax.errorbar(x=x_coords, y=y_coords, yerr=df["std_scores_r2"], fmt="none", c="k")

    # only want one set of colors
    palette = df[['target', 'color']].drop_duplicates()['color']
 
    for bars, hatch, legend_handle in zip(ax.containers, hatches, 
                                          ax.legend_.legendHandles):
        for bar, color in zip(bars, palette):
            bar.set_facecolor(color)
            bar.set_hatch(hatch)
        legend_handle.set_hatch(hatch + hatch)

    ptplt = sns.pointplot(
        x='target', 
        y='test_r2', 
        data=df, 
        hue='scope', 
        markersize=2,
        dodge=0.5, 
        linestyles="none",
        palette=greypallete,
        order=order,
        legend=False
    )

    # formatting
    ax.set(xlabel=None)
    ax.set(ylabel='Avg. $R^{2}$')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.spines[['top', 'right']].set_visible(False)

    fig.subplots_adjust(top=0.9)
    fig.suptitle(title)

# make_effect_compare_plot(sort(summary), title='Vertexwise Regressor Model Fit Comparison')
# plt.savefig(params['plot_output_path'] + 'vertexwise_regressor_model_fit_comparison_ridge.png', dpi=300, bbox_inches='tight')

In [None]:
results_paths = glob.glob(params['model_results_path']+ f"*ridge_results.pkl")
results_paths

In [None]:
def remove_nonfeatures(coefs: pd.Series, filter_strings = ['mri', 'iqc']) -> pd.Series:
    return coefs[~coefs.index.str.contains('|'.join(filter_strings))]

def get_feature_importance(results: bp.CompareDict, metric='r2'
                                        ) -> (pd.DataFrame, pd.DataFrame, pd.DataFrame):
    """Get feature importance from results.
    
    Args:
        results (bp.CompareDict): Model results.
        metric (str, optional): Metric to extract. Defaults to 'r2'.
    
    Returns:
        pd.DataFrame: Feature importance.
    """
    
    fis = {}
    best_fis = {}
    avg_fis = {}

    if isinstance(results, str):
        results = pd.read_pickle(results)

    for l, m in results.items():

        lab = l.__dict__['options']

        if len(lab) == 1:
            scope = 'all'
            target = lab[0].__dict__['name']
        else:
            scope = lab[0].__dict__['name']
            scope = scope.replace('cov + ', '')
            target = lab[1].__dict__['name']

        if scope == 'mri_confounds':
            continue
        else:
            coefs = m.get_fis()
            fis[target] = coefs

            avg_coefs = coefs.mean()
            avg_coefs = remove_nonfeatures(avg_coefs)
            avg_fis[target] = avg_coefs

            best_model_idx = np.argmax(m.scores[metric])
            best_coefs = coefs.iloc[best_model_idx, :]
            best_coefs = remove_nonfeatures(best_coefs)
            best_fis[target] = best_coefs

    return fis, best_fis, avg_fis

def assemble_feature_importance(fpath: str, params: dict, model: str='ridge', n_jobs=2) -> None:
    """Implement gather_feature_importance."""

    results_paths = glob.glob(params['model_results_path']+ f"*ridge_results.pkl")
    res = Parallel(n_jobs=n_jobs)(
        delayed(get_feature_importance)(r) for r in results_paths
    )
    pd.to_pickle(res, fpath + f'{model}_feature_importance.pkl')


# assemble_feature_importance(params['model_results_path'], model='ridge')


In [None]:
def broadcast_to_fsaverage(fis_agg: pd.Series, n_vertices=10242+1) -> pd.DataFrame:

    def _split_hemisphere(df):
        df = df.reset_index(names=['correct', 'condition', 'hemisphere'])
        lh = df[df['hemisphere'] == 'lh'].drop(columns='hemisphere')
        rh = df[df['hemisphere'] == 'rh'].drop(columns='hemisphere')

        # idx = ['correct', 'condition']
        # return lh.set_index(idx), rh.set_index(idx)
        return lh, rh

    fis = fis_agg.copy()

    fis.index = fis.index.str.split('_', expand=True)
    fis = fis.unstack(level=2)
    fis.columns = pd.to_numeric(fis.columns).sort_values()

    # need to insert blank columns for missing vertices
    vertex_names = [*range(1, n_vertices)]
    null_df = pd.DataFrame(np.nan, columns=vertex_names, index=fis.index)
    null_df = null_df.drop(columns=fis.columns)

    df = fis.join(null_df, how='outer')
    lh, rh = _split_hemisphere(df)


    return lh, rh
def load_destrieux_atlas():
    atlas = fetch_atlas_surf_destrieux()
    return atlas
def map_destrieux(lh: pd.DataFrame, rh: pd.DataFrame, prefix: str='') -> pd.DataFrame:

    dest = load_destrieux_atlas()

    idx = ['correct', 'condition']
    lh = lh.set_index(idx)
    rh = rh.set_index(idx)

    lh_mapped = map_hemisphere(lh, 
               mapping=dest['map_left'], 
               labels=dest['labels'],
               prefix=prefix,
               suffix='.lh')
    rh_mapped = map_hemisphere(rh,
                mapping=dest['map_right'],
                labels=dest['labels'],
                prefix=prefix,
                suffix='.rh')

    lh_mapped.index = lh.index
    rh_mapped.index = rh.index
 
    df = pd.concat([lh_mapped, rh_mapped], axis=1)
    vmin, vmax = get_fullrang_minmax(df)

    return lh_mapped.reset_index(), rh_mapped.reset_index(), vmin, vmax


def absmax(x):
    idx = np.argmax(np.abs(x))
    return x[idx]


def get_fullrang_minmax(series: pd.Series):
    mi = series.min().min()
    ma = series.max().max()

    abs_max = max(np.abs(mi), np.abs(ma))

    return -abs_max, abs_max


def draw_plot(lh, rh, ax, mode, cmap='bwr', vmin=None, vmax=None, avg_method=absmax):

    if mode == 'roi':
        plot_df = pd.concat([lh, rh], axis=1)
        dest = fetch_atlas_surf_destrieux()
        surf_ref = SurfRef(space='fsaverage5', parc='destr')
        to_plot = surf_ref.get_hemis_plot_vals(plot_df)

        ntp.plot(to_plot, threshold=0, ax=ax, cmap=cmap,
            vmin=vmin, vmax=vmax, colorbar=False)

    else:
        # to_plot = surf_ref.get_hemis_plot_vals(dest)
        # ntp.plot(to_plot, threshold=0)
        plot_dict = {'lh': lh.values, 'rh': rh.values}
        ntp.plot(plot_dict, 
            avg_method=avg_method, 
            ax=ax, cmap=cmap, 
            vmin=vmin, vmax=vmax, 
            colorbar=False, 
            threshold=0)
    
def make_collage_plot(fis_agg: dict, target, target_map, basepath='../../data/06_reporting/rdex_prediction/fis_plots', 
    agg='avg_fis', mode='vertex', model='enet', fontsize=25) :

    def _format_for_plotting(fis: pd.DataFrame, correct: str, condition: str) -> pd.DataFrame:
        tmp = fis[(fis['correct'] == correct) & (fis['condition'] == condition)]

        tmp = tmp.drop(columns=['correct', 'condition'])
        tmp[np.isnan(tmp)] = 0
        return tmp
               
    lh, rh = broadcast_to_fsaverage(fis_agg[target])
 

    conditions = pd.unique(lh['condition'])
    n_cond = conditions.shape[0]
    directions = ['correct', '', 'incorrect']
    width_ratios = [1]
    height_ratios = [100, 1, 100]

    col_ratios = list(repeat(10, n_cond))

    width_ratios.extend(col_ratios)
    width_ratios.extend([2]) # colorbar

    gs = {
        'width_ratios': width_ratios,
        'height_ratios': height_ratios,
        'hspace':0,
        'wspace':0
    }

    cmap = 'bwr'
    nb_ticks = 5
    cbar_tick_format='%.2g'

    fig, axs = plt.subplots(3, n_cond + 1+1, figsize = (35, 20) , gridspec_kw=gs)
    
    for i, direction in enumerate(directions):
        ax = axs[i, 0]
        ax.set_axis_off()
        if i==1: 
            continue
        else:
            ax.text(0, .5, direction, fontsize=fontsize)

    if mode == 'roi':
        lh, rh, vmin, vmax = map_destrieux(lh, rh)
    else:
        vmin, vmax = get_fullrang_minmax(fis_agg[target])
    
    cnt = 1
    for condition in conditions:
        top = axs[0, cnt]
        middle = axs[1, cnt]
        bottom = axs[2, cnt]

        lh_correct = _format_for_plotting(lh, 'correct', condition)
        rh_correct = _format_for_plotting(rh, 'correct', condition)
        draw_plot(lh_correct, rh_correct, top, mode, vmin=vmin, vmax=vmax)

        middle.set_axis_off() # make blank space

        lh_incorrect = _format_for_plotting(lh, 'incorrect', condition)
        rh_incorrect = _format_for_plotting(rh, 'incorrect', condition)
        draw_plot(lh_incorrect, rh_incorrect, bottom, mode, vmin=vmin, vmax=vmax)
        
        top.set_title(condition, fontsize=fontsize)
        cnt += 1
    
    # plot colorbar
    norm = Normalize(vmin=vmin, vmax=vmax)
    proxy_mappable = ScalarMappable(norm=norm, cmap=cmap)
    ticks = np.linspace(vmin, vmax, nb_ticks)

    right = axs[:, n_cond + 1]
    
    for ax in right.flat:
        ax.set_axis_off()
        
    cax, kw = make_axes(right, fraction=.5, shrink=0.5)
    cbar = fig.colorbar(proxy_mappable, cax=cax, ticks=ticks,
                        orientation='vertical', format=cbar_tick_format,
                        ticklocation='left')
    
    cbar.set_label(label='Avg. Feature Imp.', fontsize=fontsize - 2)

    prefix = 'Target: '
    title = target_map[target]

    fig.suptitle(prefix + title, fontsize=fontsize+2)

    fpath = f'{basepath}/{agg}/{mode}/{model}'
    if not os.path.exists(fpath):
        os.makedirs(fpath)
        
    plt.savefig(f'{fpath}/{target}.png', dpi=300, bbox_inches='tight')
    plt.close()



In [None]:
def plot_mode(fis_agg: dict, target_map: dict, agg: str, mode: str,  model: str,
     basepath='./data/08_reporting/fis_plots'):

    Parallel(n_jobs=8)(delayed(make_collage_plot)(
        fis_agg, target, target_map, basepath=basepath, agg=agg, mode=mode, model=model) 
        for target in fis_agg.keys())

def make_fis_plots(avg_fis, best_fis,  target_map, model='enet'):

    modes = ['vertex', 'roi']

    for mode in modes:
        plot_mode(avg_fis, target_map, mode=mode, agg='avg_fis', model=model)
        plot_mode(best_fis, target_map, mode=mode, agg='best_fis', model=model)

In [None]:
fis = pd.read_pickle(params['model_results_path'] + 'ridge_feature_importance.pkl')

In [None]:
def gather_fis(fis: list):
    """Gather feature importance for plotting."""

    targets = fis[0][1].keys()
    best_fis = {}
    avg_fis = {}

    for target in targets:
        best_fis[target] = pd.concat([f[1][target] for f in fis])
        avg_fis[target] = pd.concat([f[2][target] for f in fis])

return best_fis, avg_fis


# avg_fis


In [None]:
target = 'EEA'

make_collage_plot(avg_fis, target, target_map, agg='avg_fis', mode='roi', model='ridge')
