# SWB CPE Connectivity Analysis: Theta Coherence

Created: 10/07/2024 \
Updated: 11/20/2024 


In [1]:
import numpy as np
import mne
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from scipy.stats import zscore, linregress, ttest_ind, ttest_rel, ttest_1samp, pearsonr, spearmanr
import pandas as pd
from mne.preprocessing.bads import _find_outliers
import os 
import joblib
import re
import datetime
import scipy
import random
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.regression.mixed_linear_model import MixedLM 
from joblib import Parallel, delayed
import pickle
import itertools
import time 
from matplotlib.ticker import StrMethodFormatter


import mne_connectivity
from mne_connectivity import phase_slope_index, seed_target_indices, spectral_connectivity_epochs
# import fooof
# Import plotting functions
# from fooof.plts.spectra import plot_spectrum, plot_spectra
# # Import the FOOOF object
# from fooof import FOOOF
# from fooof import FOOOFGroup

from tqdm import tqdm
from IPython.display import clear_output

from joblib import delayed, Parallel
from statsmodels.stats import multitest
import warnings
warnings.filterwarnings('ignore')
# print('\n'.join(f'{m.__name__}=={m.__version__}' for m in globals().values() if getattr(m, '__version__', None)))

%load_ext autoreload
%autoreload 2



In [2]:
band = 'theta'

In [3]:
# Specify root directory for un-archived data and results 
base_dir   = '/sc/arion/projects/guLab/Alie/SWB/'
neural_dir = f'{base_dir}ephys_analysis/data/'
behav_dir  = f'{base_dir}ephys_analysis/behav/behav_data/'
save_dir   = f'{base_dir}ephys_analysis/results/connectivity/coherence/{band}/'
os.makedirs(save_dir,exist_ok=True)

script_dir = '/hpc/users/finka03/swb_ephys_analysis/scripts/'

date = datetime.date.today().strftime('%m%d%Y')
print(date)

# anat_dir   = f'{base_dir}ephys_analysis/recon_labels/'
# behav_dir  = f'{base_dir}swb_behav_models/data/behavior_preprocessed/'


11202024


In [4]:
import sys
sys.path.append(f'{base_dir}ephys_analysis/LFPAnalysis/')

from LFPAnalysis import oscillation_utils

sys.path.append(f'{script_dir}analysis_notebooks/')

from ieeg_tools import *

sys.path.append(f'{script_dir}behav/')

from behav_utils import *
from swb_subj_behav import *


In [5]:
subj_ids = list(pd.read_excel(f'{base_dir}ephys_analysis/subj_info/SWB_subjects.xlsx', 
                              sheet_name='Usable_Subjects', usecols=[0]).PatientID)
n_subj = len(subj_ids)
# subj_ids


# Load Behav + Elec ROI Data
- all_behav from updated task_dfs and behav_utils formatting
- roi_reref_labels same as usual

In [8]:
# # all_behav = pd.read_csv(f'{behav_dir}all_behav.csv') ## this isn't normalized yet 
# raw_behav = [pd.read_csv(f'{behav_dir}{subj_id}_task_df.csv') for subj_id in subj_ids]
# all_behav,drops_data = format_all_behav(raw_behav,return_drops=True)

# all_behav

In [6]:
# raw_behav = [pd.read_csv(f'{behav_dir}{subj_id}_task_df.csv') for subj_id in subj_ids]
# temp_behav,beh_drops = format_all_behav(raw_behav,return_drops=True,norm=False)


In [7]:
# all_orthog_rpe  = []
# all_rpe_cpe_rsq = []
# behav_list = []

# for ix, subj_id in enumerate(subj_ids): 
# #     subj_df = raw_behav[ix]
#     subj_df = temp_behav[temp_behav.subj_id==subj_id].reset_index(drop=True)
#     # find indices where cpe is not nan
#     notnan_idx = [ix for ix,cpe in enumerate(subj_df.cpe) if not np.isnan(cpe)]
#     # initialize save vector
#     subj_df['orthog_rpe'] = np.nan
#     # run ols regression rpe ~ cpe
#     rpe_cpe = sm.OLS(subj_df.rpe,sm.add_constant(subj_df.cpe),missing='drop').fit()
#     all_rpe_cpe_rsq.append(rpe_cpe.rsquared)
#     subj_df['orthog_rpe'][notnan_idx] = rpe_cpe.resid
#     all_orthog_rpe.extend(subj_df['orthog_rpe'])
#     behav_list.append(subj_df)
#     del rpe_cpe 
    


In [8]:
# all_behav,beh_drops = format_all_behav(behav_list,return_drops=True,norm=True)


In [9]:
# all_behav.to_csv(f'{save_dir}all_behav_{band}_coh_{date}.csv', index=False)

In [6]:
all_behav = pd.read_csv(f'{save_dir}all_behav_theta_coh_11192024.csv')

In [7]:
all_behav

Unnamed: 0,subj_id,bdi,bdi_thresh,Round,TrialNum,RT,TrialOnset,ChoiceOnset,DecisionOnset,FeedbackOnset,...,TrialEV_t1,CR_t1,choiceEV_t1,rpe_t1,res_type_t1,cf_t1,cpe_t1,keep_epoch,keep_epoch_t1,orthog_rpe
0,MS002,14,low,1,25.0,2.059852,513.380590,513.390239,515.450091,515.457173,...,-0.423529,,-0.545852,-0.553325,gamble_bad,-0.150761,-0.450578,keep,keep,0.213638
1,MS002,14,low,2,117.0,1.954564,522.640856,522.641563,524.596127,526.627092,...,0.737993,,0.495438,0.954543,gamble_good,0.277858,0.728779,keep,keep,-0.383128
2,MS002,14,low,3,79.0,1.583462,531.174799,531.175599,532.759061,534.780269,...,0.152000,0.090121,,-0.008658,safe_good,-0.456918,0.435467,keep,keep,0.684004
3,MS002,14,low,4,42.0,2.491611,545.592613,545.593355,548.084966,548.092333,...,0.619399,,0.321229,0.702276,gamble_good,0.339089,0.398803,keep,keep,-0.255425
4,MS002,14,low,5,85.0,1.768936,555.337336,555.345720,557.114656,559.135069,...,-0.116581,0.090121,,-0.008658,safe_good,-0.640612,0.618787,keep,keep,0.579806
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4045,DA039,22,high,146,79.0,1.079701,2259.827656,2259.828749,2260.908450,2262.926195,...,-0.297739,0.343221,,0.007477,safe_good,-1.253105,0.916364,keep,keep,0.279076
4046,DA039,22,high,147,30.0,1.837272,2267.502359,2267.534059,2269.371331,2269.377701,...,-0.014213,0.343221,,0.007477,safe_bad,0.356530,-0.300100,keep,keep,-0.947095
4047,DA039,22,high,148,13.0,4.030006,2282.349445,2282.350662,2286.380667,2286.389886,...,-0.115722,0.343221,,0.007477,safe_bad,0.364957,-0.306469,keep,keep,0.336081
4048,DA039,22,high,149,18.0,3.167144,2293.040983,2293.042042,2296.209186,2296.218136,...,-1.032806,,-1.312281,-1.093440,gamble_bad,-0.410364,-1.312759,keep,keep,0.342799


In [8]:
roi_reref_labels_master_df = pd.read_csv(
    glob(f'{base_dir}ephys_analysis/results/roi_info/roi_reref_labels_master.csv')[0]).drop(columns=['Unnamed: 0'])


In [9]:
band

'theta'

# Connectivity Computations : Theta Coherence 

In [10]:
# define connectivity analysis parameters:

# spectral parameters - wavelet freqs, wavelet cycles, freq band ranges
freqs = np.logspace(*np.log10([2, 200]), num=30)
n_cycles = np.floor(np.logspace(*np.log10([3, 10]), num=30))

freq_dict = {'theta':[4, 8], 
             'alpha':[8, 13],
             'beta': [13, 30], 
             'gamma': [30,70],
             'hfa': [70, 200]}

# analysis parameters - connectivity metric, conn freq band, num of surrogates for permutations, buffer time in ms
metric   = 'coh'
band     = 'theta' # set band(s) of interest for analysis
n_surr   = 500
buf_ms   = 1000

# data info - analysis epoch + rois for pairwise coh
epoch_id = 'CpeOnset' 
# rois
coh_rois = ['acc','ains','ofc','dlpfc','vlpfc','amy','dmpfc','hpc']

# dict of subj_ids with elecs in roi 
roi_subj_ids = {f'{roi}':roi_reref_labels_master_df.subj_id[
    roi_reref_labels_master_df.roi == roi].unique().tolist() 
                for roi in coh_rois}
# # make unique list of pairs [[pair1,pair2],..] without hard coding 
# pairs = [list(tup) for tup in list(itertools.combinations(coh_rois,2))]
# pairs


In [11]:
all_combos = [list(tup) for tup in list(itertools.combinations(coh_rois,2))]
all_combos

[['acc', 'ains'],
 ['acc', 'ofc'],
 ['acc', 'dlpfc'],
 ['acc', 'vlpfc'],
 ['acc', 'amy'],
 ['acc', 'dmpfc'],
 ['acc', 'hpc'],
 ['ains', 'ofc'],
 ['ains', 'dlpfc'],
 ['ains', 'vlpfc'],
 ['ains', 'amy'],
 ['ains', 'dmpfc'],
 ['ains', 'hpc'],
 ['ofc', 'dlpfc'],
 ['ofc', 'vlpfc'],
 ['ofc', 'amy'],
 ['ofc', 'dmpfc'],
 ['ofc', 'hpc'],
 ['dlpfc', 'vlpfc'],
 ['dlpfc', 'amy'],
 ['dlpfc', 'dmpfc'],
 ['dlpfc', 'hpc'],
 ['vlpfc', 'amy'],
 ['vlpfc', 'dmpfc'],
 ['vlpfc', 'hpc'],
 ['amy', 'dmpfc'],
 ['amy', 'hpc'],
 ['dmpfc', 'hpc']]

In [12]:
finished_pairs = [file.split('_')[:2] for file in os.listdir(f'{save_dir}pair_data/')]
finished_pairs

[['dlpfc', 'amy'],
 ['ains', 'amy'],
 ['acc', 'amy'],
 ['ofc', 'dlpfc'],
 ['acc', 'ains'],
 ['acc', 'dlpfc'],
 ['ofc', 'amy'],
 ['ains', 'dmpfc'],
 ['acc', 'dmpfc'],
 ['ains', 'dlpfc'],
 ['ains', 'vlpfc'],
 ['acc', 'vlpfc'],
 ['ains', 'ofc'],
 ['acc', 'ofc'],
 ['ofc', 'dmpfc']]

In [13]:
# ## unique list of all pairs that have any single subjects saved 
all_pair_files = [file.split('_')[1:3] for file in os.listdir(f'{save_dir}single_subj/') if file.split('_')[-1] == 'df.csv']
all_pair_files

pair_subj_data = []
for pair in all_pair_files:
    if pair not in pair_subj_data:
        pair_subj_data.append(pair)
pair_subj_data

# # find which pairs weren't completed 
incomplete_pairs = [pair for pair in pair_subj_data if pair not in finished_pairs]
incomplete_pairs

[['amy', 'dmpfc']]

In [14]:
complete_subj_dict   = {}
incomplete_subj_dict = {}

for pair in incomplete_pairs:
    source_region = pair[0]
    target_region = pair[1]    
    # unique pair id roi1_roi2
    pair_id = '_'.join([source_region,target_region])

    # find subj with elecs in each roi 
    source_subj = roi_subj_ids[source_region]
    target_subj = roi_subj_ids[target_region]
    
    # # find subj with elecs in both rois
    pair_subj = list(set(source_subj).intersection(target_subj))    
    # # find completed subj files
    complete_subj_files = glob(f'{save_dir}single_subj/*{pair_id}_coh_{band}_df.csv')
    complete_subj = [file.split('/')[-1].split('_')[0] for file in complete_subj_files]
    complete_subj_dict[pair_id] = complete_subj

    # # find subj from pair_subj with no saved data
    incomplete_subj = list(np.setdiff1d(pair_subj, complete_subj))
    incomplete_subj_dict[pair_id] = incomplete_subj


In [15]:
run_pairs = [pair for pair in all_combos if (pair not in finished_pairs) & (pair not in incomplete_pairs)]
run_pairs

[['acc', 'hpc'],
 ['ains', 'hpc'],
 ['ofc', 'vlpfc'],
 ['ofc', 'hpc'],
 ['dlpfc', 'vlpfc'],
 ['dlpfc', 'dmpfc'],
 ['dlpfc', 'hpc'],
 ['vlpfc', 'amy'],
 ['vlpfc', 'dmpfc'],
 ['vlpfc', 'hpc'],
 ['amy', 'hpc'],
 ['dmpfc', 'hpc']]

In [16]:
# add incomplete pair to beginning of list to run first 
pairs = incomplete_pairs + run_pairs
pairs

[['amy', 'dmpfc'],
 ['acc', 'hpc'],
 ['ains', 'hpc'],
 ['ofc', 'vlpfc'],
 ['ofc', 'hpc'],
 ['dlpfc', 'vlpfc'],
 ['dlpfc', 'dmpfc'],
 ['dlpfc', 'hpc'],
 ['vlpfc', 'amy'],
 ['vlpfc', 'dmpfc'],
 ['vlpfc', 'hpc'],
 ['amy', 'hpc'],
 ['dmpfc', 'hpc']]

In [17]:
#### running 2nd half of pairs in parallel with main notebook doing first half - drop any in progress pairs 

pairs = [['dlpfc', 'hpc'],
         ['amy', 'hpc'],
         ['dmpfc', 'hpc'],
         ['vlpfc', 'hpc'],
         ['vlpfc', 'dmpfc'],
         ['dlpfc', 'vlpfc'],
         ['dlpfc', 'dmpfc']]
pairs

[['dlpfc', 'hpc'],
 ['amy', 'hpc'],
 ['dmpfc', 'hpc'],
 ['vlpfc', 'hpc'],
 ['vlpfc', 'dmpfc'],
 ['dlpfc', 'vlpfc'],
 ['dlpfc', 'dmpfc']]

In [None]:
######## to compute coh after notebook ends in middle of roi pair

for pair in pairs:
    
    source_region = pair[0]
    target_region = pair[1]    
    
    # unique pair id roi1_roi2
    pair_id = '_'.join([source_region,target_region])
    
    # find subj with elecs in each roi 
    source_subj = roi_subj_ids[source_region]
    target_subj = roi_subj_ids[target_region]
    # find subj with elecs in both rois
    pair_subj = list(set(source_subj).intersection(target_subj))    
    # save pair subj list 

    if pair in incomplete_pairs:
        complete_subj   = complete_subj_dict[pair_id]
        incomplete_subj = incomplete_subj_dict[pair_id]
    
    # initialize the storage list 
    all_subj_pair_df = []
    
    # iterate through pair subjects
    for subj_id in pair_subj:
        
        if pair in incomplete_pairs:
            if subj_id in complete_subj:
                continue
                
        # load & format rereferenced epoch data 
        subj_epochs = mne.read_epochs(f'{neural_dir}{subj_id}/{subj_id}_conn_epochs_{epoch_id}.fif', preload=True)
        subj_elecs  = subj_epochs.ch_names
        
        # construct the seed-to-target mapping based on subject's roi coverage 
        elec_roi_df = roi_reref_labels_master_df[roi_reref_labels_master_df.subj_id==subj_id].reset_index(drop=True)
        # get ch names of subj elecs in roi 
        source_ch_names  = elec_roi_df.reref_ch_names[np.where(elec_roi_df.roi == source_region)[0]].tolist()
        target_ch_names  = elec_roi_df.reref_ch_names[np.where(elec_roi_df.roi == target_region)[0]].tolist()
        # get idx of ch in subj_elecs list (will correspond to idx in epochs array)
        source_elec_idx = [subj_elecs.index(elec) for elec in source_ch_names]
        target_elec_idx = [subj_elecs.index(elec) for elec in target_ch_names]
        # make seed to target indices using mne function 
        seed_to_target = seed_target_indices(
                        source_elec_idx,
                        target_elec_idx)
        
        # elec name for every elec pair 
        subj_pair_ch = list(map(lambda x,y: '_'.join([x,y]), 
                                  [subj_elecs[idx] for idx in  seed_to_target[0]], 
                                  [subj_elecs[idx] for idx in  seed_to_target[1]]))
        # unique elec name for every elec pair 
        unique_ch_pair = list(map(lambda x,y: '_'.join([x,y]), [subj_id]*len(subj_pair_ch), subj_pair_ch))
        
        # compute pwise coherence 
        pwise = oscillation_utils.compute_connectivity(subj_epochs, 
                                           band = freq_dict[band], 
                                           metric = metric, 
                                           indices = seed_to_target, 
                                           freqs = freqs, 
                                           n_cycles = n_cycles,
                                           buf_ms = buf_ms, 
                                           n_surr=n_surr,
                                           avg_over_dim='time',
                                           band1 = freq_dict[band],
                                           parallelize=True)

        
        coh_df = pd.concat([pd.DataFrame({'epoch':np.arange(0,pwise.shape[0]),'coh':pwise[:,ch_ix],
                                          'unique_ch_pair':[ch_name]*pwise.shape[0],
                                          'roi_pair_chans':['_'.join(ch_name.split('_')[1:])]*pwise.shape[0],
                                          'roi1_ch_names':[ch_name.split('_')[1]]*pwise.shape[0],
                                          'roi2_ch_names':[ch_name.split('_')[2]]*pwise.shape[0],
                                          'roi1_elec_idx':[seed_to_target[0]]*pwise.shape[0],
                                          'roi2_elec_idx':[seed_to_target[1]]*pwise.shape[0]}) 
                            for ch_ix, ch_name in enumerate(unique_ch_pair)])



        coh_df['subj_id']  = subj_id
        coh_df['bdi']      = all_behav[all_behav.subj_id == subj_id].bdi.unique().tolist()[0]
        coh_df['Round']    = all_behav[all_behav.subj_id == subj_id].Round
        coh_df['epoch']    = all_behav[all_behav.subj_id == subj_id].epoch
        coh_df['band']     = band
        coh_df['metric']   = metric
        coh_df['pair_id']  = pair_id
        coh_df['roi1']     = source_region
        coh_df['roi2']     = target_region


        # one pair one subj data 
        coh_df.to_csv(f'{save_dir}single_subj/{subj_id}_{pair_id}_{metric}_{band}_df.csv')
        all_subj_pair_df.append(coh_df)
        del coh_df, subj_epochs, pwise
#         print(f'finished {subj_id} {pair_id}')
    
    
    # # add completed subj dfs to concat list 
    if pair in incomplete_pairs:
        for subj_id in complete_subj:
            subj_df = pd.read_csv(f'{save_dir}single_subj/{subj_id}_{pair_id}_{metric}_{band}_df.csv')
            all_subj_pair_df.append(subj_df)
            del subj_df   
        
    all_subj_pair_df = pd.concat(all_subj_pair_df).reset_index()
    # save roi pair df separately 
    all_subj_pair_df.to_csv(f'{save_dir}pair_data/{pair_id}_{metric}_{band}_df.csv', index=False)


Reading /sc/arion/projects/guLab/Alie/SWB/ephys_analysis/data/MS033/MS033_conn_epochs_CpeOnset.fif ...
    Found the data of interest:
        t =   -1000.00 ...    2000.00 ms
        0 CTF compensation matrices available
Not setting metadata
150 matching events found
No baseline correction applied
0 projection items activated
