# Beta Band Regressions
Created: 04/08/24 \
Updated: 06/15/2024 \
*running analyses for BRAIN Initiative Conf - updating mixed effects regressions to be across roi only*

In [None]:
import numpy as np
import mne
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from scipy.stats import zscore, linregress, ttest_ind, ttest_rel, ttest_1samp, pearsonr, spearmanr
import pandas as pd
from mne.preprocessing.bads import _find_outliers
import os 
import joblib
import re
import datetime
import scipy
import random
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.regression.mixed_linear_model import MixedLM 
from joblib import Parallel, delayed
import pickle
import itertools
import time 
from matplotlib.ticker import StrMethodFormatter



import warnings
warnings.filterwarnings('ignore')




In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Specify root directory for un-archived data and results 
base_dir   = '/sc/arion/projects/guLab/Alie/SWB/'
anat_dir   = f'{base_dir}ephys_analysis/recon_labels/'
neural_dir = f'{base_dir}ephys_analysis/data/'
behav_dir  = f'{base_dir}swb_behav_models/data/behavior_preprocessed/'
save_dir   = f'{base_dir}ephys_analysis/results/regression_data/'
script_dir = '/hpc/users/finka03/swb_ephys_analysis/scripts/'

date = datetime.date.today().strftime('%m%d%Y')
print(date)

In [None]:
import sys
sys.path.append(f'{base_dir}ephys_analysis/LFPAnalysis/')

from LFPAnalysis import analysis_utils

sys.path.append(f'{script_dir}analysis_notebooks/')

from ieeg_tools import *

In [None]:
subj_ids = list(pd.read_excel(f'{base_dir}ephys_analysis/subj_info/SWB_subjects.xlsx', sheet_name='Usable_Subjects', usecols=[0]).PatientID)
n_subj = len(subj_ids)
# subj_ids


In [None]:
bdi_list = pd.read_excel(f'{base_dir}ephys_analysis/subj_info/SWB_subjects.xlsx', sheet_name='Usable_Subjects', usecols=[3])
bdi_list = list(bdi_list.SWB_BDI)
subj_info_df = pd.DataFrame({'subj_id':subj_ids,'bdi':bdi_list})
subj_info_df



In [None]:
#### set theme for all plots
# %matplotlib inline
from matplotlib import rc
rc('text', usetex=False)

sns.set_theme(style='ticks') 
sns.set_context("poster")



In [None]:
from matplotlib.colors import ListedColormap,LinearSegmentedColormap
import matplotlib as mpl
from matplotlib import cm


# saez_cmap colors = org,yellow, sagegreen, darker green, blue, purple 

# 671351
saez_colors    = ['#ffb243','#FCE27C','#6a9752','#1d5d2a','#00467e','#561C46']
saez_hex_cmap  = ListedColormap(saez_colors)
saez_rgb       = [mpl.colors.to_rgb(hex) for hex in saez_colors]
saez_full_cmap = LinearSegmentedColormap.from_list('saez_full_cmap',saez_rgb,N=10)
#plot_palette = saez_linear_cmap.resampled(lutsize=30)
expanded_cmap  = [mpl.colors.rgb2hex(saez_full_cmap(c)[:3]) for c in range(saez_full_cmap.N)]


In [None]:
saez_full_cmap

# Load beta power data 

In [None]:
# #### Try Beta Power 
# all_subj_beta_power = {}
# all_subj_metadata   = {}

# beta_range = [13,30]

# # load pow epochs data 
# for subj_id in subj_ids: 
#     # load power epochs for single subject 
#     power_epochs = mne.time_frequency.read_tfrs(f'{neural_dir}{subj_id}/{subj_id}_CpeOnset-tfr.h5')[0]
#     all_subj_metadata[subj_id] = power_epochs.metadata.copy()
    
#     # extract power df with mean beta power for each epoch for each elec 
#     power_df = power_epochs.to_data_frame()
#     del power_epochs 
#     beta_power_df = power_df[(power_df.freq >= beta_range[0])&(power_df.freq <= beta_range[1])].groupby(['epoch']).agg('mean').reset_index()
#     del power_df 
#     all_subj_beta_power[subj_id] = beta_power_df
#     del beta_power_df
    
    
# # create a binary pickle file 
# pickle.dump(all_subj_beta_power,open(f'{save_dir}all_subj_beta_power_nocrop.pkl',"wb"))
# pickle.dump(all_subj_metadata,open(f'{save_dir}all_subj_metadata.pkl',"wb"))

    

# Make Beta Regression DF

1. Update metadata vars - add t+1 vars, normalize vecs by 2std
2. Aggregate metadata, power data, ROI info, bdi info into one master df 

In [None]:
all_subj_beta_power = pickle.load(open(f'{save_dir}all_subj_beta_power_nocrop.pkl',"rb"))

all_subj_metadata = pickle.load(open(f'{save_dir}all_subj_metadata.pkl',"rb"))

In [None]:
### update metadata - zscore continuous regressors + add RT info
for subj_id in subj_ids:
    subj_df         = all_subj_metadata[subj_id]
    rounds_t1       = subj_df.Round.tolist()[1:]
    logRT_list      = np.log(subj_df['RT']).tolist() # take log of RT
    logRT_t1_list   = logRT_list[1:] # get logRT_t1
    GambleChoice_t1 = subj_df.GambleChoice.tolist()[1:]
    GambleEV_t1     = subj_df.GambleEV.tolist()[1:]
    TrialType_t1    = subj_df.TrialType.tolist()[1:]
    TrialEV         = subj_df.GambleEV - subj_df.SafeBet
    TrialEV_t1      = TrialEV[1:].tolist()
    rounds_t1.append(np.nan)
    logRT_t1_list.append(np.nan) # add nan to logRT_t1 (no t1 for last round)
    GambleChoice_t1.append(np.nan)
    GambleEV_t1.append(np.nan)
    TrialType_t1.append(np.nan) 
    TrialEV_t1.append(np.nan)
    subj_df['logRT']           = logRT_list
    subj_df['logRT_t1']        = logRT_t1_list
    subj_df['Round_t1']        = rounds_t1
    subj_df['GambleChoice_t1'] = GambleChoice_t1
    subj_df['GambleEV_t1']     = GambleEV_t1
    subj_df['TrialType_t1']    = TrialType_t1
    subj_df['TrialEV']         = TrialEV
    subj_df['TrialEV_t1']      = TrialEV_t1
    subj_df['bdi']             = subj_info_df.bdi[subj_info_df.subj_id==subj_id].values[0]
    subj_df['subj_id']         = subj_id
    
    # update metadata
    all_subj_metadata[subj_id] = subj_df
    
    
# pickle.dump(all_subj_metadata,open(f'{save_dir}all_subj_metadata_t1.pkl',"wb"))
    
### debugging/testing 

# round_test = [list(zip(all_subj_metadata[key].Round.tolist(),all_subj_metadata[key].RT.tolist(),all_subj_metadata[key].logRT.tolist(),all_subj_metadata[key].logRT_t1.tolist())) for key in all_subj_metadata.keys()]
# is_consecutive = [all([all_subj_metadata[subj_id].Round[i] == all_subj_metadata[subj_id].Round_t1[i-1] for i in range(1, len(all_subj_metadata[subj_id].Round))]) for subj_id in all_subj_metadata.keys()]


    

In [None]:
# ### load ROI reref master 


# # # load newest file
# # paths = glob(f'{save_dir}roi_reref_labels_master_*.csv')
# # load_date = str(np.max([int(path.split('_')[-1][:-4]) # find most recent saved df to load 
# #                         for path in paths]))
# # roi_reref_labels_master_df = pd.read_csv(glob(f'{save_dir}roi_reref_labels_master_0{load_date}.csv')[0])


roi_reref_labels_master_df = pd.read_csv(glob(f'{base_dir}ephys_analysis/results/roi_info/roi_reref_labels_master.csv')[0])
roi_reref_labels_master_df = roi_reref_labels_master_df.drop(columns=['Unnamed: 0'])




In [None]:
continuous_regressors = ['TrialEV','TrialEV_t1','GambleEV','GambleEV_t1','SafeBet',
                         'TotalProfit','CR', 'choiceEV', 'RPE','Profit',
                         'decisionCPE', 'decisionRegret','decisionRelief']

In [None]:
#### add roi info to master df 
beta_master_df = []


for subj_id,pow_df in zip(all_subj_beta_power.keys(),all_subj_beta_power.values()):
    
     # reshape pow_df 
        pow_reg_df = pow_df.melt(id_vars=['epoch','freq'],
                                    value_vars=pow_df.columns.drop(['epoch','time','freq']).tolist().copy(),
                                    var_name='reref_ch_names', value_name='band_pow',
                                    ignore_index = False)

        pow_reg_df['subj_id'] = subj_id # add subj_id to df values
        # reshape metadata df 
        metadata_df = all_subj_metadata[subj_id].reset_index()
        metadata_df = metadata_df.rename(columns={'index':'epoch'})

        # merge pow and metadata dfs keeping epoch as index 
        subj_regression_df = pd.merge(pow_reg_df, metadata_df,left_on=['epoch','subj_id'],
                                    right_on=['epoch','subj_id']).reset_index(drop=True).set_index('epoch')
        
        subj_regression_df['epoch_num'] = subj_regression_df.index # keep epoch as a column

        # add roi info 
        subj_epochs = pow_reg_df.epoch.unique()
        # extract roi_reref_labels_df for each subj 
        subj_roi_reref_labels = roi_reref_labels_master_df[roi_reref_labels_master_df['subj_id']==subj_id]
        # add epochs to df - repeat each elec id for every epoch to join with master df 
        subj_roi_reref_labels = roi_reref_labels_master_df[roi_reref_labels_master_df['subj_id']==subj_id]
        subj_roi_reref_labels = (pd.concat([subj_roi_reref_labels]*len(subj_epochs),
                                keys = subj_epochs,
                                names = ['epoch',None])).reset_index(level=1,drop=True) #make sure epochs are the index
        subj_roi_reref_labels['epoch_num'] = subj_roi_reref_labels.index

        # create unique reref ch id for regression
        subj_roi_reref_labels['unique_reref_ch'] = subj_roi_reref_labels[['subj_id', 'reref_ch_names']].agg('_'.join, axis=1)

        # merge regression df with roi reref info df  
        subj_regression_df = pd.merge(subj_regression_df, subj_roi_reref_labels,left_on=['subj_id','reref_ch_names','epoch_num'],
                                    right_on=['subj_id','reref_ch_names','epoch_num']).reset_index(drop=True).set_index('epoch_num')
        
        subj_regression_df['epoch'] = subj_regression_df.index
        
        ### cleaning within subject BEFORE zscoring regressors! 
        
        # remove trials where RT is too fast (random choice)
        subj_regression_df = subj_regression_df[subj_regression_df.RT>0.3]
        # remove trials where there is no choice at trial t
        subj_regression_df = subj_regression_df[(subj_regression_df.GambleChoice=='gamble') | 
                                                (subj_regression_df.GambleChoice == 'safe')].reset_index(drop=True)
        # remove trials where there is no choice for trial t1 (RT_t1 is nonsense)
        subj_regression_df = subj_regression_df[(subj_regression_df.GambleChoice_t1=='gamble') | 
                                                (subj_regression_df.GambleChoice_t1 == 'safe')].reset_index(drop=True)
        # remove trials where there is no outcome
        subj_regression_df = subj_regression_df[(subj_regression_df.Outcome=='good') | 
                                                (subj_regression_df.Outcome == 'bad')].reset_index(drop=True)
        # remove epoch 76 bc photodiode times are not correct at breakpoint 
        subj_regression_df = subj_regression_df[subj_regression_df.Round != 76]
        # remove trials where logRT_t1 is inf or -inf (means RT = 0)
        subj_regression_df = subj_regression_df[~np.isinf(subj_regression_df.logRT_t1)]
        # format GambleChoice + TrialType vars as categories 
        subj_regression_df['GambleChoice']    = subj_regression_df['GambleChoice'].astype('category')
        subj_regression_df['GambleChoice_t1'] = subj_regression_df['GambleChoice_t1'].astype('category')
        subj_regression_df['TrialType']       = subj_regression_df['TrialType'].astype('category')
        subj_regression_df['TrialType_t1']    = subj_regression_df['TrialType_t1'].astype('category')
        # make sure Round_t1 is an integer
        subj_regression_df['Round_t1']        = subj_regression_df['Round_t1'].astype('Int64')
        subj_regression_df['TrialEV']         = subj_regression_df.GambleEV - subj_regression_df.SafeBet
        
        # zscore continuous regressors AFTER removing bad trials!! 
        for reg in continuous_regressors:
            subj_regression_df[reg] = norm_zscore(subj_regression_df[reg].values) # zscore is /2std

        beta_master_df.append(subj_regression_df)


    
    
beta_master_df = pd.concat(beta_master_df)   
# reorder columns 
beta_master_df = beta_master_df[['subj_id','bdi','epoch','freq','reref_ch_names','band_pow','unique_reref_ch',
                                 'roi','hemi','ch_label4roi','ch_type4roi','loc4roi','Round', 'Round_t1',
                                 'logRT_t1','logRT','RT','TrialType','TrialType_t1','GambleChoice',
                                 'GambleChoice_t1','GambleEV', 'GambleEV_t1','SafeBet','TrialEV','TrialEV_t1',
                                 'Profit','TotalProfit','RPE','decisionCPE', 'decisionRegret','decisionRelief']]



In [None]:
# # beta_master_df.to_csv(f'{save_dir}beta_master_data.csv')
beta_master_df.to_csv(f'{save_dir}beta_master_data_BRAINconf_interactionnocrop.csv')

# beta_master_df

In [None]:
beta_master_df

In [None]:
roi_reref_labels_master_df.roi.value_counts()

In [None]:
#### filter df by rois with electrode counts >= 50 *** updated to >15
keep_rois = roi_reref_labels_master_df.roi.value_counts().index[roi_reref_labels_master_df.roi.value_counts(
).values>=50].tolist()

# keep_rois = roi_reref_labels_master_df.roi.value_counts().index[
#     roi_reref_labels_master_df.roi.value_counts().values>=50].tolist()
# keep_rois = ['dlpfc','vlpfc','dmpfc','ofc','ains','pins','acc','hpc','amy','mtg']     
# keep_rois


In [None]:
####### NEW DATA FRAME 
beta_reg_df = beta_master_df.copy()
beta_reg_df = beta_reg_df[beta_reg_df.roi.isin(keep_rois)]
len(beta_reg_df)

In [None]:
np.max(beta_reg_df.band_pow)

In [None]:
beta_reg_df.roi.unique().tolist()

In [None]:
beta_reg_df.columns

# INTERACTION FEATURE SELECTION

In [None]:
######## make all possible variable combinations for regressions 
main_var  = ['decisionCPE','band_pow','decisionCPE:band_pow']

### iter vars for RT prediction dfs > 
iter_vars = ['C(GambleChoice)','TrialEV','TotalProfit',
             'C(GambleChoice_t1)','TrialEV_t1']


In [None]:
iter_vars

In [None]:
var_combos = [[v]+main_var for v in iter_vars]

for i in range(1,len(iter_vars)-1):
    iter_list = [list(tup)+main_var for tup in list(itertools.combinations(iter_vars,i+1))]
    var_combos.extend(iter_list)

# remove var combos that contain Choice + Trial Type for the *same trial* only bc they are collinear
var_combos = [combo for combo in var_combos if ('C(GambleChoice_t1)' in combo and 'C(TrialType_t1)' not in combo)|
              ('C(GambleChoice_t1)' not in combo and 'C(TrialType_t1)' in combo)|('C(GambleChoice_t1)'not in combo and 'C(TrialType_t1)' not in combo)]

var_combos = [combo for combo in var_combos if ('C(GambleChoice)' in combo and 'C(TrialType)' not in combo)|
              ('C(GambleChoice)' not in combo and 'C(TrialType)' in combo)|('C(GambleChoice)'not in combo and 'C(TrialType)' not in combo)]


var_combos = [combo for combo in var_combos if ('C(GambleChoice)' in combo and 'TrialEV' not in combo)|
              ('C(GambleChoice)' not in combo and 'TrialEV' in combo)|('C(GambleChoice)'not in combo and 'TrialEV' not in combo)]


var_combos = [combo for combo in var_combos if ('TrialEV' in combo and 'C(TrialType)' not in combo)|
              ('TrialEV' not in combo and 'C(TrialType)' in combo)|('TrialEV'not in combo and 'C(TrialType)' not in combo)]

var_combos = [combo for combo in var_combos if ('C(GambleChoice_t1)' in combo and 'TrialEV_t1' not in combo)|
              ('C(GambleChoice_t1)' not in combo and 'TrialEV_t1' in combo)|('C(GambleChoice_t1)'not in combo and 'TrialEV_t1' not in combo)]

var_combos = [combo for combo in var_combos if ('TrialEV_t1' in combo and 'C(TrialType_t1)' not in combo)|
              ('TrialEV_t1' not in combo and 'C(TrialType_t1)' in combo)|('TrialEV_t1'not in combo and 'C(TrialType_t1)' not in combo)]
var_combos

In [None]:
len(var_combos)

In [None]:
# import time 
start = time.time() # start timer

all_models = {}
all_rsq    = {}
all_stats  = {}
all_keys   = {}
no_conv    = []

for test_vars in var_combos:
    model_key = ''.join(test_vars[:-1]) ### make string of covariate var names to use as dict key
    model_fit = fit_mixed_model(beta_reg_df,test_vars,outcome_var,rand_eff_var,reml=False)
    # check if model converged
    if model_fit.converged:
        
        print(model_key,' converged!')
        
        all_models[model_key] = model_fit
        all_rsq[model_key]    = compute_marginal_rsq(model_fit)
        all_stats[model_key]  = mixed_effects_ftest_ttest(model_fit)
        all_keys[model_key]   = test_vars
        
    else:
        print(model_key,' did not converge :(')
        no_conv.append(model_key)
    del model_fit

end = time.time()
print('{:.4f} s'.format(end-start)) # print time elapsed for computation (approx 20 seconds per channel)    

In [None]:
no_conv

In [None]:
list(all_models.keys())

In [None]:
all_models['TrialEVdecisionCPEband_pow'].bic

In [None]:
model_summary = {key:all_models[key].bic for key in all_models.keys()}
model_summary

In [None]:
bic_list = [model_summary[key] for key in model_summary.keys()]
bic_list.sort()

ordered_list = [list(model_summary.keys())[np.where(list(model_summary.values()) == bic)[0][0]] for bic in bic_list]
list(zip(ordered_list,bic_list))


In [None]:
all_models['TotalProfitTrialEV_t1decisionCPEband_pow'].pvalues

In [None]:
[(key,bic,all_models[key].pvalues['decisionCPE:band_pow']) for key,bic in list(zip(ordered_list,bic_list))]  


In [None]:
print('worst bic model: ',list(model_summary.keys())[np.where(list(model_summary.values())==np.max(list(model_summary.values())))[0][0]])


In [None]:
print('best bic model: ',list(model_summary.keys())[np.where(list(model_summary.values())==np.min(list(model_summary.values())))[0][0]])


In [None]:
print(list(model_summary.keys())[np.where(list(model_summary.values())==np.min(list(model_summary.values())))[0][0]])
print(np.min(list(model_summary.values())))
      
      

In [None]:
[(key,all_models[key].pvalues['decisionCPE:band_pow']) for key in ordered_list if all_models[key].pvalues['decisionCPE:band_pow'] < 0.05]  



In [None]:
[(key,all_models[key].pvalues['decisionCPE:band_pow']) for key in ordered_list if all_models[key].pvalues['decisionCPE:band_pow'] > 0.05]  


In [None]:
all_models['TrialEVTotalProfitdecisionCPEband_pow'].pvalues

In [None]:
all_models['TrialEVTotalProfitdecisionCPEband_pow'].params

In [None]:
all_models['TotalProfitTrialEV_t1decisionCPEband_pow'].summary()

In [None]:
all_rsq

In [None]:
all_models

In [None]:
all_stats['TrialEVdecisionCPEband_pow']

In [None]:
all_stats

In [None]:
#### FOR NOW JUST SAVE CPE STATS!
t_test_results_summary = [pd.DataFrame({'ttest_fe_tval':stats[0].t.tolist()[-1],
                                        'ttest_fe_coef' :stats[0].coef.tolist()[-1],
                                        'ttest_fe_pval':stats[0]['P>|t|'].tolist()[-1],
                                        'model_id':idx},index=[0]) 
                          for idx,stats in all_stats.items()]


t_test_results_summary = pd.concat(t_test_results_summary).reset_index(drop=True)
t_test_results_summary



In [None]:
f_test_results_summary = [pd.DataFrame({'ftest_fvalue':stats[1].fvalue,
                                        'ftest_re_df':stats[1].df_num,
                                        'ftest_f_pval':stats[1].pvalue,
                                        'model_id':idx},index=[0]) 
                          for idx,stats in all_stats.items()]

f_test_results_summary = pd.concat(f_test_results_summary).reset_index(drop=True)
f_test_results_summary

In [None]:
tstat_fstat_df = pd.merge(t_test_results_summary,f_test_results_summary,left_on=['model_id'],
                          right_on=['model_id']).reset_index(drop=True)

tstat_fstat_df = tstat_fstat_df[['model_id','ttest_fe_tval', 'ttest_fe_coef', 'ttest_fe_pval',
                                 'ftest_fvalue', 'ftest_re_df', 'ftest_f_pval']]
tstat_fstat_df



In [None]:
tstat_fstat_df['rsq'] = [val for key,val in all_rsq.items()]

In [None]:
tstat_fstat_df

In [None]:
all_rsq

In [None]:

tstat_fstat_df['bic']         = [all_models[key].bic for key in all_models.keys()]


In [None]:
tstat_fstat_df

In [None]:
tstat_fstat_df['fe_cpe_est'] = [all_models[key].params['decisionCPE:band_pow'] for key in all_models.keys()]

In [None]:
tstat_fstat_df['fe_cpe_pval'] = [all_models[key].pvalues['decisionCPE:band_pow'] for key in all_models.keys()]

In [None]:
tstat_fstat_df



In [None]:
#### multiple comparisons correction!!!!
bonferr_alpha = 0.05/len(tstat_fstat_df)
bonferr_alpha

In [None]:

tstat_fstat_df.sort_values(by='fe_cpe_pval')

In [None]:
# tstat_fstat_df = tstat_fstat_df[tstat_fstat_df.fe_cpe_pval < bonferr_alpha]
tstat_fstat_df.sort_values(by='ttest_fe_pval')

In [None]:
save_df = tstat_fstat_df.copy()

In [None]:
tstat_fstat_df.sort_values(by='ttest_fe_pval')

In [None]:
tstat_fstat_df.ttest_fe_pval

In [None]:
print('Best BIC model: ',tstat_fstat_df.model_id[tstat_fstat_df.bic == np.min(tstat_fstat_df.bic)].values[0])


In [None]:
print('Best fe_cpe_pval model: ',tstat_fstat_df.model_id[tstat_fstat_df.fe_cpe_pval == np.min(tstat_fstat_df.fe_cpe_pval)].values[0])


In [None]:
print('Best rsq model: ',tstat_fstat_df.model_id[tstat_fstat_df.rsq == np.max(tstat_fstat_df.rsq)].values[0])


In [None]:
print('Best ttest pval: ',tstat_fstat_df.model_id[tstat_fstat_df.ttest_fe_pval == 
                                                  np.min(tstat_fstat_df.ttest_fe_pval)].values[0])


In [None]:
print('Best tstat: ',tstat_fstat_df.model_id[tstat_fstat_df.ttest_fe_tval == 
                                                  np.max(tstat_fstat_df.ttest_fe_tval)].values[0])


In [None]:
print('Best ftest pval: ',tstat_fstat_df.model_id[tstat_fstat_df.ftest_f_pval == 
                                                  np.min(tstat_fstat_df.ftest_f_pval)].values[0])

In [None]:
print('Best Fstat model: ',tstat_fstat_df.model_id[tstat_fstat_df.ftest_fvalue == 
                                                  np.max(tstat_fstat_df.ftest_fvalue)].values[0])

In [None]:
all_models_fe_data   = {}

for key in all_models.keys():
#     model_unique_elecs = all_models[key].random_effects.keys()
    model_fe_params = {f'{param}':all_models[key].fe_params[param] for param in all_models[key].fe_params.index}
    model_fe_params['model_id'] = key
    all_models_fe_data[key] = model_fe_params



In [None]:
all_models_fe_data

In [None]:
model_plot_data = {}

for key in all_models.keys():
    fe_param_names = list(all_models[key].fe_params.index)
    
    plot_data = pd.concat([pd.DataFrame(
    {**{'subj_id':elec_id.split('_')[0],'reref_ch_names':elec_id.split('_')[1]},'unique_reref_ch':elec_id,
     **{f'raw_{param}':all_models[key].random_effects[elec_id][param] for param in list(
            all_models[key].random_effects[elec_id].index)}},index=[range(len(fe_param_names))])
    for elec_id in all_models[key].random_effects.keys()]).reset_index(drop=True)
    
    plot_data['roi'] = [beta_reg_df.roi[beta_reg_df.unique_reref_ch == row.unique_reref_ch].values[0]
                                for idx,row in plot_data.iterrows()]

    plot_data['bdi'] = [subj_info_df.bdi[subj_info_df.subj_id == subj_id].values[0]
                                for subj_id in plot_data.subj_id]
    model_plot_data[key] = plot_data



In [None]:
plot_data

In [None]:
# plot_data = pd.concat([pd.DataFrame(
#     {**{'subj_id':elec_id.split('_')[0],'reref_ch_names':elec_id.split('_')[1]},'unique_reref_ch':elec_id,
#      **{f'raw_{param}':interaction_fit.random_effects[elec_id][param] for param in list(
#             interaction_fit.random_effects[elec_id].index)}},index=[range(len(fe_param_names))])
        
#     for elec_id in interaction_fit.random_effects.keys()]).reset_index(drop=True)

# plot_data['roi'] = [beta_reg_df.roi[beta_reg_df.unique_reref_ch == row.unique_reref_ch].values[0]
#                                 for idx,row in plot_data.iterrows()]

# plot_data['bdi'] = [subj_info_df.bdi[subj_info_df.subj_id == subj_id].values[0]
#                                 for subj_id in plot_data.subj_id]

In [None]:
all_models_fe_data

In [None]:
all_models_fe_data[key]

In [None]:
plot_data

In [None]:

for key in model_plot_data.keys():
    plot_data = model_plot_data[key]
    fe_params = all_models_fe_data[key]
    fe_intercept = fe_params['Intercept']
    fe_cpe_beta  = fe_params['decisionCPE:band_pow']
    elec_b0      =  np.add(plot_data.raw_Group,fe_intercept)
    elec_bCPE    =  np.add(plot_data['raw_decisionCPE:band_pow'],fe_cpe_beta)
    plot_data['elec_b0']   = elec_b0
    plot_data['elec_bCPE'] = elec_bCPE
    plot_data = plot_data.reset_index(drop=True)
    model_plot_data[key] = plot_data
    
    
    

In [None]:
model_plot_data['TrialEVdecisionCPEband_pow']

In [None]:
ttest_1samp(model_plot_data[key][model_plot_data[key].roi == 'ains']['raw_decisionCPE:band_pow'].to_numpy(),popmean=0)

In [None]:
RAW_model_roi_stats_data = {f'{key}':{f'{roi}':ttest_1samp(model_plot_data[key][model_plot_data[key].roi == roi]['raw_decisionCPE:band_pow'].to_numpy(),
                                     popmean=0).pvalue for roi in model_plot_data[key].roi.unique().tolist()
                                  } for key in model_plot_data.keys()}

In [None]:
RAW_model_roi_stats_data

In [None]:
model_roi_stats_data = {f'{key}':{f'{roi}':ttest_1samp(model_plot_data[key][model_plot_data[key].roi == roi].elec_bCPE.to_numpy(),
                                     popmean=0).pvalue for roi in model_plot_data[key].roi.unique().tolist()
                                  } for key in model_plot_data.keys()}


In [None]:
model_roi_stats_data

In [None]:
bonferr_alpha = 0.05/len(np.unique(plot_data.roi))
# bonferr_alpha = 0.05/len(bdi_plot_data.roi.unique().tolist())
bonferr_alpha

In [None]:
roi = 'ains'

In [None]:
roi in roi_order

In [None]:
print([[(roi,RAW_model_roi_stats_data[model][roi]) for 
                                                    roi in RAW_model_roi_stats_data[model].keys()
                                                    if roi in roi_order]
                                                   for model in RAW_model_roi_stats_data.keys()
                                                   ])

In [None]:
print('sig ROIs after multiple comp correction: ',[[(roi,RAW_model_roi_stats_data[model][roi]) for 
                                                    roi in RAW_model_roi_stats_data[model].keys()
                                                    if RAW_model_roi_stats_data[model][roi]< bonferr_alpha]
                                                   for model in RAW_model_roi_stats_data.keys()
                                                   ])


In [None]:
print('sig ROIs after multiple comp correction: ',[[(roi,model_roi_stats_data[model][roi]) for 
                                                    roi in model_roi_stats_data[model].keys()
                                                    if model_roi_stats_data[model][roi]< bonferr_alpha]
                                                   for model in model_roi_stats_data.keys()
                                                   ])




In [None]:
model_roi_stats_data.keys()

In [None]:
roi_order = ['dlpfc','vlpfc','dmpfc','ofc','ains','pins','acc','hpc','amy']

In [None]:
print([[(model,RAW_model_roi_stats_data[model][roi]) for  roi in RAW_model_roi_stats_data[model].keys()
                                                    if roi =='ains']
                                                   for model in RAW_model_roi_stats_data.keys()
                                                   ])

In [None]:
print([[(model,model_roi_stats_data[model][roi]) for roi in model_roi_stats_data[model].keys()
                                                    if roi =='ains']
                                                   for model in model_roi_stats_data.keys()
                                                   ])

In [None]:
test_model = 'TrialEVdecisionCPEband_pow'
# 'TrialEVTotalProfitdecisionCPEband_pow'


In [None]:
roi_bdi_plot_data = model_plot_data[test_model]
roi_bdi_plot_data = roi_bdi_plot_data[roi_bdi_plot_data.roi.isin(roi_order)]

roi_bdi_plot_data

In [None]:
roi_bdi_plot_data

In [None]:
[roi_bdi_plot_data]

In [None]:
print('sig ROIs after multiple comp correction: ',[(roi,ttest_1samp(roi_bdi_plot_data[
    roi_bdi_plot_data.roi == roi]['raw_decisionCPE:band_pow'].to_numpy(),
                                     popmean=0).pvalue) for roi in roi_bdi_plot_data.roi.unique().tolist()
                                                  if ttest_1samp(roi_bdi_plot_data[
    roi_bdi_plot_data.roi == roi]['raw_decisionCPE:band_pow'].to_numpy(),
                                     popmean=0).pvalue < bonferr_alpha])




In [None]:
print('sig ROIs after multiple comp correction: ',[(roi,ttest_1samp(roi_bdi_plot_data[
    roi_bdi_plot_data.roi == roi]['elec_bCPE'].to_numpy(),
                                     popmean=0).pvalue) for roi in roi_bdi_plot_data.roi.unique().tolist()
                                                  if ttest_1samp(roi_bdi_plot_data[
    roi_bdi_plot_data.roi == roi]['elec_bCPE'].to_numpy(),
                                     popmean=0).pvalue < bonferr_alpha])




In [None]:
sns.set_context("poster")
fig,ax = plt.subplots(figsize=(15, 6),dpi=500)

fig_save_dir = f'{base_dir}ephys_analysis/figs/beta_regressions/'
os.makedirs(fig_save_dir,exist_ok=True)

sns.barplot(data = roi_bdi_plot_data,
            x = 'roi',
            y = 'raw_decisionCPE:band_pow',
#             y='elec_bCPE',
            color = '#125549',
            saturation=0.8,
            order=roi_order,
            errorbar=None) 


sns.stripplot(data = roi_bdi_plot_data,
            x = 'roi',
            y = 'raw_decisionCPE:band_pow',                          
#               y='elec_bCPE',
              color = '#125549',
              alpha=.4,
              dodge=True,
              order=roi_order,
              size = 8)

ax.axhline(y=0,color='#125549',linewidth=1.5)
plt.tick_params(length=0 )#,grid_linewidth=1)

# plt.ylim([-0.02,0.017])
ax.set(xlabel=None)

# plt.ylabel(r'$beta_{{{zpow}}} \sim cpe:bdi$')
plt.ylabel(r'$log(RT_t) \sim cpe:beta_{{{zpow}}}$')

# plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.3f}')) 
plt.locator_params(axis='y', nbins=6,tight=True)



sns.despine()
    
plt.savefig(f'{fig_save_dir}/rt_interactionbetas{date}.pdf', format='pdf', metadata=None,
bbox_inches=None, pad_inches=0.1,dpi=500,
facecolor='auto', edgecolor='auto',
backend=None)

In [None]:
roi_bdi_plot_data


In [None]:
raw_decisionCPE:band_pow

In [None]:
roi_bdi_plot_data[roi_bdi_plot_data.roi==roi]

In [None]:
rho_data = []
for roi in roi_bdi_plot_data.roi.unique().tolist()[:-1]:
    roi_temp_data = roi_bdi_plot_data[roi_bdi_plot_data.roi==roi]
    rho,rho_p = pearsonr(roi_temp_data.bdi,roi_temp_data['raw_decisionCPE:band_pow'])
    rho_data.append(pd.DataFrame({'roi':roi,'rho':rho,'rho_p':rho_p},index=[0]))
rho_data = pd.concat(rho_data).reset_index(drop=True)

In [None]:
rho_data

In [None]:
rho_data[rho_data.roi=='pins']

In [None]:
rho_data['plot_p'] = [0.0001 if roi != 'pins' else 0.0003 for roi in rho_data.roi ]

In [None]:
rho_data['plot_rho'] = np.round(rho_data.rho,2)

In [None]:
rho_data

In [None]:

# sns.set_context("poster",rc={'axes.linewidth': 2})

for roi in roi_bdi_plot_data.roi.unique().tolist()[:-1]:
    roi_temp_data = roi_bdi_plot_data[roi_bdi_plot_data.roi==roi]
#     rho,rho_p = pearsonr(roi_temp_data.bdi,roi_temp_data['raw_decisionCPE:band_pow'])

    rho   = rho_data[rho_data.roi == roi].plot_rho.values[0]
    rho_p = rho_data[rho_data.roi == roi].plot_p.values[0]
    
        
    fig = plt.figure(figsize=(7,5),dpi=300)

#     sns.regplot(data=roi_temp_data,x='bdi',y='raw_decisionCPE:band_pow',scatter=False,ci=95,
#                 fit_reg=True,color = '#0b4e5f') #034975 0b4e5f 00467e 3f2755
    
    sns.regplot(data = roi_temp_data.groupby('subj_id').agg('mean'),x='bdi',y='raw_decisionCPE:band_pow'
                ,color='#135546') #0b4e5f #135546  0b4e5f

    
#     plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.2f}')) 

    plt.xlabel('BDI-II')
    
    plt.tick_params(length=0) 
#     plt.ylabel('beta'+r'$_{{{zpow}}}$'+' cpe encoding')
    plt.ylabel(r'$RT_t \sim cpe:beta_{{{zpow}}}$')


    plt.text(np.min(roi_temp_data['bdi'])+1.25,
             np.round(np.max(roi_temp_data.groupby(['subj_id']).agg('mean')[
        'raw_decisionCPE:band_pow']),3)-0.01,
             r'$rho$ = '+f'{rho} '+'\n '+r'$p < 0.001$',
             color='#135546',fontsize=18)
# +f'{np.round(rho_p,3)}',
#     plt.title(f'{roi}',fontsize=24)    

#     plt.text(5, np.round(np.max(roi_temp_data.groupby(['subj_id']).agg('mean')[
#         'raw_decisionCPE:band_pow']),3)-0.003,
#              r'$rho$ = '+f'{np.round(rho,2)} '+'\n '+r'$p < 0.0001$'
#              color='#0b4e5f',fontsize=16)
# # +f'{np.round(rho_p,3)}',
    
    plt.title(f'{roi}',fontsize=24)


    
    sns.despine()
    
    plt.savefig(f'{fig_save_dir}/rt_cpebeta_bdi_2{roi}_{date}.pdf', format='pdf', metadata=None,
        bbox_inches='tight', pad_inches=0.1,dpi=300,
        facecolor='auto', edgecolor='auto',
        backend=None)
    

In [None]:
roi_bdi_plot_data

In [None]:
beta_reg_df['Depressed'] = ['Low BDI' if bdi<20 else 'High BDI' for bdi in beta_reg_df.bdi]
beta_reg_df

In [None]:
roi_bdi_plot_data.unique_reref_ch.unique().tolist()

In [None]:
beta_lmplot = beta_reg_df.copy()
beta_lmplot = beta_lmplot.groupby('unique_reref_ch').agg('mean').reset_index()
beta_lmplot['Depressed'] = ['Low BDI' if bdi<20 else 'High BDI' for bdi in beta_lmplot.bdi]
beta_lmplot = beta_lmplot[beta_lmplot.unique_reref_ch.isin(roi_bdi_plot_data.unique_reref_ch.unique().tolist())]

beta_lmplot['raw_rt_coeff']  = [roi_bdi_plot_data[roi_bdi_plot_data.unique_reref_ch == elec][
    'raw_decisionCPE:band_pow'].values[0] for elec in beta_lmplot.unique_reref_ch]

beta_lmplot['elec_bCPE']  = [roi_bdi_plot_data[roi_bdi_plot_data.unique_reref_ch == elec][
    'elec_bCPE'].values[0] for elec in beta_lmplot.unique_reref_ch]

beta_lmplot['roi']  = [roi_bdi_plot_data[roi_bdi_plot_data.unique_reref_ch == elec][
    'roi'].values[0] for elec in beta_lmplot.unique_reref_ch]
beta_lmplot

In [None]:
lm_beta_reg_df['band_pow']

In [None]:
lm_beta_reg_df = beta_reg_df.copy()
lm_beta_reg_df = lm_beta_reg_df[lm_beta_reg_df.unique_reref_ch.isin(
    roi_bdi_plot_data.unique_reref_ch.unique().tolist())]


lm_beta_reg_df['raw_rt_coeff']  = [roi_bdi_plot_data[roi_bdi_plot_data.unique_reref_ch == elec][
    'raw_decisionCPE:band_pow'].values[0] for elec in lm_beta_reg_df.unique_reref_ch]

lm_beta_reg_df['elec_bCPE']  = [roi_bdi_plot_data[roi_bdi_plot_data.unique_reref_ch == elec][
    'elec_bCPE'].values[0] for elec in lm_beta_reg_df.unique_reref_ch]
lm_beta_reg_df
# beta_reg_df

In [None]:
lm_beta_reg_df['Beta_Power_Thresh'] = ['High' if beta_pow > np.mean(lm_beta_reg_df['band_pow']) 
                                       else 'Low' for beta_pow in lm_beta_reg_df['band_pow']]

lm_beta_reg_df['BDI-II'] = ['Low' if bdi<20 else 'High' for bdi in lm_beta_reg_df.bdi]
lm_beta_reg_df

In [None]:
lm_beta_reg_df['Beta_Power_Thresh'] = ['High' if beta_pow > np.mean(lm_beta_reg_df['band_pow']) 
                                       else 'Low' for beta_pow in lm_beta_reg_df['band_pow']]

In [None]:

palette=dict(Low="#591a48", High="#125549")
824973
sns.regplot(ax=ax,data = all_behav.groupby('subj_id').agg('mean'),x='bdi',y='logRT_t1',color='#0b4e5f') #,color='#0b4e5f')


In [None]:

lm_beta_reg_df['interaction_thresh'] = ['High' if elec_bCPE > np.mean(lm_beta_reg_df['elec_bCPE']) 
                                       else 'Low' for elec_bCPE in lm_beta_reg_df['elec_bCPE']]

In [None]:
# fig,ax = plt.subplots(figsize=(7,5),dpi=500)
sns.set_context("poster",rc={'axes.linewidth': 1.5})
sns.set_theme(style='ticks',rc={"figure.dpi": 500}) 


ax = sns.lmplot(data=lm_beta_reg_df,x='decisionCPE',y='logRT_t1',hue='Beta_Power_Thresh',scatter=False,height=4, 
           aspect=1.25,ci=95,palette=palette)
sns.move_legend(ax, "center right")

#     sns.regplot(data = roi_temp_data.groupby('subj_id').agg('mean'),x='bdi',y='raw_decisionCPE:band_pow'
#                 ,color='#3f2755') 
# sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='Low BDI'],x='decisionCPE',y='band_pow',scatter=False,
#            ci=95,color = '#173b6f')
# sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='High BDI'],x='decisionCPE',y='band_pow',scatter=False,
#            ci=95,color='#0b4e5f')


plt.ylabel(r'$log(RT_t) \sim cpe:beta_{{{zpow}}}$')
# plt.xlabel(r'$cpe (z) $')
sns.despine()
plt.tick_params(length=0)
# plt.text(7.25,0.07,f'rho: {np.round(r,3)} \n p: {np.round(p_r,6)}',fontsize=16)




In [None]:
sns.set_theme(style='ticks') 
sns.set_context("poster")

# from matplotlib import rc
# rc('text', usetex=False)


fig,ax = plt.subplots(figsize=(8,5),dpi=500)
# sns.set_context("poster",rc={'axes.linewidth': 2})
# sns.set_theme(style='ticks') 

# palette={'depression':"#3f2755",'no depression':"#054a6f"}

palette={'high BDI':"#125549",'low BDI':"#824973"}


# ax = sns.lmplot(data=beta_reg_df,x='decisionCPE',y='band_pow',hue='Depression',scatter=False,ci=95,palette=palette)
# sns.move_legend(ax, loc='lower center', bbox_to_anchor=(0.1, 1),bbox_to_anchor=(-1, 0), ncol=2)

# sns.move_legend(ax,"upper center", bbox_to_anchor=(0.5, 1), borderaxespad=0.1,frameon=False,framealpha=0,
#                 fontsize='x-small',shadow=None,title=None,columnspacing=0.5,edgecolor=None,markerfirst=True,
#                 markerscale=4,labelcolor=None, ncol=2)


sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='Low BDI'],x='band_pow',y='logRT_t1',scatter=False,
           ci=95,color = '#824973',label='Low BDI')
sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='High BDI'],x='band_pow',y='logRT_t1',scatter=False,
           ci=95,color='#125549',label='High BDI')

plt.legend(fontsize=10)
# borderaxespad=0,frameon=False,framealpha=0,
sns.move_legend(ax,"upper center", bbox_to_anchor=(0.5, 1.1), frameon=False,framealpha=0,
                fontsize='small',shadow=None,
                columnspacing=2,ncol=2)


plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.2f}')) 

# plt.ylabel(r'$beta_{{{zpow}}}$')
plt.ylabel(r'$log(RT_t)$')
# plt.xlabel(r'$cpe(z)$')
sns.despine()
plt.tick_params(length=0)
# plt.text(7.25,0.07,f'rho: {np.round(r,3)} \n p: {np.round(p_r,6)}',fontsize=16)


plt.savefig(f'{fig_save_dir}/rt_CPE_bdibandpow{date}.pdf', format='pdf', metadata=None,
bbox_inches='tight', pad_inches=0.1,dpi=500,
facecolor='auto', edgecolor='auto',
backend=None)

In [None]:
sns.set_theme(style='ticks') 
sns.set_context("poster")

# from matplotlib import rc
# rc('text', usetex=False)


fig,ax = plt.subplots(figsize=(8,5),dpi=500)
# sns.set_context("poster",rc={'axes.linewidth': 2})
# sns.set_theme(style='ticks') 

# palette={'depression':"#3f2755",'no depression':"#054a6f"}

palette={'high BDI':"#125549",'low BDI':"#ffb243"}


# ax = sns.lmplot(data=beta_reg_df,x='decisionCPE',y='band_pow',hue='Depression',scatter=False,ci=95,palette=palette)
# sns.move_legend(ax, loc='lower center', bbox_to_anchor=(0.1, 1),bbox_to_anchor=(-1, 0), ncol=2)

# sns.move_legend(ax,"upper center", bbox_to_anchor=(0.5, 1), borderaxespad=0.1,frameon=False,framealpha=0,
#                 fontsize='x-small',shadow=None,title=None,columnspacing=0.5,edgecolor=None,markerfirst=True,
#                 markerscale=4,labelcolor=None, ncol=2)


sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='Low BDI'],x='decisionCPE',y='logRT_t1',scatter=False,
           ci=95,color = '#ffb243',label='Low BDI')
sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='High BDI'],x='decisionCPE',y='logRT_t1',scatter=False,
           ci=95,color='#125549',label='High BDI')

plt.legend(fontsize=10)
# borderaxespad=0,frameon=False,framealpha=0,
sns.move_legend(ax,"upper center", bbox_to_anchor=(0.5, 1.1), frameon=False,framealpha=0,
                fontsize='small',shadow=None,
                columnspacing=2,ncol=2)


plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.2f}')) 

# plt.ylabel(r'$beta_{{{zpow}}}$')
plt.ylabel(r'$log(RT_t)$')
plt.xlabel(r'$cpe(z)$')
sns.despine()
plt.tick_params(length=0)
# plt.text(7.25,0.07,f'rho: {np.round(r,3)} \n p: {np.round(p_r,6)}',fontsize=16)


plt.savefig(f'{fig_save_dir}/rt_CPE_bdi{date}.pdf', format='pdf', metadata=None,
bbox_inches='tight', pad_inches=0.1,dpi=500,
facecolor='auto', edgecolor='auto',
backend=None)

In [None]:
lm_beta_reg_df

In [None]:
sns.set_theme(style='ticks') 
sns.set_context("poster")

# from matplotlib import rc
# rc('text', usetex=False)


fig,ax = plt.subplots(figsize=(8,5),dpi=500)
# sns.set_context("poster",rc={'axes.linewidth': 2})
# sns.set_theme(style='ticks') 

# palette={'depression':"#3f2755",'no depression':"#054a6f"}

palette={'high BDI':"#125549",'low BDI':"#824973"}


# ax = sns.lmplot(data=beta_reg_df,x='decisionCPE',y='band_pow',hue='Depression',scatter=False,ci=95,palette=palette)
# sns.move_legend(ax, loc='lower center', bbox_to_anchor=(0.1, 1),bbox_to_anchor=(-1, 0), ncol=2)

# sns.move_legend(ax,"upper center", bbox_to_anchor=(0.5, 1), borderaxespad=0.1,frameon=False,framealpha=0,
#                 fontsize='x-small',shadow=None,title=None,columnspacing=0.5,edgecolor=None,markerfirst=True,
#                 markerscale=4,labelcolor=None, ncol=2)


sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='Low BDI'],x='decisionCPE',y='logRT_t1',scatter=False,
           ci=95,color = '#824973',label='Low BDI')
sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='High BDI'],x='decisionCPE',y='logRT_t1',scatter=False,
           ci=95,color='#125549',label='High BDI')

plt.legend(fontsize=10)
# borderaxespad=0,frameon=False,framealpha=0,
sns.move_legend(ax,"upper center", bbox_to_anchor=(0.5, 1.1), frameon=False,framealpha=0,
                fontsize='small',shadow=None,
                columnspacing=2,ncol=2)


plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.2f}')) 

# plt.ylabel(r'$beta_{{{zpow}}}$')
plt.ylabel(r'$log(RT_t)$')
plt.xlabel(r'$cpe(z)$')
sns.despine()
plt.tick_params(length=0)
# plt.text(7.25,0.07,f'rho: {np.round(r,3)} \n p: {np.round(p_r,6)}',fontsize=16)


plt.savefig(f'{fig_save_dir}/rt_CPE_bdi{date}.pdf', format='pdf', metadata=None,
bbox_inches='tight', pad_inches=0.1,dpi=500,
facecolor='auto', edgecolor='auto',
backend=None)

In [None]:
raw_rt_coeff
054a6f
824973

In [None]:
# fig,ax = plt.subplots(figsize=(7,5),dpi=500)
sns.set_context("poster",rc={'axes.linewidth': 1.5})

sns.lmplot(data=lm_beta_reg_df,x='band_pow',y='logRT_t1',hue='Depressed',scatter=False,height=7, 
           aspect=1.25,ci=95)

#     sns.regplot(data = roi_temp_data.groupby('subj_id').agg('mean'),x='bdi',y='raw_decisionCPE:band_pow'
#                 ,color='#3f2755') 
# sns.regplot(data=lm_beta_reg_df[lm_beta_reg_df.Depressed=='Low BDI'],x='decisionCPE',y='raw_rt_coeff',
#             scatter=False,
#            ci=95,color = '#173b6f')
# sns.regplot(data=lm_beta_reg_df[lm_beta_reg_df.Depressed=='High BDI'],x='decisionCPE',y='raw_rt_coeff',
#             scatter=False,
#            ci=95,color='#0b4e5f')


plt.ylabel(r'$log(RT_t) \sim cpe:beta_{{{zpow}}}$')
plt.xlabel(r'$cpe (z) $')
sns.despine()
plt.tick_params(length=0)
# plt.text(7.25,0.07,f'rho: {np.round(r,3)} \n p: {np.round(p_r,6)}',fontsize=16)




In [None]:
# fig,ax = plt.subplots(figsize=(7,5),dpi=500)
sns.set_context("poster",rc={'axes.linewidth': 1.5})

sns.lmplot(data=lm_beta_reg_df,x='decisionCPE',y='logRT_t1',hue='Beta_Power_Thresh',scatter=False,height=7, 
           aspect=1.5,ci=95)

#     sns.regplot(data = roi_temp_data.groupby('subj_id').agg('mean'),x='bdi',y='raw_decisionCPE:band_pow'
#                 ,color='#3f2755') 
# sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='Low BDI'],x='decisionCPE',y='band_pow',scatter=False,
#            ci=95,color = '#173b6f')
# sns.regplot(data=beta_reg_df[beta_reg_df.Depressed=='High BDI'],x='decisionCPE',y='band_pow',scatter=False,
#            ci=95,color='#0b4e5f')


plt.ylabel(r'$log(RT_t) \sim cpe:beta_{{{zpow}}}$')
plt.xlabel(r'$cpe (z) $')
sns.despine()
plt.tick_params(length=0)
# plt.text(7.25,0.07,f'rho: {np.round(r,3)} \n p: {np.round(p_r,6)}',fontsize=16)

