In [None]:
# import
import os, sys
import numpy as np
import pandas as pd
import scipy as sp
from scipy import stats
from scipy.spatial import distance
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from tqdm import tqdm

# import plotting libraries
import matplotlib.pyplot as plt
plt.rcParams.update({"font.size": 8})
plt.rcParams["svg.fonttype"] = "none"
import seaborn as sns
from nilearn import datasets
from nilearn import plotting

sys.path.extend([r'/home/lindenmp/research_projects/snaplab_tools'])
sys.path.extend([r'/home/lindenmp/research_projects/nctpy/src'])

# import nctpy functions
from snaplab_tools.plotting.plotting import categorical_kde_plot, reg_plot, brain_scatter_plot, null_plot, surface_plot
from snaplab_tools.plotting.utils import get_my_colors, get_p_val_string
from nctpy.utils import matrix_normalization
from snaplab_tools.utils import get_schaefer_system_mask, get_null_p, get_fdr_p
from snaplab_tools.plotting.utils import roi_to_vtx

from brainsmash.mapgen.base import Base
from src.utils import get_adj_weights
from snaplab_tools.utils import normalize_x

from snaplab_tools.derivs import compute_acf
from nilearn.glm import first_level
from nilearn import plotting
from nilearn.glm.contrasts import compute_contrast
from scipy import signal
import statsmodels.api as sm
import statsmodels.formula.api as smf
from sklearn.metrics import r2_score

In [None]:
def compute_deflections(rest_timescales, task_timescales, nuisance_regression=False):
    rest_timescales = rest_timescales.reshape(-1, 1)
    task_timescales = task_timescales.reshape(-1, 1)

    int_deflections = task_timescales - rest_timescales
    
    if nuisance_regression:
        deflections_mean = np.mean(int_deflections)
        
        X = np.concatenate((rest_timescales, np.ones((rest_timescales.shape[0], 1))), axis=1)
        beta = np.dot(np.linalg.pinv(X), int_deflections)
        predicted = np.dot(X, beta)
        int_deflections = int_deflections - predicted
        # print(int_deflections.shape, predicted.shape)
        # int_deflections += deflections_mean
    
    return int_deflections[:, 0]

def nuis_reg(X, y, use_sklearn=False):
    if use_sklearn:
        if X.ndim == 1:
            X = X[:, np.newaxis]
        if y.ndim == 1:
            y = y[:, np.newaxis]

        regr = LinearRegression()
        regr.fit(X, y)
        predicted = regr.predict(X)
        residuals = y - predicted
    else:
        beta = np.dot(np.linalg.pinv(X), y)
        predicted = np.dot(X, beta)
        residuals = y - predicted
    return residuals

def run_regression(df_behavior, y_var, rest_timescales, task_timescales):
    n_nodes = rest_timescales.shape[0]
    effect_map = pd.DataFrame(index=np.arange(n_nodes))
    pvals = pd.DataFrame(index=np.arange(n_nodes))

    for i in tqdm(np.arange(n_nodes)):
        df_stats = pd.DataFrame(index=df_behavior.index)
        df_stats[y_var] = df_behavior[y_var].copy()
        df_stats[y_var].fillna(np.nanmean(df_stats[y_var]), inplace=True)
        df_stats[y_var] = sp.stats.zscore(df_stats[y_var])
        
        df_stats['rest'] = sp.stats.zscore(rest_timescales[i])
        df_stats['task'] = sp.stats.zscore(task_timescales[i])
        df_stats['age'] = sp.stats.zscore(df_behavior['age'])
        df_stats['sex'] = df_behavior['sex']
        df_stats = sm.add_constant(df_stats)
        # if i == 0:
            # print(df_stats.head())
        
        # Fit the model
        results = smf.ols(formula='{0} ~ rest * task + age + C(sex)'.format(y_var), data=df_stats).fit()
        # if i == 0:
            # print(results.summary())
        
        for effect in ['rest', 'task', 'rest:task', 'age', 'C(sex)[T.1]']:
            try:
                effect_map.loc[i, effect] = results.params[effect]
                pvals.loc[i, effect] = results.pvalues[effect]
            except:
                pass
        effect_map.loc[i, 'rsquared'] = results.rsquared_adj
        
    return effect_map, pvals

def detect_univariate_outliers(x, k=1.5):
    q1, q3 = np.percentile(x, [25, 75])
    iqr = q3 - q1
    lower_bound = q1 - (k * iqr)
    upper_bound = q3 + (k * iqr)
    outlier_labels = ((x < lower_bound) | (x > upper_bound)).astype(int)
    
    return outlier_labels

def compute_r2(x, y):
    if x.ndim == 1:
        x = x[:, np.newaxis]
    if y.ndim == 1:
        y = y[:, np.newaxis]

    regr = LinearRegression()
    regr.fit(x, y)
    y_pred = regr.predict(x)
    return r2_score(y, y_pred) * 100

def winsorize_iqr(vector, k=1.5, inplace=False):
    """
    Winsorizes a vector using the IQR method to handle outliers.
    
    Parameters:
    -----------
    vector : array-like
        Input data to be winsorized (list, numpy array, or pandas Series)
    k : float, optional (default=1.5)
        Multiplier for IQR to determine outlier thresholds
    inplace : bool, optional (default=False)
        If True, modifies the input vector in place (only works with mutable input)
        
    Returns:
    --------
    winsorized_vector : numpy array
        Winsorized version of the input vector
    """
    # Convert input to numpy array if it isn't already
    if not isinstance(vector, np.ndarray):
        vector = np.array(vector)
    
    # Calculate quartiles and IQR
    q1 = np.percentile(vector, 25)
    q3 = np.percentile(vector, 75)
    iqr = q3 - q1
    
    # Calculate lower and upper bounds
    lower_bound = q1 - k * iqr
    upper_bound = q3 + k * iqr
    
    # Create a copy unless inplace is True and input is mutable
    if inplace and isinstance(vector, np.ndarray):
        winsorized_vector = vector
    else:
        winsorized_vector = vector.copy()
    
    # Winsorize the values
    winsorized_vector[winsorized_vector < lower_bound] = lower_bound
    winsorized_vector[winsorized_vector > upper_bound] = upper_bound
    
    return winsorized_vector

## Setup

In [None]:
# directory where data is stored
indir = '/home/lindenmp/research_projects/nct_xr/data/int_deflections'
which_data = 'HCP-YA'
# which_data = 'HCP_D'
# which_data = 'RBC-PNC'
outdir = '/home/lindenmp/research_projects/nct_xr/results/int_deflections/{0}'.format(which_data.replace('_', ''))

atlas = 'Schaefer4007'
if atlas == 'Schaefer4007':
    n_parcels = 400
    n_nodes = 400
elif atlas == 'Schaefer2007':
    n_parcels = 200
    n_nodes = 200
elif atlas == 'Schaefer1007':
    n_parcels = 100
    n_nodes = 100
    
if which_data == 'HCP-YA':
    tr = 0.720
    # tasks = ['tfMRI_EMOTION_LR', 'tfMRI_GAMBLING_LR', 'tfMRI_LANGUAGE_LR', 'tfMRI_RELATIONAL_LR', 'tfMRI_SOCIAL_LR', 'tfMRI_WM_LR']
    tasks = ['tfMRI_WM_LR', 'tfMRI_EMOTION_LR', 'tfMRI_LANGUAGE_LR', 'tfMRI_RELATIONAL_LR']

    cog_measures = [
        'PMAT24_A_CR', # fluid intelligence
        'VSPLOT_TC', # spatial orientation
        'ListSort_Unadj', # working memory
        'DDisc_AUC_40K', # delay discounting
        'Flanker_Unadj', # executive function
    ]
elif which_data == 'HCP_D':
    tr = 0.800
    # tasks = ['tfMRI_CARIT_PA', 'tfMRI_EMOTION_PA', 'tfMRI_GUESSING_PA']
    tasks = ['tfMRI_CARIT_PA',]
elif which_data == 'RBC-PNC':
    tr = 3
    tasks = ['frac2back',]

In [None]:
# plot states on brain surface
annot_dir = '/home/lindenmp/research_projects/nctpy/data'
lh_annot_file = os.path.join(annot_dir, 'schaefer_parc', 'fsaverage5', 'lh.Schaefer2018_{0}Parcels_7Networks_order.annot'.format(n_parcels))
rh_annot_file = os.path.join(annot_dir, 'schaefer_parc', 'fsaverage5', 'rh.Schaefer2018_{0}Parcels_7Networks_order.annot'.format(n_parcels))
fsaverage = datasets.fetch_surf_fsaverage(mesh="fsaverage5")
cmap = "viridis"

In [None]:
my_colors = get_my_colors()

In [None]:
parc_centroids = pd.read_csv(os.path.join(indir, 'Schaefer2018_{0}Parcels_7Networks_order_FSLMNI152_1mm.Centroid_RAS.csv'.format(n_parcels)), index_col=1)
parc_centroids.drop(columns=['ROI Label'], inplace=True)
print(parc_centroids.head())

distance_matrix = distance.pdist(
    parc_centroids, "euclidean"
)  # get euclidean distances between nodes
distance_matrix = distance.squareform(distance_matrix)  # reshape to square matrix

# Load data

In [None]:
subjectids = np.loadtxt(os.path.join(indir, '{0}_{1}_subjids.txt'.format(which_data.replace('_', ''), atlas)), dtype=str)
n_subs = len(subjectids)
print(n_subs)

In [None]:
sa_axis = np.load(os.path.join(indir, 'schaefer{0}-7_sa-axis.npy'.format(n_parcels)))
sa_axis_sort_idx = np.argsort(sa_axis)

# myelin = np.load(os.path.join(indir, '{0}_{1}_myelin.npy'.format(which_data.replace('_', ''), atlas)))
# myelin_mean = np.nanmean(myelin, axis=1)

if which_data == 'HCP-YA':
    which_task = 'rfMRIREST1LR'
elif which_data == 'HCP_D':
    which_task = 'rfMRIREST1PA'
else:
    which_task = 'RBC'

rest_timescales = np.load(os.path.join(outdir, '{0}_{1}_{2}_timescales_acflag.npy'.format(which_data.replace('_', ''), atlas, which_task)))
# rest_timescales = np.load(os.path.join(outdir, '{0}_{1}_{2}_timescales_alff.npy'.format(which_data.replace('_', ''), atlas, which_task)))

# rest_timescales = np.log10(rest_timescales, out=np.zeros_like(rest_timescales), where=(rest_timescales != 0))

rest_timescales_mean = np.nanmean(rest_timescales, axis=0)

print(sa_axis.shape, rest_timescales.shape, rest_timescales_mean.shape)

In [None]:
task_timescales = dict()
task_timescales_mean = dict()
subject_filter = dict()
task_timescales_mean_all = np.zeros(n_parcels)
for which_task in tasks:
    print(which_task)
    tsk_tscales = np.load(os.path.join(outdir, '{0}_{1}_{2}_timescales_acflag.npy'.format(which_data.replace('_', ''), atlas, which_task.replace('_', ''))))
    # tsk_tscales = np.load(os.path.join(outdir, '{0}_{1}_{2}_timescales_alff.npy'.format(which_data.replace('_', ''), atlas, which_task.replace('_', ''))))
    
    # tsk_tscales = np.log10(tsk_tscales, out=np.zeros_like(tsk_tscales), where=(tsk_tscales != 0))
    for i in np.arange(tsk_tscales.shape[0]):
        tsk_tscales[i, :] = winsorize_iqr(tsk_tscales[i, :])
    try:
        subject_filter[which_task] = np.load(os.path.join(outdir, '{0}_{1}_{2}_subjectfilter.npy'.format(which_data.replace('_', ''), atlas, which_task.replace('_', ''))))
        print(subject_filter[which_task].sum())
        task_timescales_mean[which_task] = np.nanmean(tsk_tscales[~subject_filter[which_task], :], axis=0)
    except:
        task_timescales_mean[which_task] = np.nanmean(tsk_tscales, axis=0)
    task_timescales[which_task] = tsk_tscales
    task_timescales_mean_all += task_timescales_mean[which_task]

task_timescales_mean_all = np.divide(task_timescales_mean_all, len(tasks))

print(task_timescales_mean_all.shape)

In [None]:
task_timescales[which_task].shape

In [None]:
# f = surface_plot(
#     data=sa_axis,
#     lh_annot_file=lh_annot_file,
#     rh_annot_file=rh_annot_file,
#     fsaverage=fsaverage,
#     order="lr",
#     cmap="viridis",
# )
# f.savefig(os.path.join(outdir, 'sa_axis.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

In [None]:
df_stats = pd.DataFrame(index=np.arange(n_parcels))
df_stats['rest'] = rest_timescales_mean
df_stats['task'] = task_timescales_mean_all
df_stats['SA'] = sa_axis
df_stats = sm.add_constant(df_stats)
results = smf.ols(formula='SA ~ rest * task', data=df_stats).fit()
print(results.summary())

In [None]:
# int_deflections = compute_deflections(rest_timescales_mean, task_timescales_mean_all, nuisance_regression=True)
int_deflections = task_timescales_mean_all - rest_timescales_mean

f, ax = plt.subplots(1, 3, figsize=(2*3, 1.75))
reg_plot(sa_axis, rest_timescales_mean, xlabel='SA axis', ylabel='INTs (rest)', ax=ax[0], annotate='both', add_pval=False)
reg_plot(sa_axis, task_timescales_mean_all, xlabel='SA axis', ylabel='INTs (task)', ax=ax[1], annotate='both', add_pval=False)
reg_plot(sa_axis, int_deflections, xlabel='SA axis', ylabel='INTs (task-rest)', ax=ax[2], annotate='both', add_pval=False)

f.tight_layout()
plt.show()
f.savefig(os.path.join(outdir, 'ints_deflections_sa_corr.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

In [None]:
f = surface_plot(
    data=rest_timescales_mean,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="magma",
)
f.savefig(os.path.join(outdir, 'ints_rest.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

f = surface_plot(
    data=task_timescales_mean_all,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="magma",
)
f.savefig(os.path.join(outdir, 'ints_task.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

f = surface_plot(
    data=int_deflections,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="coolwarm",
)
f.savefig(os.path.join(outdir, 'ints_deflections.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

# Subject-level effects

In [None]:
which_task = tasks[0]
print(which_task)

if which_task == 'tfMRI_EMOTION_LR':
    task_performance_measures = ['Emotion_Task_Acc', 'Emotion_Task_Median_RT']
elif which_task == 'tfMRI_GAMBLING_LR':
    task_performance_measures = ['Gambling_Task_Perc_Larger', 'Gambling_Task_Perc_Smaller']
elif which_task == 'tfMRI_LANGUAGE_LR':
    task_performance_measures = ['Language_Task_Acc', 'Language_Task_Median_RT']
elif which_task == 'tfMRI_RELATIONAL_LR':
    task_performance_measures = ['Relational_Task_Acc', 'Relational_Task_Median_RT']
elif which_task == 'tfMRI_SOCIAL_LR':
    task_performance_measures = ['Social_Task_Perc_Random', 'Social_Task_Perc_TOM']
elif which_task == 'tfMRI_WM_LR' or which_task == 'tfMRI_WM_RL':
    task_performance_measures = ['WM_Task_Acc', 'WM_Task_Median_RT']
elif which_task == 'frac2back':
    contrasts_list = ['0back', '1back', '2back']
# elif which_task == 'tfMRI_CARIT_PA':
    # task_performance_measures = ['WM_Task_Acc', 'WM_Task_Median_RT']
    # contrasts_list = ['go', 'miss', 'nogoCRLose', 'nogoCRWin', 'nogoFALose', 'nogoFAWin']
    # contrasts_list_short = ['go', 'nogo']

In [None]:
n_subjects = len(subjectids)

if which_data == 'HCP-YA':
    column_filter = task_performance_measures.copy()
    column_filter.extend(cog_measures)
    print(column_filter)

    df_behavior = pd.read_csv('/mnt/storage_ssd_raid/research_data/HCP_YA/unrestricted_lindenmp_9_26_2022_16_52_14.csv', index_col=0)
    df_behavior = df_behavior.loc[subjectids.astype(int)]
    df_behavior_restricted = pd.read_csv('/mnt/storage_ssd_raid/research_data/HCP_YA/RESTRICTED_lindenmp_9_26_2022_16_42_3.csv', index_col=0)
    df_behavior_restricted = df_behavior_restricted.loc[subjectids.astype(int)]
    df_behavior['age'] = df_behavior_restricted['Age_in_Yrs']
    df_behavior['sex'] = df_behavior['Gender'] == 'M'
    df_behavior['sex'] = df_behavior['sex'].astype(int)

    # for y_var in column_filter:
        # df_behavior[y_var].fillna(np.nanmean(df_behavior[y_var]), inplace=True)

    y_var = column_filter[0]
    # y_var = 'age'
    print(y_var)
    if 'Acc' in y_var:
        y_var_label = 'accuracy'
    elif 'RT' in y_var:
        y_var_label = 'RT'
    else:
        y_var_label = y_var

    if y_var == 'WM_Task_Acc':
        y_var_filter = df_behavior[y_var] >= 50
    else:
        y_var_filter = ~np.isnan(df_behavior[y_var].values)
        # y_var_filter = np.zeros(n_subjects).astype(bool)
        # y_var_filter[:] = True
    print(np.sum(y_var_filter))
        
    if y_var == 'age':
        nuis_covs = np.concatenate((df_behavior['sex'].values[:, np.newaxis],
                                    np.ones((n_subjects, 1))
                                    ), axis=1)
    else:
        nuis_covs = np.concatenate((sp.stats.zscore(df_behavior['age']).values[:, np.newaxis],
                                    df_behavior['sex'].values[:, np.newaxis],
                                    np.ones((n_subjects, 1))
                                    ), axis=1)
elif which_data == 'HCP_D':
    df_behavior = pd.read_csv('/mnt/storage_ssd_raid/research_data/HCP_D/behavior/socdem01.txt', sep='\t', header=0)
    df_behavior = df_behavior.iloc[1:]
    behavior_subjectids = list(df_behavior['src_subject_id'])
    behavior_subjectids = ['sub-' + s + 'V1MR' for s in behavior_subjectids]
    df_behavior.index = behavior_subjectids
    df_behavior.index.name = 'subjectids'

    y_var = 'interview_age'
    y_var_label = 'age'
    df_behavior = df_behavior.loc[:, [y_var, 'sex']]
    df_behavior['sex'] = df_behavior['sex'] == 'M'
    df_behavior['sex'] = df_behavior['sex'].astype(int)

    df_behavior[y_var] = df_behavior[y_var].astype(float) / 12
    df_behavior = df_behavior.loc[subjectids]
    y_var_filter = np.zeros(n_subjects).astype(bool)
    y_var_filter[:] = True
    
    nuis_covs = np.concatenate(((df_behavior['sex'] == 'M').astype(int).values[:, np.newaxis],
                                np.ones((n_subjects, 1))
                                ), axis=1)

elif which_data == 'RBC-PNC':
    df_behavior = pd.read_csv('/mnt/storage_ssd_raid/research_data/RBC/PNC/study-PNC_desc-participants.tsv', sep='\t', header=0)
    behavior_subjectids = list(df_behavior['participant_id'])
    behavior_subjectids = ['sub-' + str(s) for s in behavior_subjectids]
    df_behavior.index = behavior_subjectids
    df_behavior.index.name = 'subjectids'

    y_var = 'age'
    y_var_label = 'age'
    df_behavior = df_behavior.loc[:, [y_var, 'sex']]
    df_behavior['sex'] = df_behavior['sex'] == 'Male'
    df_behavior['sex'] = df_behavior['sex'].astype(int)

    df_behavior[y_var] = df_behavior[y_var].astype(float)
    df_behavior = df_behavior.loc[subjectids]
    y_var_filter = np.zeros(n_subjects).astype(bool)
    y_var_filter[:] = True

    nuis_covs = np.concatenate((df_behavior['sex'].values[:, np.newaxis],
                                np.ones((n_subjects, 1))
                                ), axis=1)

df_behavior

In [None]:
nuis_covs

In [None]:
rest_timescales_resid = nuis_reg(nuis_covs, rest_timescales)
task_timescales_resid = nuis_reg(nuis_covs, task_timescales[which_task])
# int_deflections = task_timescales[which_task] - rest_timescales
# int_deflections = task_timescales_resid - rest_timescales_resid
# int_deflections_resid = nuis_reg(rest_timescales_resid, int_deflections)

# brain_data = rest_timescales_resid
# brain_data = task_timescales_resid
brain_data = task_timescales[which_task]
# brain_data = int_deflections_resid

# Fit models for each region, plot SA axis by beta value
corr_map = np.zeros(n_parcels)
p_map = np.zeros(n_parcels)
for i in np.arange(n_parcels):
    stats = sp.stats.pearsonr(brain_data[y_var_filter, i], df_behavior[y_var][y_var_filter])
    corr_map[i] = stats[0]
    p_map[i] = stats[1]

    # corr_map[i] = compute_r2(brain_data[y_var_filter, i], df_behavior[y_var][y_var_filter].values)

p_map = get_fdr_p(p_map)
corr_map_filtered = corr_map.copy()
# corr_map_filtered[p_map>0.05] = np.nan

f2 = surface_plot(
    data=corr_map_filtered,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="coolwarm",
    # title_str='corr(int_deflections, behavior)'
)
f2.savefig(os.path.join(outdir, 'corr(int_deflections,behavior)_brainplot.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

# Get slope of the task-rest ~ SA axis per subject, regress slope on performance
f, ax = plt.subplots(1, 1, figsize=(1.1, 1.1))
reg_plot(sa_axis, corr_map, xlabel='SA axis', ylabel='corr(INTs, {0})'.format(y_var_label), ax=ax, annotate='spearman', add_pval=False)
plt.show()
f.savefig(os.path.join(outdir, 'int_deflections_task_performance_sa_corr_1.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

f, ax = plt.subplots(1, 2, figsize=(3, 1.5))
reg_plot(df_behavior[y_var][y_var_filter], brain_data[y_var_filter, np.argmin(corr_map)], xlabel=y_var_label, ylabel='INTs', ax=ax[0], annotate='both', add_pval=False)
reg_plot(df_behavior[y_var][y_var_filter], brain_data[y_var_filter, np.argmax(corr_map)], xlabel=y_var_label, ylabel='INTs', ax=ax[1], annotate='both', add_pval=False)
f.tight_layout()
plt.show()

corr_map = np.zeros(n_subjects)
for i in np.arange(n_subjects):
    corr_map[i] = sp.stats.pearsonr(sa_axis, brain_data[i, :])[0]
f, ax = plt.subplots(1, 1, figsize=(1.1, 1.1))
reg_plot(df_behavior[y_var].values[y_var_filter], corr_map[y_var_filter], ylabel='corr(SA axis, INTs)', xlabel=y_var_label, ax=ax, annotate='spearman', add_pval=False)
plt.show()
f.savefig(os.path.join(outdir, 'int_deflections_task_performance_sa_corr_2.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

### GAM analysis of age

In [None]:
from pygam import LinearGAM, GammaGAM, s, f
from sklearn.model_selection import KFold
from sklearn.linear_model import Ridge
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

def fit_age_relationship(age, sex, y, model_type='gam', run_gridsearch=False):
    """
    Fit either a GAM or linear regression to model the relationship between age and y 
    while controlling for sex. Returns age-based predictions at 1-month intervals.
    
    Parameters:
        age (array-like): Age values in months
        sex (array-like): Categorical sex values
        y (array-like): Target variable
        model_type (str): 'gam' or 'linear' regression
        n_splines_range (range): Range of spline numbers to test (GAM only)
        lam_range (array): Range of lambda values to test (GAM only)
        
    Returns:
        dict: Contains model, predictions, and parameters
    """

    if model_type == 'gam':
        # Create dataframe and drop missing values
        df = pd.DataFrame({'age': age, 'sex': sex, 'y': y}).dropna()
        df['sex'] = pd.Categorical(df['sex'])
        n_age_steps = 100
        # n_age_steps = np.arange(int(df['age'].min()), int(df['age'].max()) + 1, 1).shape[0]

        n_splines = 3
        if run_gridsearch:
            # Fit GAM with current parameters
            # lam = np.linspace(0.001, 0.1, 5)
            lam = np.linspace(0.001, 1, 10)
            # print(lam)
            lams = [lam] * 2
            model = LinearGAM(s(0, n_splines=n_splines, spline_order=n_splines-1) + f(1))
            model.fit(df[['age', 'sex']], df['y'])
            model.gridsearch(df[['age', 'sex']], df['y'], lam=lams, progress=False)
        else:
            lams = [0.1] * 2
            model = LinearGAM(s(0, n_splines=n_splines, spline_order=n_splines-1) + f(1), lam=lams)
            model.fit(df[['age', 'sex']], df['y'])

        # Generate predictions
        for i, term in enumerate(model.terms):
            # print(term)
            if i == 0:
                age_grid = model.generate_X_grid(term=i, n=n_age_steps)
                pred, confi = model.partial_dependence(term=i, X=age_grid, width=0.95)
                # print(age_grid)
                # print(pred)
            else:
                pass

        # pred = pred - pred.mean()
        # pred = sp.stats.zscore(pred)

    elif model_type == 'linear':
        n_age_steps = 100
        # age_grid = np.arange(int(age.min()), int(age.max()) + 1, 1).reshape(-1, 1)
        age = sp.stats.zscore(age)
        age_grid = np.linspace(age.min(), age.max(), n_age_steps).reshape(-1, 1)
        y = sp.stats.zscore(y)

        age = age.reshape(-1, 1)
        sex = sex.reshape(-1, 1)
        y = y.reshape(-1, 1)
        
        model = LinearRegression()
        # model.fit(sex, y)
        # pred = model.predict(sex)
        # y = y - pred
        model.fit(sex, age)
        pred = model.predict(sex)
        age = age - pred

        # Fit the model
        model = LinearRegression()
        model.fit(age, y)
        
        # print(age_grid.shape)
        pred = model.predict(age_grid).flatten()
        pred = pred - pred.mean()

    results_df = pd.DataFrame({
        'age': age_grid[:, 0],
        'prediction': pred,
    })
    
    return {
        'model': model,
        'predictions': results_df,
    }

In [None]:
model_type = 'gam'
run_gridsearch = False

In [None]:
from matplotlib.colors import LinearSegmentedColormap, ListedColormap

my_colors = get_my_colors()
# colors = [my_colors['conch_shell'],
#           my_colors['north_sea_green']]
colors = ['orange',
          'purple']
cm = LinearSegmentedColormap.from_list("Custom", colors, N=10)
cm

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
# reg_plot(sa_axis, age_resolved_change[:, 0], xlabel='sa', ylabel='age change', ax=ax, annotate='both')

In [None]:
rest_timescales_sorted = rest_timescales[:, sa_axis_sort_idx]
task_timescales_sorted = task_timescales[which_task][:, sa_axis_sort_idx]

predictions_rest = []
age_resolved_change_rest = []
predictions_task = []
age_resolved_change_task = []
predictions_deflections = []
age_resolved_change_deflections = []

age = df_behavior[y_var].values.copy()
sex = df_behavior['sex'].values.copy()
# print(sex)

step_size = int(n_parcels/10)
print(step_size)
for i in np.arange(0, n_parcels, step_size):
    rest_decile = rest_timescales_sorted[:, i:i+step_size].mean(axis=1)
    task_decile = task_timescales_sorted[:, i:i+step_size].mean(axis=1)
    deflections = task_decile - rest_decile
    mean_back = deflections.mean()
    deflections = np.squeeze(nuis_reg(rest_decile, deflections, use_sklearn=True))
    deflections += mean_back
                             
    # rest
    results = fit_age_relationship(age, sex, rest_decile, model_type=model_type, run_gridsearch=run_gridsearch)
    preds = results['predictions']['prediction']
    predictions_rest.append(preds)
    age_resolved_change_rest.append(np.diff(preds))

    # task
    results = fit_age_relationship(age, sex, task_decile, model_type=model_type, run_gridsearch=run_gridsearch)
    preds = results['predictions']['prediction']
    predictions_task.append(preds)
    age_resolved_change_task.append(np.diff(preds))
    
    # deflections
    results = fit_age_relationship(age, sex, deflections, model_type=model_type, run_gridsearch=run_gridsearch)
    preds = results['predictions']['prediction']
    predictions_deflections.append(preds)
    age_resolved_change_deflections.append(np.diff(preds))

predictions_rest = np.asarray(predictions_rest)
age_resolved_change_rest = np.asarray(age_resolved_change_rest)
predictions_task = np.asarray(predictions_task)
age_resolved_change_task = np.asarray(age_resolved_change_task)
predictions_deflections = np.asarray(predictions_deflections)
age_resolved_change_deflections = np.asarray(age_resolved_change_deflections)

# print(predictions_rest.shape, age_resolved_change_rest.shape, predictions_task.shape, age_resolved_change_task.shape)

fig, ax = plt.subplots(1, 3, figsize=(7.5, 2.5))
for i in np.arange(int(n_parcels/step_size)):
    ax[0].plot(results['predictions']['age'], predictions_rest[i], linewidth=1, alpha=1, color=cm(i))
ax[0].set_title('rest')
for i in np.arange(int(n_parcels/step_size)):
    ax[1].plot(results['predictions']['age'], predictions_task[i], linewidth=1, alpha=1, color=cm(i))
ax[1].set_title('task')
for i in np.arange(int(n_parcels/step_size)):
    ax[2].plot(results['predictions']['age'], predictions_deflections[i], linewidth=1, alpha=1, color=cm(i))
ax[2].set_title('deflections')

for this_ax in ax:
    this_ax.set_xlabel('Age (years)')
    this_ax.set_ylabel('INTs (GAM prediction)')
fig.tight_layout()

In [None]:
relative_change = np.abs(age_resolved_change_task) - np.abs(age_resolved_change_rest)
# relative_change = age_resolved_change_task - age_resolved_change_rest
print(relative_change.shape)
fig, ax = plt.subplots(1, 2, figsize=(5, 2.5))
# ax[0].plot(results['predictions']['age'][1:], age_resolved_change_rest[-1], linewidth=2, alpha=1, linestyle='--', color=cm(9))
# ax[0].plot(results['predictions']['age'][1:], age_resolved_change_task[-1], linewidth=2, alpha=1, color=cm(9))
ax[0].plot(results['predictions']['age'], predictions_rest[0], linewidth=2, alpha=1, linestyle='--', color=cm(0))

ax[0].plot(results['predictions']['age'], predictions_task[0], linewidth=2, alpha=1, color=cm(0))
ax[0].plot(results['predictions']['age'], predictions_rest[-1], linewidth=2, alpha=1, linestyle='--', color=cm(9))
ax[0].plot(results['predictions']['age'], predictions_task[-1], linewidth=2, alpha=1, color=cm(9))
# ax[0].axhline(y=0, color='k', linestyle='--')

ax[1].plot(results['predictions']['age'][1:], relative_change[0, :], color=cm(0), linewidth=2, label='sensorimotor')
ax[1].plot(results['predictions']['age'][1:], relative_change[-1, :], color=cm(10), linewidth=2, label='association')
ax[1].axhline(y=0, color='k', linestyle='--')
ax[1].set_xlabel('Age (years)')
ax[1].set_ylabel('mean(age change, task-rest)')
# ax[1].grid(True)
# ax[1].legend()
fig.tight_layout()
plt.show()

In [None]:
# # export data for Bart
# df_rest = pd.DataFrame()
# df_rest['subject_id'] = subjectids
# df_rest.set_index('subject_id', inplace=True)
# df_rest['age'] = df_behavior[y_var].values.copy()
# df_rest['sex'] = df_behavior['sex'].values.copy()
# counter = 1
# for i in np.arange(0, n_parcels, step_size):
#     df_rest['SA_dec_{0}'.format(counter)] = rest_timescales_sorted[:, i:i+step_size].mean(axis=1)
#     counter += 1
# df_rest.to_csv(os.path.join(outdir, 'df_rest_larsen.csv'))

# df_task = pd.DataFrame()
# df_task['subject_id'] = subjectids
# df_task.set_index('subject_id', inplace=True)
# df_task['age'] = df_behavior[y_var].values.copy()
# df_task['sex'] = df_behavior['sex'].values.copy()
# counter = 1
# for i in np.arange(0, n_parcels, step_size):
#     df_task['SA_dec_{0}'.format(counter)] = task_timescales_sorted[:, i:i+step_size].mean(axis=1)
#     counter += 1
# df_task.to_csv(os.path.join(outdir, 'df_task_larsen.csv'))

In [None]:
# q_val = 0.2
# lower_q = sa_axis < np.quantile(sa_axis, q=q_val)
# upper_q = sa_axis > np.quantile(sa_axis, q=1-q_val)

# fig, ax = plt.subplots(1, 2, figsize=(4.5, 2))
# ax[0].plot(results['predictions']['age'][1:], np.mean(age_resolved_change_rest[lower_q, :], axis=0), 'r-', linewidth=2, label='rest')
# ax[0].plot(results['predictions']['age'][1:], np.mean(age_resolved_change_task[lower_q, :], axis=0), 'b-', linewidth=2, label='task')
# ax[0].set_title('Sensorimotor')
# ax[1].plot(results['predictions']['age'][1:], np.mean(age_resolved_change_rest[upper_q, :], axis=0), 'r-', linewidth=2, label='rest')
# ax[1].plot(results['predictions']['age'][1:], np.mean(age_resolved_change_task[upper_q, :], axis=0), 'b-', linewidth=2, label='task')
# ax[1].set_title('Association')
# for this_ax in ax:
#     this_ax.set_xlabel('Age (years)')
#     this_ax.set_ylabel('mean(age change)')
#     this_ax.grid(True)
#     this_ax.legend()
# fig.tight_layout()
# plt.show()

In [None]:
predictions_rest = []
age_resolved_change_rest = []
predictions_task = []
age_resolved_change_task = []

for i in tqdm(np.arange(n_parcels)):
    # rest
    results = fit_age_relationship(age, sex, rest_timescales[:, i], model_type=model_type, run_gridsearch=run_gridsearch)
    preds = results['predictions']['prediction']
    predictions_rest.append(preds)
    age_resolved_change_rest.append(np.diff(preds))
    
    # task
    results = fit_age_relationship(age, sex, task_timescales[which_task][:, i], model_type=model_type, run_gridsearch=run_gridsearch)
    preds = results['predictions']['prediction']
    predictions_task.append(preds)
    age_resolved_change_task.append(np.diff(preds))

predictions_rest = np.asarray(predictions_rest)
age_resolved_change_rest = np.asarray(age_resolved_change_rest)
predictions_task = np.asarray(predictions_task)
age_resolved_change_task = np.asarray(age_resolved_change_task)
print(predictions_rest.shape, age_resolved_change_rest.shape, predictions_task.shape, age_resolved_change_task.shape)

In [None]:
# sa_axis_corr = []
# for i in np.arange(age_resolved_change_rest.shape[1]):
#     sa_axis_corr.append(sp.stats.pearsonr(sa_axis, age_resolved_change_rest[:, i])[0])
# sa_axis_corr = np.asarray(sa_axis_corr)
# print(sa_axis_corr.shape)

# fig, ax = plt.subplots(1, 1, figsize=(2, 2))
# ax.plot(results['predictions']['age'][1:], sa_axis_corr, 'k-', linewidth=2)
# ax.set_xlabel('Age (years)')
# ax.set_ylabel('corr(SA axis, age effect)')
# # ax.grid(True)
# # fig.tight_layout()
# plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2))

sa_axis_corr = []
for i in np.arange(age_resolved_change_rest.shape[1]):
    sa_axis_corr.append(sp.stats.pearsonr(sa_axis, age_resolved_change_rest[:, i])[0])
    # sa_axis_corr.append(sp.stats.pearsonr(sa_axis, np.abs(age_resolved_change_rest[:, i]))[0])
sa_axis_corr = np.asarray(sa_axis_corr)
# sa_axis_corr = sp.stats.zscore(sa_axis_corr)
ax.plot(results['predictions']['age'][1:], sa_axis_corr, 'r-', linewidth=2, label='rest')

sa_axis_corr = []
for i in np.arange(age_resolved_change_task.shape[1]):
    sa_axis_corr.append(sp.stats.pearsonr(sa_axis, age_resolved_change_task[:, i])[0])
    # sa_axis_corr.append(sp.stats.pearsonr(sa_axis, np.abs(age_resolved_change_task[:, i]))[0])
sa_axis_corr = np.asarray(sa_axis_corr)
# sa_axis_corr = sp.stats.zscore(sa_axis_corr)
ax.plot(results['predictions']['age'][1:], sa_axis_corr, 'b-', linewidth=2, label='task')

ax.set_xlabel('Age (years)')
ax.set_ylabel('corr(SA axis, age effect)')
ax.grid(True)
ax.legend()
plt.show()

## Dataset specific analyses

### HCP YA

In [None]:
if which_data == 'HCP-YA':
    n_cols = 3
    f, ax = plt.subplots(1, n_cols, figsize=(2*n_cols, 2))
    plot_col = 0

    # version 1, deflections
    # int_deflections = task_timescales - rest_timescales
    # int_deflections_resid = np.zeros(int_deflections.shape)
    # for i in np.arange(n_subjects):
    #     int_deflections_resid[i, :] = nuis_reg(rest_timescales[i, :], int_deflections[i, :], use_sklearn=True)[:, 0]
        
    # corr_map = np.zeros(n_parcels)
    # p_map = np.zeros(n_parcels)
    # for i in np.arange(n_parcels):
    #     stats = sp.stats.pearsonr(int_deflections[y_var_filter, i], df_behavior[y_var][y_var_filter])
    #     # stats = sp.stats.pearsonr(int_deflections_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
    #     corr_map[i] = stats[0]
    #     p_map[i] = stats[1]

    # reg_plot(sa_axis, corr_map, xlabel='SA axis', ylabel='corr(INTs_deflection, {0})'.format(y_var), ax=ax[plot_col], annotate='both', order=1, add_pval=False)  # 'coef({0}, {1})'.format(effect,y_var)

    # version 2, interaction
    # int_deflections = task_timescales_resid - rest_timescales_resid
    int_deflections = task_timescales[which_task] - rest_timescales
    corr_map_rest = np.zeros(n_parcels)
    corr_map_task = np.zeros(n_parcels)
    corr_map_delta = np.zeros(n_parcels)
    for i in np.arange(n_parcels):
        stats = sp.stats.pearsonr(rest_timescales_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
        corr_map_rest[i] = stats[0]
        
        stats = sp.stats.pearsonr(task_timescales_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
        corr_map_task[i] = stats[0]
        
        stats = sp.stats.pearsonr(int_deflections[y_var_filter, i], df_behavior[y_var][y_var_filter])
        corr_map_delta[i] = stats[0]

    df_stats = pd.DataFrame(index=np.arange(n_parcels))
    df_stats['rest'] = corr_map_rest
    df_stats['task'] = corr_map_task
    df_stats['SA'] = sa_axis
    df_stats = sm.add_constant(df_stats)
    results = smf.ols(formula='SA ~ rest : task', data=df_stats).fit()

    sns.regplot(x=sa_axis, y=corr_map_rest, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='rest', line_kws={'linestyle': '-'}, color=my_colors['north_sea_green'], order=1)
    sns.regplot(x=sa_axis, y=corr_map_task, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='task', line_kws={'linestyle': '--'}, color=my_colors['north_sea_green'], order=1)
    ax[plot_col].set_xlabel('SA axis')
    ax[plot_col].set_ylabel('corr(INTs, {0})'.format(y_var))
    ax[plot_col].set_title('t = {:.2f}, p = {:.2f}'.format(np.abs(results.tvalues['rest:task']), results.pvalues['rest:task']))
    ax[plot_col].legend()

    plot_col += 1

    q_val = 0.2
    lower_q = sa_axis < np.quantile(sa_axis, q=q_val)
    upper_q = sa_axis > np.quantile(sa_axis, q=1-q_val)

    rest_timescales_mean_bin = rest_timescales_resid[:, lower_q].mean(axis=1)
    task_timescales_mean_bin = task_timescales_resid[:, lower_q].mean(axis=1)
    y = df_behavior[y_var][y_var_filter]
    x = rest_timescales_mean_bin[y_var_filter]
    sns.regplot(x=x, y=y, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='rest', line_kws={'linestyle': '-'}, color='orange', order=1)
    r_rest = sp.stats.pearsonr(x, y)[0]
    x = task_timescales_mean_bin[y_var_filter]
    sns.regplot(x=x, y=y, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='task', line_kws={'linestyle': '--'}, color='orange', order=1)
    r_task = sp.stats.pearsonr(x, y)[0]

    ax[plot_col].set_xlabel('INTs')
    ax[plot_col].set_ylabel(y_var)
    ax[plot_col].set_title('Sensorimotor')
    # ax[plot_col].set_title('Sensorimotor\nr_rest={:.2f}, r_task={:.2f}'.format(r_rest, r_task))
    # ax[plot_col].set_title('Sensorimotor (bottom {0}%)'.format(int(q_val*100)))
    ax[plot_col].legend()
    plot_col += 1

    rest_timescales_mean_bin = rest_timescales_resid[:, upper_q].mean(axis=1)
    task_timescales_mean_bin = task_timescales_resid[:, upper_q].mean(axis=1)
    x = rest_timescales_mean_bin[y_var_filter]
    sns.regplot(x=x, y=y, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='rest', line_kws={'linestyle': '-'}, color='purple', order=1)
    r_rest = sp.stats.pearsonr(x, y)[0]
    x = task_timescales_mean_bin[y_var_filter]
    sns.regplot(x=x, y=y, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='task', line_kws={'linestyle': '--'}, color='purple', order=1)
    r_task = sp.stats.pearsonr(x, y)[0]

    ax[plot_col].set_xlabel('INTs')
    ax[plot_col].set_ylabel(y_var)
    ax[plot_col].set_title('Association')
    # ax[plot_col].set_title('Association\nr_rest={:.2f}, r_task={:.2f}'.format(r_rest, r_task))
    # ax[plot_col].set_title('Association (top {0}%)'.format(int(q_val*100)))
    ax[plot_col].legend()

    for this_ax in ax.reshape(-1):
        sns.despine(right=True, top=True, ax=this_ax)

    # f.suptitle(which_task)
    f.tight_layout()
    plt.show()
    f.savefig(os.path.join(outdir, 'interaction_plot.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

    f2 = surface_plot(
        data=corr_map_delta,
        lh_annot_file=lh_annot_file,
        rh_annot_file=rh_annot_file,
        fsaverage=fsaverage,
        order="lr",
        cmap="coolwarm",
        # title_str='corr(int_deflections, behavior)'
    )
    f2.savefig(os.path.join(outdir, 'interaction_plot_brain.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

In [None]:
if which_data == 'HCP-YA':
    heatmap_data = np.zeros((len(column_filter), 5))
    heatmap_data_pval = np.zeros((len(column_filter), 5))

    for idx, y_var in enumerate(column_filter):
        print(y_var)
        y_var_filter = ~np.isnan(df_behavior[y_var].values)
        print(np.sum(y_var_filter))
        
        corr_map_rest = np.zeros(n_parcels)
        corr_map_task = np.zeros(n_parcels)
        for i in np.arange(n_parcels):
            stats = sp.stats.pearsonr(rest_timescales_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
            corr_map_rest[i] = stats[0]
            
            stats = sp.stats.pearsonr(task_timescales_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
            corr_map_task[i] = stats[0]

        df_stats = pd.DataFrame(index=np.arange(n_parcels))
        df_stats['rest'] = corr_map_rest
        df_stats['task'] = corr_map_task
        df_stats['SA'] = sa_axis
        df_stats = sm.add_constant(df_stats)
        results = smf.ols(formula='SA ~ rest : task', data=df_stats).fit()
        
        heatmap_data[idx, 0] = results.tvalues['rest:task']
        heatmap_data_pval[idx, 0] = results.pvalues['rest:task']

        y = df_behavior[y_var][y_var_filter]

        rest_timescales_mean_bin = rest_timescales_resid[:, lower_q].mean(axis=1)
        task_timescales_mean_bin = task_timescales_resid[:, lower_q].mean(axis=1)
        heatmap_data[idx, 1] = sp.stats.pearsonr(rest_timescales_mean_bin[y_var_filter], y)[0]
        heatmap_data[idx, 2] = sp.stats.pearsonr(task_timescales_mean_bin[y_var_filter], y)[0]

        rest_timescales_mean_bin = rest_timescales_resid[:, upper_q].mean(axis=1)
        task_timescales_mean_bin = task_timescales_resid[:, upper_q].mean(axis=1)
        heatmap_data[idx, 3] = sp.stats.pearsonr(rest_timescales_mean_bin[y_var_filter], y)[0]
        heatmap_data[idx, 4] = sp.stats.pearsonr(task_timescales_mean_bin[y_var_filter], y)[0]

In [None]:
if which_data == 'HCP-YA':
    # f, ax = plt.subplots(3, 1, figsize=(1, 2))
    f, ax = plt.subplots(3, 1, figsize=(1.25, 2))
    # heatmap_labels = [my_str.split('_')[0] for my_str in cog_measures]
    heatmap_labels = [my_str.split('_')[0] for my_str in column_filter]
    heatmap_labels[0] = 'accuracy'
    heatmap_labels[1] = 'RT'
    print(heatmap_labels)

    markerline, stemlines, baseline = ax[0].stem(np.abs(heatmap_data[:, 0]))
    plt.setp(stemlines, color=my_colors['north_sea_green'])
    plt.setp(markerline, markersize=2, markeredgecolor=my_colors['north_sea_green'], markeredgewidth=0.5, markerfacecolor=my_colors['north_sea_green'])
    plt.setp(baseline, visible=False)
    ax[0].axhline(y=0, color='grey', linestyle=':')
    ax[0].set_ylabel('t-statistic')
    ax[0].set_xticklabels('')
    ax[0].set_xticks([])

    markerline, stemlines, baseline = ax[1].stem(np.arange(len(heatmap_labels))-0.1, heatmap_data[:, 1], linefmt='-', label='rest')
    plt.setp(stemlines, color='orange')
    plt.setp(markerline, markersize=2, markeredgecolor='orange', markeredgewidth=0.5, markerfacecolor='orange')
    plt.setp(baseline, visible=False)
    markerline, stemlines, baseline = ax[1].stem(np.arange(len(heatmap_labels))+0.1, heatmap_data[:, 2], linefmt='--', label='task')
    plt.setp(stemlines, color='orange')
    plt.setp(markerline, markersize=2, markeredgecolor='orange', markeredgewidth=0.5, markerfacecolor='orange')
    plt.setp(baseline, visible=False)
    ax[1].axhline(y=0, color='grey', linestyle=':')
    ax[1].set_ylabel('slope (r)')
    ax[1].set_xticklabels('')
    ax[1].set_xticks([])

    markerline, stemlines, baseline = ax[2].stem(np.arange(len(heatmap_labels))-0.1, heatmap_data[:, 3], linefmt='-', label='rest')
    plt.setp(stemlines, color='purple')
    plt.setp(markerline, markersize=2, markeredgecolor='purple', markeredgewidth=0.5, markerfacecolor='purple')
    plt.setp(baseline, visible=False)
    markerline, stemlines, baseline = ax[2].stem(np.arange(len(heatmap_labels))+0.1, heatmap_data[:, 4], linefmt='--', label='task')
    plt.setp(stemlines, color='purple')
    plt.setp(markerline, markersize=2, markeredgecolor='purple', markeredgewidth=0.5, markerfacecolor='purple')
    plt.setp(baseline, visible=False)
    ax[2].axhline(y=0, color='grey', linestyle=':')
    ax[2].set_ylabel('slope (r)')
    ax[2].set_xticks(np.arange(len(heatmap_labels)))
    ax[2].set_xticklabels(heatmap_labels, rotation=90, ha='center')

    # f.tight_layout()
    plt.show()
    f.savefig(os.path.join(outdir, 'interaction_stemplot.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)