In [None]:
# import
import os, sys
import numpy as np
import pandas as pd
import scipy as sp
from scipy import stats
from scipy.spatial import distance
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from tqdm import tqdm

# import plotting libraries
import matplotlib.pyplot as plt

plt.rcParams.update({"font.size": 8})
plt.rcParams["svg.fonttype"] = "none"
import seaborn as sns
from nilearn import datasets
from nilearn import plotting

sys.path.extend([r'/home/lindenmp/research_projects/snaplab_tools'])
sys.path.extend([r'/home/lindenmp/research_projects/nctpy/src'])

# import nctpy functions
from snaplab_tools.plotting.plotting import categorical_kde_plot, reg_plot, brain_scatter_plot, null_plot, surface_plot
from snaplab_tools.plotting.utils import get_my_colors, get_p_val_string
from nctpy.utils import matrix_normalization
from snaplab_tools.utils import get_schaefer_system_mask, get_null_p, get_fdr_p
from snaplab_tools.plotting.utils import roi_to_vtx

from brainsmash.mapgen.base import Base
from src.utils import get_adj_weights
from snaplab_tools.utils import normalize_x

from snaplab_tools.derivs import compute_acf
from nilearn.glm import first_level
from nilearn import plotting
from nilearn.glm.contrasts import compute_contrast
from scipy import signal
import statsmodels.api as sm
import statsmodels.formula.api as smf
from sklearn.metrics import r2_score

In [None]:
def compute_deflections(rest_timescales, task_timescales, nuisance_regression=False):
    rest_timescales = rest_timescales.reshape(-1, 1)
    task_timescales = task_timescales.reshape(-1, 1)

    int_deflections = task_timescales - rest_timescales
    
    if nuisance_regression:
        deflections_mean = np.mean(int_deflections)
        
        X = np.concatenate((rest_timescales, np.ones((rest_timescales.shape[0], 1))), axis=1)
        beta = np.dot(np.linalg.pinv(X), int_deflections)
        predicted = np.dot(X, beta)
        int_deflections = int_deflections - predicted
        # print(int_deflections.shape, predicted.shape)
        # int_deflections += deflections_mean
    
    return int_deflections[:, 0]

def nuis_reg(X, y, use_sklearn=False):
    if use_sklearn:
        if X.ndim == 1:
            X = X[:, np.newaxis]
        if y.ndim == 1:
            y = y[:, np.newaxis]

        regr = LinearRegression()
        regr.fit(X, y)
        predicted = regr.predict(X)
        residuals = y - predicted
    else:
        beta = np.dot(np.linalg.pinv(X), y)
        predicted = np.dot(X, beta)
        residuals = y - predicted
    return residuals

def run_regression(df_behavior, y_var, rest_timescales, task_timescales):
    n_nodes = rest_timescales.shape[0]
    effect_map = pd.DataFrame(index=np.arange(n_nodes))
    pvals = pd.DataFrame(index=np.arange(n_nodes))

    for i in tqdm(np.arange(n_nodes)):
        df_stats = pd.DataFrame(index=df_behavior.index)
        df_stats[y_var] = df_behavior[y_var].copy()
        df_stats[y_var].fillna(np.nanmean(df_stats[y_var]), inplace=True)
        df_stats[y_var] = sp.stats.zscore(df_stats[y_var])
        
        df_stats['rest'] = sp.stats.zscore(rest_timescales[i])
        df_stats['task'] = sp.stats.zscore(task_timescales[i])
        df_stats['age'] = sp.stats.zscore(df_behavior['age'])
        df_stats['sex'] = df_behavior['sex']
        df_stats = sm.add_constant(df_stats)
        # if i == 0:
            # print(df_stats.head())
        
        # Fit the model
        results = smf.ols(formula='{0} ~ rest * task + age + C(sex)'.format(y_var), data=df_stats).fit()
        # if i == 0:
            # print(results.summary())
        
        for effect in ['rest', 'task', 'rest:task', 'age', 'C(sex)[T.1]']:
            try:
                effect_map.loc[i, effect] = results.params[effect]
                pvals.loc[i, effect] = results.pvalues[effect]
            except:
                pass
        effect_map.loc[i, 'rsquared'] = results.rsquared_adj
        
    return effect_map, pvals

def detect_univariate_outliers(x, k=1.5):
    q1, q3 = np.percentile(x, [25, 75])
    iqr = q3 - q1
    lower_bound = q1 - (k * iqr)
    upper_bound = q3 + (k * iqr)
    outlier_labels = ((x < lower_bound) | (x > upper_bound)).astype(int)
    
    return outlier_labels

def compute_r2(x, y):
    if x.ndim == 1:
        x = x[:, np.newaxis]
    if y.ndim == 1:
        y = y[:, np.newaxis]

    regr = LinearRegression()
    regr.fit(x, y)
    y_pred = regr.predict(x)
    return r2_score(y, y_pred) * 100

def winsorize_iqr(vector, k=1.5, inplace=False):
    """
    Winsorizes a vector using the IQR method to handle outliers.
    
    Parameters:
    -----------
    vector : array-like
        Input data to be winsorized (list, numpy array, or pandas Series)
    k : float, optional (default=1.5)
        Multiplier for IQR to determine outlier thresholds
    inplace : bool, optional (default=False)
        If True, modifies the input vector in place (only works with mutable input)
        
    Returns:
    --------
    winsorized_vector : numpy array
        Winsorized version of the input vector
    """
    # Convert input to numpy array if it isn't already
    if not isinstance(vector, np.ndarray):
        vector = np.array(vector)
    
    # Calculate quartiles and IQR
    q1 = np.percentile(vector, 25)
    q3 = np.percentile(vector, 75)
    iqr = q3 - q1
    
    # Calculate lower and upper bounds
    lower_bound = q1 - k * iqr
    upper_bound = q3 + k * iqr
    
    # Create a copy unless inplace is True and input is mutable
    if inplace and isinstance(vector, np.ndarray):
        winsorized_vector = vector
    else:
        winsorized_vector = vector.copy()
    
    # Winsorize the values
    winsorized_vector[winsorized_vector < lower_bound] = lower_bound
    winsorized_vector[winsorized_vector > upper_bound] = upper_bound
    
    return winsorized_vector

## Setup

In [None]:
# directory where data is stored
indir = '/home/lindenmp/research_projects/nct_xr/data/int_deflections'
which_data = 'HCP_YA'
# which_data = 'HCP_D'
outdir = '/home/lindenmp/research_projects/nct_xr/results/int_deflections/{0}'.format(which_data.replace('_', ''))

atlas = 'Schaefer4007'
if atlas == 'Schaefer4007':
    n_parcels = 400
    n_nodes = 400
elif atlas == 'Schaefer2007':
    n_parcels = 200
    n_nodes = 200
elif atlas == 'Schaefer1007':
    n_parcels = 100
    n_nodes = 100
    
if which_data == 'HCP_YA':
    tr = 0.720
    # tasks = ['tfMRI_EMOTION_LR', 'tfMRI_GAMBLING_LR', 'tfMRI_LANGUAGE_LR', 'tfMRI_RELATIONAL_LR', 'tfMRI_SOCIAL_LR', 'tfMRI_WM_LR']
    tasks = ['tfMRI_EMOTION_LR', 'tfMRI_LANGUAGE_LR', 'tfMRI_RELATIONAL_LR', 'tfMRI_WM_LR']

    # which_task = 'tfMRI_EMOTION_LR'
    # which_task = 'tfMRI_GAMBLING_LR'
    # which_task = 'tfMRI_LANGUAGE_LR'
    # which_task = 'tfMRI_RELATIONAL_LR'
    # which_task = 'tfMRI_SOCIAL_LR'
    # which_task = 'tfMRI_WM_LR'
    
    # which_task = 'tfMRI_WM_RL'

    # if which_task == 'tfMRI_EMOTION_LR':
    #     contrasts_list = ['fear', 'neut']
    #     task_performance_measures = ['Emotion_Task_Acc', 'Emotion_Task_Median_RT']
    # elif which_task == 'tfMRI_GAMBLING_LR':
    #     contrasts_list = ['win', 'loss']
    #     task_performance_measures = ['Gambling_Task_Perc_Larger', 'Gambling_Task_Perc_Smaller']
    # elif which_task == 'tfMRI_LANGUAGE_LR':
    #     contrasts_list = ['math', 'story']
    #     task_performance_measures = ['Language_Task_Acc', 'Language_Task_Median_RT']
    # elif which_task == 'tfMRI_RELATIONAL_LR':
    #     contrasts_list = ['match', 'relation']
    #     task_performance_measures = ['Relational_Task_Acc', 'Relational_Task_Median_RT']
    # elif which_task == 'tfMRI_SOCIAL_LR':
    #     contrasts_list = ['mental', 'rnd']
    #     task_performance_measures = ['Social_Task_Perc_Random', 'Social_Task_Perc_TOM']
    # elif which_task == 'tfMRI_WM_LR' or which_task == 'tfMRI_WM_RL':
    #     contrasts_list = ['0bk', '2bk']
    #     # task_performance_measures = ['WM_Task_Acc', 'WM_Task_Median_RT', 'WM_Task_2bk_Acc', 'WM_Task_2bk_Median_RT', 'WM_Task_0bk_Acc', 'WM_Task_0bk_Median_RT']
    #     task_performance_measures = ['WM_Task_Acc', 'WM_Task_Median_RT']
    
    cog_measures = [
        'PMAT24_A_CR', # fluid intelligence
        'VSPLOT_TC', # spatial orientation
        'ListSort_Unadj', # working memory
        'DDisc_AUC_40K', # delay discounting
        'Flanker_Unadj', # executive function
    ]
elif which_data == 'HCP_D':
    tr = 0.800
    tasks = ['tfMRI_CARIT_PA', 'tfMRI_EMOTION_PA', 'tfMRI_GUESSING_PA']
    # which_task = 'tfMRI_CARIT_PA'
    # which_task = 'tfMRI_EMOTION_PA'
    # which_task = 'tfMRI_GUESSING_PA'

In [None]:
# plot states on brain surface
annot_dir = '/home/lindenmp/research_projects/nctpy/data'
lh_annot_file = os.path.join(annot_dir, 'schaefer_parc', 'fsaverage5', 'lh.Schaefer2018_{0}Parcels_7Networks_order.annot'.format(n_parcels))
rh_annot_file = os.path.join(annot_dir, 'schaefer_parc', 'fsaverage5', 'rh.Schaefer2018_{0}Parcels_7Networks_order.annot'.format(n_parcels))
fsaverage = datasets.fetch_surf_fsaverage(mesh="fsaverage5")
cmap = "viridis"

In [None]:
parc_centroids = pd.read_csv(os.path.join(indir, 'Schaefer2018_{0}Parcels_7Networks_order_FSLMNI152_1mm.Centroid_RAS.csv'.format(n_parcels)), index_col=1)
parc_centroids.drop(columns=['ROI Label'], inplace=True)
print(parc_centroids.head())

distance_matrix = distance.pdist(
    parc_centroids, "euclidean"
)  # get euclidean distances between nodes
distance_matrix = distance.squareform(distance_matrix)  # reshape to square matrix

# Load data

In [None]:
subjectids = np.loadtxt(os.path.join(indir, '{0}_{1}_subjids.txt'.format(which_data.replace('_', ''), atlas)), dtype=str)
n_subs = len(subjectids)
print(n_subs)

In [None]:
sa_axis = np.load(os.path.join(indir, 'schaefer{0}-7_sa-axis.npy'.format(n_parcels)))

myelin = np.load(os.path.join(indir, '{0}_{1}_myelin.npy'.format(which_data.replace('_', ''), atlas)))
myelin_mean = np.nanmean(myelin, axis=1)

which_task = 'rfMRIREST1LR'
rest_timescales = np.load(os.path.join(outdir, '{0}_{1}_{2}_timescales.npy'.format(which_data.replace('_', ''), atlas, which_task)))
rest_timescales_mean = np.nanmean(rest_timescales, axis=0)

print(sa_axis.shape, myelin_mean.shape, rest_timescales.shape, rest_timescales_mean.shape)

In [None]:
task_timescales = dict()
task_timescales_mean = dict()
subject_filter = dict()
task_timescales_mean_all = np.zeros(n_parcels)
for which_task in tasks:
    print(which_task)
    task_timescales[which_task] = np.load(os.path.join(outdir, '{0}_{1}_{2}_timescales.npy'.format(which_data.replace('_', ''), atlas, which_task.replace('_', ''))))
    try:
        subject_filter[which_task] = np.load(os.path.join(outdir, '{0}_{1}_{2}_subjectfilter.npy'.format(which_data.replace('_', ''), atlas, which_task.replace('_', ''))))
        print(subject_filter[which_task].sum())
        task_timescales_mean[which_task] = np.nanmean(task_timescales[which_task][~subject_filter[which_task], :], axis=0)
    except:
        task_timescales_mean[which_task] = np.nanmean(task_timescales[which_task], axis=0)
        
    task_timescales_mean_all += task_timescales_mean[which_task]

task_timescales_mean_all = np.divide(task_timescales_mean_all, len(tasks))

In [None]:
f = surface_plot(
    data=sa_axis,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="viridis",
)
f.savefig(os.path.join(outdir, 'sa_axis.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

In [None]:
f, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))
reg_plot(rest_timescales_mean, myelin_mean, ylabel='Myelin (mean)', xlabel='ACF_lag (mean)', ax=ax, annotate='both')

In [None]:
df_stats = pd.DataFrame(index=np.arange(n_parcels))
df_stats['rest'] = rest_timescales_mean
df_stats['task'] = task_timescales_mean_all
df_stats['SA'] = sa_axis
df_stats = sm.add_constant(df_stats)
results = smf.ols(formula='SA ~ rest * task', data=df_stats).fit()
print(results.summary())

In [None]:
# int_deflections = compute_deflections(rest_timescales_mean, task_timescales_mean_all, nuisance_regression=True)
int_deflections = task_timescales_mean_all - rest_timescales_mean

f, ax = plt.subplots(1, 3, figsize=(2*3, 1.75))
reg_plot(sa_axis, rest_timescales_mean, xlabel='SA axis', ylabel='INTs (rest)', ax=ax[0], annotate='both', add_pval=False)
reg_plot(sa_axis, task_timescales_mean_all, xlabel='SA axis', ylabel='INTs (task)', ax=ax[1], annotate='both', add_pval=False)
reg_plot(sa_axis, int_deflections, xlabel='SA axis', ylabel='INTs (task-rest)', ax=ax[2], annotate='both', add_pval=False)

f.tight_layout()
plt.show()
f.savefig(os.path.join(outdir, 'ints_deflections_sa_corr.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

In [None]:
f = surface_plot(
    data=rest_timescales_mean,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="magma",
)
f.savefig(os.path.join(outdir, 'ints_rest.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

f = surface_plot(
    data=task_timescales_mean_all,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="magma",
)
f.savefig(os.path.join(outdir, 'ints_task.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

f = surface_plot(
    data=int_deflections,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="coolwarm",
)
f.savefig(os.path.join(outdir, 'ints_deflections.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

# Subject-level effects

In [None]:
run_cell = True

In [None]:
which_task = tasks[-1]
print(which_task)

if which_task == 'tfMRI_EMOTION_LR':
    task_performance_measures = ['Emotion_Task_Acc', 'Emotion_Task_Median_RT']
elif which_task == 'tfMRI_GAMBLING_LR':
    task_performance_measures = ['Gambling_Task_Perc_Larger', 'Gambling_Task_Perc_Smaller']
elif which_task == 'tfMRI_LANGUAGE_LR':
    task_performance_measures = ['Language_Task_Acc', 'Language_Task_Median_RT']
elif which_task == 'tfMRI_RELATIONAL_LR':
    task_performance_measures = ['Relational_Task_Acc', 'Relational_Task_Median_RT']
elif which_task == 'tfMRI_SOCIAL_LR':
    task_performance_measures = ['Social_Task_Perc_Random', 'Social_Task_Perc_TOM']
elif which_task == 'tfMRI_WM_LR' or which_task == 'tfMRI_WM_RL':
    task_performance_measures = ['WM_Task_Acc', 'WM_Task_Median_RT']

In [None]:
if run_cell:
    n_subjects = len(subjectids)
    column_filter = task_performance_measures.copy()
    column_filter.extend(cog_measures)
    print(column_filter)

    if which_data == 'HCP_YA':
        df_behavior = pd.read_csv('/mnt/storage_ssd_raid/research_data/HCP_YA/unrestricted_lindenmp_9_26_2022_16_52_14.csv', index_col=0)
        df_behavior = df_behavior.loc[subjectids.astype(int)]
        df_behavior_restricted = pd.read_csv('/mnt/storage_ssd_raid/research_data/HCP_YA/RESTRICTED_lindenmp_9_26_2022_16_42_3.csv', index_col=0)
        df_behavior_restricted = df_behavior_restricted.loc[subjectids.astype(int)]
        df_behavior['age'] = df_behavior_restricted['Age_in_Yrs']
        df_behavior['sex'] = df_behavior['Gender'] == 'M'
        df_behavior['sex'] = df_behavior['sex'].astype(int)

        # for y_var in column_filter:
            # df_behavior[y_var].fillna(np.nanmean(df_behavior[y_var]), inplace=True)

        y_var = column_filter[0]
        print(y_var)
        if 'Acc' in y_var:
            y_var_label = 'accuracy'
        elif 'RT' in y_var:
            y_var_label = 'RT'
        else:
            y_var_label = y_var

        if y_var == 'WM_Task_Acc':
            y_var_filter = df_behavior[y_var] >= 50
        else:
            y_var_filter = ~np.isnan(df_behavior[y_var].values)
            # y_var_filter = np.zeros(n_subjects).astype(bool)
            # y_var_filter[:] = True
        print(np.sum(y_var_filter))
            
        nuis_covs = np.concatenate((sp.stats.zscore(df_behavior['age']).values[:, np.newaxis],
                                    df_behavior['sex'].values[:, np.newaxis],
                                    np.ones((n_subjects, 1))
                                    ), axis=1)

        # nuis_covs = np.concatenate((sp.stats.rankdata(df_behavior['age'])[:, np.newaxis],
        #                             df_behavior['sex'].values[:, np.newaxis],
        #                             np.ones((n_subjects, 1))
        #                             ), axis=1)

    elif which_data == 'HCP_D':
        df_behavior = pd.read_csv('/mnt/storage_ssd_raid/research_data/HCP_D/behavior/socdem01.txt', sep='\t', header=0)
        df_behavior = df_behavior.iloc[1:]
        behavior_subjectids = list(df_behavior['src_subject_id'])
        behavior_subjectids = ['sub-' + s + 'V1MR' for s in behavior_subjectids]
        df_behavior.index = behavior_subjectids
        df_behavior.index.name = 'subjectids'

        y_var = 'interview_age'
        # df_behavior = df_behavior.loc[:, ['WM_Task_Acc', 'WM_Task_Median_RT', 'WM_Task_2bk_Acc', 'WM_Task_2bk_Median_RT', 'WM_Task_0bk_Acc', 'WM_Task_0bk_Median_RT']]
        df_behavior = df_behavior.loc[:, [y_var, 'sex']]
        df_behavior[y_var] = df_behavior[y_var].astype(int)
        df_behavior = df_behavior.loc[subjectids]
        y_var_filter = np.zeros(n_subjects).astype(bool)
        y_var_filter[:] = True

    df_behavior

In [None]:
if run_cell:
    rest_timescales_resid = nuis_reg(nuis_covs, rest_timescales)
    task_timescales_resid = nuis_reg(nuis_covs, task_timescales[which_task])
    # int_deflections = task_timescales - rest_timescales
    # int_deflections = task_timescales_resid - rest_timescales_resid
    # int_deflections_resid = nuis_reg(rest_timescales_resid, task_timescales_resid)

    # brain_data = rest_timescales_resid
    brain_data = task_timescales_resid
    # brain_data = int_deflections_resid

    # Fit models for each region, plot SA axis by beta value
    corr_map = np.zeros(n_parcels)
    p_map = np.zeros(n_parcels)
    for i in np.arange(n_parcels):
        stats = sp.stats.pearsonr(brain_data[y_var_filter, i], df_behavior[y_var][y_var_filter])
        corr_map[i] = stats[0]
        p_map[i] = stats[1]

        # corr_map[i] = compute_r2(brain_data[y_var_filter, i], df_behavior[y_var][y_var_filter].values)

    p_map = get_fdr_p(p_map)
    corr_map_filtered = corr_map.copy()
    # corr_map_filtered[p_map>0.05] = np.nan

    f2 = surface_plot(
        data=corr_map_filtered,
        lh_annot_file=lh_annot_file,
        rh_annot_file=rh_annot_file,
        fsaverage=fsaverage,
        order="lr",
        cmap="coolwarm",
        # title_str='corr(int_deflections, behavior)'
    )
    f2.savefig(os.path.join(outdir, 'corr(int_deflections,behavior)_brainplot.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

    # Get slope of the task-rest ~ SA axis per subject, regress slope on performance
    f, ax = plt.subplots(1, 1, figsize=(1.1, 1.1))
    reg_plot(sa_axis, corr_map, xlabel='SA axis', ylabel='corr(INTs, {0})'.format(y_var_label), ax=ax, annotate='spearman', add_pval=False)
    plt.show()
    f.savefig(os.path.join(outdir, 'int_deflections_task_performance_sa_corr_1.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

    corr_map = np.zeros(n_subjects)
    for i in np.arange(n_subjects):
        corr_map[i] = sp.stats.pearsonr(sa_axis, brain_data[i, :])[0]
    f, ax = plt.subplots(1, 1, figsize=(1.1, 1.1))
    reg_plot(df_behavior[y_var].values[y_var_filter], corr_map[y_var_filter], ylabel='corr(SA axis, INTs)', xlabel=y_var_label, ax=ax, annotate='spearman', add_pval=False)
    plt.show()
    f.savefig(os.path.join(outdir, 'int_deflections_task_performance_sa_corr_2.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

In [None]:
import numpy as np
from pygam import LinearGAM, s, f
from sklearn.model_selection import KFold
from itertools import product
from tqdm import tqdm  # for progress bar


def fit_smooth_gam(X, y, n_splines=20, lam_range=np.logspace(-3, 3, 25), 
                  term_types='auto', n_folds=5, plot=False, 
                  return_model=False, verbose=0, random_state=None,
                  progress_bar=True, **kwargs):
    """
    Fit a smooth Generalized Additive Model (GAM) with automated smoothing parameter selection.
    
    Parameters:
    -----------
    X : array-like or pd.DataFrame
        Input features (n_samples, n_features)
    y : array-like
        Target values (n_samples,)
    n_splines : int or list, optional
        Number of splines to use (default: 20). If list, specify per feature.
    lam_range : list or array-like, optional
        Range of lambda values to search (default: [0.1, 1, 10, 100])
    term_types : str or list, optional
        Term types ('auto', 's' for spline, 'f' for factor, or list per feature)
    n_folds : int, optional
        Number of CV folds (default: 5)
    plot : bool, optional
        Whether to plot partial dependence (default: False)
    return_model : bool, optional
        Whether to return fitted model (default: False)
    verbose : int, optional
        Verbosity level (0-2)
    random_state : int, optional
        Random seed for reproducibility
    progress_bar : bool, optional
        Whether to show progress bar (default: True)
    **kwargs : dict
        Additional arguments for LinearGAM
        
    Returns:
    --------
    tuple or array:
        If return_model=False: y_pred
        If return_model=True: (y_pred, gam, best_params)
            where best_params contains 'lam' and 'n_splines'
    """
    try:
        from pygam import LinearGAM, s, f
    except ImportError:
        raise ImportError("pygam is required. Install with: pip install pygam")
    
    X = np.asarray(X)
    y = np.asarray(y)
    
    # Handle 1D input
    if X.ndim == 1:
        X = X.reshape(-1, 1)
    
    n_features = X.shape[1]
    lam_range = np.asarray(lam_range)
    
    # Process n_splines parameter
    if isinstance(n_splines, int):
        n_splines = [n_splines] * n_features
    elif len(n_splines) != n_features:
        raise ValueError("n_splines must be int or list with length equal to n_features")
    
    # Determine term types
    if term_types == 'auto':
        term_types = []
        for i in range(n_features):
            unique_vals = len(np.unique(X[:, i]))
            if unique_vals < 10 or np.issubdtype(X.dtype, np.integer):
                term_types.append('f')
            else:
                term_types.append('s')
    elif isinstance(term_types, str):
        term_types = [term_types] * n_features
    
    # Create terms with proper n_splines for spline terms
    terms = []
    for i, term_type in enumerate(term_types):
        if term_type == 's':
            terms.append(s(i, n_splines=n_splines[i]))
        elif term_type == 'f':
            terms.append(f(i))
        else:
            raise ValueError(f"Unknown term type: {term_type}")
    
    # Prepare parameter grid
    if lam_range.ndim == 1:
        param_grid = [lam_range] * n_features
    
    # Generate all lambda combinations
    lam_combinations = list(product(*param_grid))
    
    # Cross-validation setup
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=random_state)
    cv_results = []
    
    # Grid search with progress bar
    iter_combinations = tqdm(lam_combinations, desc="Grid search") if progress_bar else lam_combinations
    
    for lam in iter_combinations:
        fold_scores = []
        
        for train_idx, val_idx in kf.split(X):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            
            gam = LinearGAM(*terms, lam=lam, **kwargs)
            gam.fit(X_train, y_train)
            score = gam.score(X_val, y_val)
            fold_scores.append(score)
        
        mean_score = np.mean(fold_scores)
        cv_results.append({'lam': lam, 'mean_score': mean_score})
        
        if verbose >= 2:
            print(f"Lambda: {lam} - Mean CV score: {mean_score:.4f}")
    
    # Find best parameters
    best_result = max(cv_results, key=lambda x: x['mean_score'])
    best_lam = best_result['lam']
    
    if verbose >= 1:
        print(f"\nBest lambda: {best_lam} with score: {best_result['mean_score']:.4f}")
        print(f"Using n_splines: {n_splines}")
    
    # Fit final model
    final_gam = LinearGAM(*terms, lam=best_lam, **kwargs)
    final_gam.fit(X, y)
    y_pred = final_gam.predict(X)
    
    # Plotting
    if plot:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(3, 3 * n_features))
        
        for i in range(n_features):
            plt.subplot(n_features, 1, i+1)
            XX = final_gam.generate_X_grid(term=i)
            plt.scatter(X[:, i], y, alpha=0.2, label='Data')
            plt.plot(XX[:, i], final_gam.partial_dependence(term=i, X=XX), 
                    'r', linewidth=2, label='GAM fit')
            
            title = f"Feature {i}"
            if term_types[i] == 's':
                title += f" (λ={best_lam[i]:.2f}, n_splines={n_splines[i]})"
            plt.title(title)
            plt.legend()
        
        plt.tight_layout()
        plt.show()
    
    if return_model:
        best_params = {'lam': best_lam, 'n_splines': n_splines}
        return y_pred, final_gam, best_params
    return y_pred

In [None]:
my_colors = get_my_colors()

n_cols = 3
f, ax = plt.subplots(1, n_cols, figsize=(2*n_cols, 2))
plot_col = 0

# version 1, deflections
# int_deflections = task_timescales - rest_timescales
# int_deflections_resid = np.zeros(int_deflections.shape)
# for i in np.arange(n_subjects):
#     int_deflections_resid[i, :] = nuis_reg(rest_timescales[i, :], int_deflections[i, :], use_sklearn=True)[:, 0]
    
# corr_map = np.zeros(n_parcels)
# p_map = np.zeros(n_parcels)
# for i in np.arange(n_parcels):
#     stats = sp.stats.pearsonr(int_deflections[y_var_filter, i], df_behavior[y_var][y_var_filter])
#     # stats = sp.stats.pearsonr(int_deflections_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
#     corr_map[i] = stats[0]
#     p_map[i] = stats[1]

# reg_plot(sa_axis, corr_map, xlabel='SA axis', ylabel='corr(INTs_deflection, {0})'.format(y_var), ax=ax[plot_col], annotate='both', order=1, add_pval=False)  # 'coef({0}, {1})'.format(effect,y_var)

# version 2, interaction
# int_deflections = task_timescales_resid - rest_timescales_resid
int_deflections = task_timescales[which_task] - rest_timescales
corr_map_rest = np.zeros(n_parcels)
corr_map_task = np.zeros(n_parcels)
corr_map_delta = np.zeros(n_parcels)
for i in np.arange(n_parcels):
    stats = sp.stats.pearsonr(rest_timescales_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
    corr_map_rest[i] = stats[0]
    
    stats = sp.stats.pearsonr(task_timescales_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
    corr_map_task[i] = stats[0]
    
    stats = sp.stats.pearsonr(int_deflections[y_var_filter, i], df_behavior[y_var][y_var_filter])
    corr_map_delta[i] = stats[0]

df_stats = pd.DataFrame(index=np.arange(n_parcels))
df_stats['rest'] = corr_map_rest
df_stats['task'] = corr_map_task
df_stats['SA'] = sa_axis
df_stats = sm.add_constant(df_stats)
results = smf.ols(formula='SA ~ rest : task', data=df_stats).fit()

sns.regplot(x=sa_axis, y=corr_map_rest, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='rest', line_kws={'linestyle': '-'}, color=my_colors['north_sea_green'], order=1)
sns.regplot(x=sa_axis, y=corr_map_task, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='task', line_kws={'linestyle': '--'}, color=my_colors['north_sea_green'], order=1)
ax[plot_col].set_xlabel('SA axis')
ax[plot_col].set_ylabel('corr(INTs, {0})'.format(y_var))
ax[plot_col].set_title('t = {:.2f}, p = {:.2f}'.format(np.abs(results.tvalues['rest:task']), results.pvalues['rest:task']))
ax[plot_col].legend()

plot_col += 1

q_val = 0.2
lower_q = sa_axis < np.quantile(sa_axis, q=q_val)
upper_q = sa_axis > np.quantile(sa_axis, q=1-q_val)

rest_timescales_mean_bin = rest_timescales_resid[:, lower_q].mean(axis=1)
task_timescales_mean_bin = task_timescales_resid[:, lower_q].mean(axis=1)
y = df_behavior[y_var][y_var_filter]
x = rest_timescales_mean_bin[y_var_filter]
sns.regplot(x=x, y=y, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='rest', line_kws={'linestyle': '-'}, color='orange', order=1)
r_rest = sp.stats.pearsonr(x, y)[0]
x = task_timescales_mean_bin[y_var_filter]
sns.regplot(x=x, y=y, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='task', line_kws={'linestyle': '--'}, color='orange', order=1)
r_task = sp.stats.pearsonr(x, y)[0]

ax[plot_col].set_xlabel('INTs')
ax[plot_col].set_ylabel(y_var)
ax[plot_col].set_title('Sensorimotor')
# ax[plot_col].set_title('Sensorimotor\nr_rest={:.2f}, r_task={:.2f}'.format(r_rest, r_task))
# ax[plot_col].set_title('Sensorimotor (bottom {0}%)'.format(int(q_val*100)))
ax[plot_col].legend()
plot_col += 1

rest_timescales_mean_bin = rest_timescales_resid[:, upper_q].mean(axis=1)
task_timescales_mean_bin = task_timescales_resid[:, upper_q].mean(axis=1)
x = rest_timescales_mean_bin[y_var_filter]
sns.regplot(x=x, y=y, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='rest', line_kws={'linestyle': '-'}, color='purple', order=1)
r_rest = sp.stats.pearsonr(x, y)[0]
x = task_timescales_mean_bin[y_var_filter]
sns.regplot(x=x, y=y, ax=ax[plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='task', line_kws={'linestyle': '--'}, color='purple', order=1)
r_task = sp.stats.pearsonr(x, y)[0]

ax[plot_col].set_xlabel('INTs')
ax[plot_col].set_ylabel(y_var)
ax[plot_col].set_title('Association')
# ax[plot_col].set_title('Association\nr_rest={:.2f}, r_task={:.2f}'.format(r_rest, r_task))
# ax[plot_col].set_title('Association (top {0}%)'.format(int(q_val*100)))
ax[plot_col].legend()

for this_ax in ax.reshape(-1):
    sns.despine(right=True, top=True, ax=this_ax)

# f.suptitle(which_task)
f.tight_layout()
plt.show()
f.savefig(os.path.join(outdir, 'interaction_plot.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

f2 = surface_plot(
    data=corr_map_delta,
    lh_annot_file=lh_annot_file,
    rh_annot_file=rh_annot_file,
    fsaverage=fsaverage,
    order="lr",
    cmap="coolwarm",
    # title_str='corr(int_deflections, behavior)'
)
f2.savefig(os.path.join(outdir, 'interaction_plot_brain.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)

In [None]:
heatmap_data = np.zeros((len(cog_measures), 5))
heatmap_data_pval = np.zeros((len(cog_measures), 5))

for idx, y_var in enumerate(cog_measures):
    y_var_filter = ~np.isnan(df_behavior[y_var].values)
    print(np.sum(y_var_filter))
    
    corr_map_rest = np.zeros(n_parcels)
    corr_map_task = np.zeros(n_parcels)
    for i in np.arange(n_parcels):
        stats = sp.stats.pearsonr(rest_timescales_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
        corr_map_rest[i] = stats[0]
        
        stats = sp.stats.pearsonr(task_timescales_resid[y_var_filter, i], df_behavior[y_var][y_var_filter])
        corr_map_task[i] = stats[0]

    df_stats = pd.DataFrame(index=np.arange(n_parcels))
    df_stats['rest'] = corr_map_rest
    df_stats['task'] = corr_map_task
    df_stats['SA'] = sa_axis
    df_stats = sm.add_constant(df_stats)
    results = smf.ols(formula='SA ~ rest : task', data=df_stats).fit()
    
    heatmap_data[idx, 0] = results.tvalues['rest:task']
    heatmap_data_pval[idx, 0] = results.pvalues['rest:task']

    y = df_behavior[y_var][y_var_filter]

    rest_timescales_mean_bin = rest_timescales_resid[:, lower_q].mean(axis=1)
    task_timescales_mean_bin = task_timescales_resid[:, lower_q].mean(axis=1)
    heatmap_data[idx, 1] = sp.stats.pearsonr(rest_timescales_mean_bin[y_var_filter], y)[0]
    heatmap_data[idx, 2] = sp.stats.pearsonr(task_timescales_mean_bin[y_var_filter], y)[0]

    rest_timescales_mean_bin = rest_timescales_resid[:, upper_q].mean(axis=1)
    task_timescales_mean_bin = task_timescales_resid[:, upper_q].mean(axis=1)
    heatmap_data[idx, 3] = sp.stats.pearsonr(rest_timescales_mean_bin[y_var_filter], y)[0]
    heatmap_data[idx, 4] = sp.stats.pearsonr(task_timescales_mean_bin[y_var_filter], y)[0]

In [None]:
column_filter

In [None]:
f, ax = plt.subplots(3, 1, figsize=(1, 2))
# f, ax = plt.subplots(3, 1, figsize=(1.25, 2))
heatmap_labels = [my_str.split('_')[0] for my_str in cog_measures]
# heatmap_labels = [my_str.split('_')[0] for my_str in column_filter]
# heatmap_labels[0] = 'accuracy'
# heatmap_labels[1] = 'RT'
print(heatmap_labels)

markerline, stemlines, baseline = ax[0].stem(np.abs(heatmap_data[:, 0]))
plt.setp(stemlines, color=my_colors['north_sea_green'])
plt.setp(markerline, markersize=2, markeredgecolor=my_colors['north_sea_green'], markeredgewidth=0.5, markerfacecolor=my_colors['north_sea_green'])
plt.setp(baseline, visible=False)
ax[0].axhline(y=0, color='grey', linestyle=':')
ax[0].set_ylabel('t-statistic')
ax[0].set_xticklabels('')
ax[0].set_xticks([])

markerline, stemlines, baseline = ax[1].stem(np.arange(len(heatmap_labels))-0.1, heatmap_data[:, 1], linefmt='-', label='rest')
plt.setp(stemlines, color='orange')
plt.setp(markerline, markersize=2, markeredgecolor='orange', markeredgewidth=0.5, markerfacecolor='orange')
plt.setp(baseline, visible=False)
markerline, stemlines, baseline = ax[1].stem(np.arange(len(heatmap_labels))+0.1, heatmap_data[:, 2], linefmt='--', label='task')
plt.setp(stemlines, color='orange')
plt.setp(markerline, markersize=2, markeredgecolor='orange', markeredgewidth=0.5, markerfacecolor='orange')
plt.setp(baseline, visible=False)
ax[1].axhline(y=0, color='grey', linestyle=':')
ax[1].set_ylabel('slope (r)')
ax[1].set_xticklabels('')
ax[1].set_xticks([])

markerline, stemlines, baseline = ax[2].stem(np.arange(len(heatmap_labels))-0.1, heatmap_data[:, 3], linefmt='-', label='rest')
plt.setp(stemlines, color='purple')
plt.setp(markerline, markersize=2, markeredgecolor='purple', markeredgewidth=0.5, markerfacecolor='purple')
plt.setp(baseline, visible=False)
markerline, stemlines, baseline = ax[2].stem(np.arange(len(heatmap_labels))+0.1, heatmap_data[:, 4], linefmt='--', label='task')
plt.setp(stemlines, color='purple')
plt.setp(markerline, markersize=2, markeredgecolor='purple', markeredgewidth=0.5, markerfacecolor='purple')
plt.setp(baseline, visible=False)
ax[2].axhline(y=0, color='grey', linestyle=':')
ax[2].set_ylabel('slope (r)')
ax[2].set_xticks(np.arange(len(heatmap_labels)))
ax[2].set_xticklabels(heatmap_labels, rotation=90, ha='center')

# f.tight_layout()
plt.show()
f.savefig(os.path.join(outdir, 'interaction_stemplot.svg'), dpi=600, bbox_inches="tight", pad_inches=0.01)

In [None]:
run_cell = False

In [None]:
if run_cell:
    n_rows = len(cog_measures)
    n_cols = 3
    f, ax = plt.subplots(n_rows, n_cols, figsize=(2.5*n_cols, 2.5*n_rows))
    for plot_row, y_var in enumerate(cog_measures):
        print(y_var)
        effect_map, pvals = run_regression(df_behavior, y_var, rest_timescales, task_timescales)
        plot_col = 0
        # for effect in ['rest', 'task', 'rest:task']:
        for effect in ['rest:task',]:
            x = sa_axis[:, np.newaxis]
            y = effect_map[effect].values[:, np.newaxis]
            x_quad = np.hstack([x, x**2])
            quad_model = LinearRegression()
            quad_model.fit(x_quad, y)
            y_pred_quad = quad_model.predict(x_quad)
            r2_quad = r2_score(y, y_pred_quad)*100

            reg_plot(x, y, xlabel='SA axis', ylabel='coef({0})'.format(effect), ax=ax[plot_row, plot_col], annotate='both', order=2, add_pval=False)  # 'coef({0}, {1})'.format(effect,y_var)
            if plot_col == 0:
                if y_var in task_performance_measures:
                    ax[plot_row, plot_col].set_title('{:} (in scanner)\nr2: {:.2f}%'.format(y_var, r2_quad))
                else:
                    ax[plot_row, plot_col].set_title('{:} (out of scanner)\nr2: {:.2f}%'.format(y_var, r2_quad))
            plot_col += 1

            # if effect == 'rest:task':
            #     plot_data = effect_map[effect].values
            #     # p_map = pvals[effect].values
            #     # p_map = get_fdr_p(pvals[effect])
            #     # plot_data[p_map>0.05] = np.nan
            #     f2 = surface_plot(
            #         data=effect_map[effect].values,
            #         lh_annot_file=lh_annot_file,
            #         rh_annot_file=rh_annot_file,
            #         fsaverage=fsaverage,
            #         order="lr",
            #         cmap="coolwarm",
            #         # title_str='corr(int_deflections, behavior)'
            #     )
            #     f2.savefig(os.path.join(outdir, 'interaction_brainplot.png'), dpi=600, bbox_inches="tight", pad_inches=0.01)
        
        q_val = 0.2
        lower_q = sa_axis < np.quantile(sa_axis, q=q_val)
        upper_q = sa_axis > np.quantile(sa_axis, q=1-q_val)

        rest_timescales_mean_bin = nuis_reg(nuis_covs, rest_timescales)[lower_q, :].mean(axis=0)
        task_timescales_mean_bin = nuis_reg(nuis_covs, task_timescales)[lower_q, :].mean(axis=0)
        sns.regplot(y=df_behavior[y_var][y_var_filter], x=rest_timescales_mean_bin[y_var_filter], ax=ax[plot_row, plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='rest', line_kws={'linestyle': '-'}, color='orange', order=1)
        sns.regplot(y=df_behavior[y_var][y_var_filter], x=task_timescales_mean_bin[y_var_filter], ax=ax[plot_row, plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='task', line_kws={'linestyle': '--'}, color='orange', order=1)
        ax[plot_row, plot_col].set_xlabel('INTs')
        ax[plot_row, plot_col].set_ylabel(y_var)
        ax[plot_row, plot_col].set_title('Sensorimotor (bottom {0}%)'.format(int(q_val*100)))
        ax[plot_row, plot_col].legend()
        plot_col += 1
        
        rest_timescales_mean_bin = nuis_reg(nuis_covs, rest_timescales)[upper_q, :].mean(axis=0)
        task_timescales_mean_bin = nuis_reg(nuis_covs, task_timescales)[upper_q, :].mean(axis=0)
        sns.regplot(y=df_behavior[y_var][y_var_filter], x=rest_timescales_mean_bin[y_var_filter], ax=ax[plot_row, plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='rest', line_kws={'linestyle': '-'}, color='purple', order=1)
        sns.regplot(y=df_behavior[y_var][y_var_filter], x=task_timescales_mean_bin[y_var_filter], ax=ax[plot_row, plot_col], scatter=False, marker='.', scatter_kws={"s": 1}, label='task', line_kws={'linestyle': '--'}, color='purple', order=1)
        ax[plot_row, plot_col].set_xlabel('INTs')
        ax[plot_row, plot_col].set_ylabel(y_var)
        ax[plot_row, plot_col].set_title('Association (top {0}%)'.format(int(q_val*100)))
        ax[plot_row, plot_col].legend()

    for this_ax in ax.reshape(-1):
        sns.despine(right=True, top=True, ax=this_ax)

    f.suptitle(which_task+'\n\n')
    f.tight_layout()
    plt.show()

# Quantile effects

In [None]:
run_cell = False

In [None]:
if run_cell:
    lower_q = sa_axis < np.quantile(sa_axis, q=0.2)
    upper_q = sa_axis > np.quantile(sa_axis, q=0.8)
    print(sa_axis[lower_q].mean(), sa_axis[upper_q].mean())

    f, ax = plt.subplots(1, 2, figsize=(5, 2.75))

    rest_timescales_mean_bin = np.nanmean(rest_timescales[lower_q, :], axis=0)
    task_timescales_mean_bin = np.nanmean(task_timescales[lower_q, :], axis=0)

    int_deflections = task_timescales_mean_bin - rest_timescales_mean_bin
    reg_plot(df_behavior[y_var][y_var_filter], int_deflections[y_var_filter], xlabel=y_var, ylabel='INTs_def', ax=ax[0], annotate='both')
    ax[0].set_title('Lower 20%\n(i.e., S regions)')

    rest_timescales_mean_bin = np.nanmean(rest_timescales[upper_q, :], axis=0)
    task_timescales_mean_bin = np.nanmean(task_timescales[upper_q, :], axis=0)

    int_deflections = task_timescales_mean_bin - rest_timescales_mean_bin
    reg_plot(df_behavior[y_var][y_var_filter], int_deflections[y_var_filter], xlabel=y_var, ylabel='INTs_def', ax=ax[1], annotate='both')
    ax[1].set_title('Upper 20%\n(i.e., A regions)')

    f.suptitle(which_task)
    f.tight_layout()
    plt.show()

In [None]:
if run_cell:
    lower_q = df_behavior[y_var] < df_behavior[y_var].quantile(q=0.2)
    upper_q = df_behavior[y_var] > df_behavior[y_var].quantile(q=0.8)
    print(df_behavior[y_var][lower_q].mean(), df_behavior[y_var][upper_q].mean())

    f, ax = plt.subplots(1, 2, figsize=(5, 2.5))

    rest_timescales_mean_bin = np.nanmean(rest_timescales[:, lower_q.values], axis=1)
    task_timescales_mean_bin = np.nanmean(task_timescales[:, lower_q.values], axis=1)

    int_deflections = compute_deflections(rest_timescales_mean_bin, task_timescales_mean_bin, nuisance_regression=False)
    reg_plot(sa_axis, int_deflections, xlabel='SA axis', ylabel='INTs_def', ax=ax[0], annotate='both')
    ax[0].set_title('Lower 20%\n({0})'.format(y_var))

    rest_timescales_mean_bin = np.nanmean(rest_timescales[:, upper_q.values], axis=1)
    task_timescales_mean_bin = np.nanmean(task_timescales[:, upper_q.values], axis=1)

    int_deflections = compute_deflections(rest_timescales_mean_bin, task_timescales_mean_bin, nuisance_regression=False)
    reg_plot(sa_axis, int_deflections, xlabel='SA axis', ylabel='INTs_def', ax=ax[1], annotate='both')
    ax[1].set_title('Upper 20%\n({0})'.format(y_var))

    f.suptitle(which_task)
    f.tight_layout()
    plt.show()