# Epoching for Continuous Data 
Created: 03/14/2024 by A Fink

In [1]:
import numpy as np
import mne
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from scipy.stats import zscore, linregress, ttest_ind, ttest_rel, ttest_1samp
import pandas as pd
from mne.preprocessing.bads import _find_outliers
import os 
import joblib
import re
import datetime
import scipy


import warnings
warnings.filterwarnings('ignore')

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('/sc/arion/projects/guLab/Alie/SWB/ephys_analysis/LFPAnalysis/')

In [4]:
from LFPAnalysis import lfp_preprocess_utils, sync_utils, analysis_utils, nlx_utils

## Data loading and organization

In [5]:
# Specify root directory for un-archived data and results 
base_dir = '/sc/arion/projects/guLab/Alie/SWB/'
anat_dir = f'{base_dir}ephys_analysis/recon_labels/'
neural_dir = f'{base_dir}ephys_analysis/data/'
behav_dir = f'{base_dir}swb_behav_models/data/behavior_preprocessed/'

subj_id = 'MS001'
subj_format = ['edf']
subj_site = ['MSSM']


## Epoching
https://mne.tools/dev/generated/mne.make_fixed_length_epochs.html#mne.make_fixed_length_epochs

## Data Formatting + Saving for Connectivity Analyses

In [None]:
# # choose the connectivity metric
# metric = 'coh'  # ['coh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', 'gc_tr']

# band_dict = {'theta' : [2, 9],
#                'beta' : [14, 30]}

# freqs = np.logspace(*np.log10([2, 200]), num=30)
# n_cycles = np.floor(np.logspace(*np.log10([3, 10]), num=30))

# buf_ms = 1000
# n_surr = 500

# source_region = 'HPC'

# # source_region = 'dmPFC'
# # ['OFC', 'ACC', 'AMY']
# # source_region = 'OFC'
# # ['ACC', 'AMY']
# # source_region = 'ACC'
# # ['AMY']

# # iterate through target regions
# # analysis_evs = ['baseline_start']
# analysis_evs = ['feedback_start', 'recog_time']


# for target_region in ['OFC', 'AMY']: 
#     conn_group_data = []

#     for subj_id in subj_ids:

#         filepath = f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/scratch/PSI'
#         if not os.path.exists(filepath):
#             os.makedirs(filepath)

#         for event in analysis_evs:
#             save_path = f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/neural/{day}'
#             epochs_reref = mne.read_epochs(f'{save_path}/{event}-epo.fif', preload=True) 

#             # Get electrode df 
#             elec_df = pd.read_csv(f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/Day1_reref_elec_df')

#             # construct the seed-to-target mapping based on your rois - matters most for PSI as this is directional 
#             seed_target_df = pd.DataFrame(columns=['seed', 'target'], index=['l', 'r'])
#             for hemi in ['l', 'r']:
#                 seed_target_df['seed'][hemi] = np.where(elec_df.loc[elec_df.hemisphere==hemi, 'salman_region'] == source_region)[0]
# #                 seed_target_df['target'][hemi] = np.where(elec_df.loc[elec_df.hemisphere==hemi, 'salman_region'].isin(target_regions))[0]    

#                 seed_target_df['target'][hemi] = np.where(elec_df.loc[elec_df.hemisphere==hemi, 'salman_region'] == target_region)[0]    

#             seed_target_df = seed_target_df[
#                         (seed_target_df['seed'].map(lambda d: len(d) > 0)) & (seed_target_df['target'].map(lambda d: len(d) > 0))]


#             # for cond in conditions: 
#             for hemi in ['l', 'r']:
#                 # first determine if ipsi connectivity is even possible; if not, move on
#                 if hemi not in seed_target_df.index.tolist():
#                     continue
#                 else:
#                     seed_to_target = seed_target_indices(
#                         seed_target_df['seed'][hemi],
#                         seed_target_df['target'][hemi])
                    
                
                    
#                 # These epochs only grab those images that were later involved in memory recognition:
#                 if event in ['gamble_start', 'feedback_start', 'baseline_start']:
#                     epochs_to_analyze = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_gamble').reset_index(drop=True).trials_gamble.values - 1
#                 elif event == 'recog_time': 
#                     epochs_to_analyze = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_mem').reset_index(drop=True).trials_mem.values - 1
                    
#                 trunc_epochs = epochs_reref[epochs_to_analyze.astype(int)]
                
#                 avg_dim = 'time'
                
#                 # NOTE: If I compute any epoch-avged connectivity measures, I need to NAN bad epochs BEFORE hand 
#                 npairs = len(seed_to_target[0])
                
# #                 for pair in range(npairs): 
# #                     source_bad_epochs  = list(np.where(trunc_epochs.metadata[trunc_epochs.ch_names[seed_to_target[0][pair]]].notnull())[0])
# #                     target_bad_epochs  = list(np.where(trunc_epochs.metadata[trunc_epochs.ch_names[seed_to_target[1][pair]]].notnull())[0])
# #                     bad_epochs = np.array(source_bad_epochs+target_bad_epochs)
# #                     channels_involved = np.concatenate([np.unique(x) for x in seed_to_target])
# #                     trunc_epochs._data[bad_epochs[:, np.newaxis], channels_involved, :] = 0
                
#                 pwise_dfs = []
#                 for band in band_dict.keys():
                
#                     pwise = oscillation_utils.compute_connectivity(trunc_epochs, 
#                                                                band = band_dict[band], 
#                                                                metric = metric, 
#                                                                indices = seed_to_target, 
#                                                                freqs = freqs, 
#                                                                n_cycles = n_cycles,
#                                                                buf_ms = buf_ms, 
#                                                                n_surr=n_surr,
#                                                                avg_over_dim=avg_dim)

#                     if avg_dim == 'epoch':
#                         pwise = pwise[:, int(buf_ms  * (epochs_reref.info['sfreq']/1000)):-int(buf_ms  * (epochs_reref.info['sfreq']/1000))]

#                     for pair in range(npairs): 
#                         pwise_df = pd.DataFrame(columns=['participant', 'roi1', 'roi2', 'hemi', 'pair_label', 'metric', 'event', 'conn'])
#                         pwise_df['conn'] = pwise[:, pair] 
#                         pwise_df['participant'] = subj_id
#                         pwise_df['age'] = subj_df[subj_df.MSSMCode==subj_id].Age.values[0]
#                         pwise_df['sex'] = subj_df[subj_df.MSSMCode==subj_id].Sex.str.strip().values[0]

#                         pwise_df['roi1'] = source_region
#                         pwise_df['roi2'] = target_region
#                         pwise_df['hemi'] = hemi 
#                         pwise_df['metric'] = metric 
#                         pwise_df['event'] = event
#                         pwise_df['band'] = band
#                         pwise_df['pair_label'] = f'{epochs_reref.ch_names[seed_to_target[0][pair]]}:{epochs_reref.ch_names[seed_to_target[1][pair]]}'
#                         # pwise_df['trials'] = np.arange(1, pwise.shape[0]+1)
#                         if event in ['gamble_start', 'feedback_start', 'baseline_start']:
#                             pwise_df['trials_gamble'] = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_gamble').reset_index(drop=True).trials_gamble.values
#                             pwise_df = pwise_df.merge(combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_gamble').reset_index(drop=True), 
#                                                       on=['participant', 'trials_gamble'])
#                         elif event == 'recog_time':
#                             pwise_df['trials_mem'] = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_mem').reset_index(drop=True).trials_mem.values
#                             pwise_df = pwise_df.merge(combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_mem').reset_index(drop=True), 
#                                                       on=['participant', 'trials_mem'])

#                         pwise_df['zrpe'] = (pwise_df.rpe - np.nanmean(pwise_df.rpe)) / (2*np.nanstd(pwise_df.rpe))
#                         pwise_df['zpm'] = (pwise_df.DPRIME - np.nanmean(pwise_df.DPRIME)) / (2*np.nanstd(pwise_df.DPRIME))


#                         pwise_df['good_epoch'] = 1
#                         # NOTE: HOW TO HANDLE BAD EPOCHS? 
#                         # find all the bad epochs across both channels and add to the dataframe under pair label 
#                         # for ch_ix in progress_bar: 
#                         source_bad_epochs  = list(np.where(epochs_reref.metadata[epochs_reref.ch_names[seed_to_target[0][pair]]].notnull())[0])
#                         target_bad_epochs  = list(np.where(epochs_reref.metadata[epochs_reref.ch_names[seed_to_target[1][pair]]].notnull())[0])
#                         bad_epochs = np.array(source_bad_epochs+target_bad_epochs)
#                         if event in ['gamble_start', 'feedback_start', 'baseline_start']:
#                             pwise_df.loc[pwise_df['trials_gamble'].isin(bad_epochs), 'good_epoch'] = 0
#                         elif event == 'recog_time':
#                             pwise_df.loc[pwise_df['trials_mem'].isin(bad_epochs), 'good_epoch'] = 0
                            
#                         # aggregate
#                         pwise_dfs.append(pwise_df)
#                 pwise_dfs = pd.concat(pwise_dfs)
#                 conn_group_data.append(pwise_dfs)

#     all_pairs_df = pd.concat(conn_group_data)
#     all_pairs_df.to_csv(f'/sc/arion/projects/guLab/Salman/EphysAnalyses/Connectivity/{source_region}_{target_region}_{metric}_df', index=False)

