# Epoching for Continuous Data 
Created: 03/14/2024 by A Fink

In [1]:
import numpy as np
import mne
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from scipy.stats import zscore, linregress, ttest_ind, ttest_rel, ttest_1samp
import pandas as pd
from mne.preprocessing.bads import _find_outliers
import os 
import joblib
import re
import datetime
import scipy


import warnings
warnings.filterwarnings('ignore')

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('/Users/alexandrafink/Documents/GitHub/LFPAnalysis/')

In [4]:
from LFPAnalysis import lfp_preprocess_utils, sync_utils, analysis_utils, nlx_utils

## Data loading and organization

In [5]:
# Specify root directory for un-archived data and results 
base_dir      = '/Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg/'
anat_dir      = f'{base_dir}anat/'
neural_dir    = f'{base_dir}/preprocess/clean_data/'
subj_info_dir = f'{base_dir}patient_tracker/'

subj_ids = list(pd.read_excel(f'{subj_info_dir}subjects_master_list.xlsx', usecols=[0]).PatientID)


In [8]:
bp_lfp_all_subj = {}
epochs_all_subj = {}


for subj_id in subj_ids:
    bp_data = mne.io.read_raw_fif(f'{neural_dir}{subj_id}/{subj_id}_bp_ref_ieeg.fif', preload=True)
    bp_data.crop(tmin=10,tmax=430)
    bp_lfp_all_subj[subj_id] = bp_data
    epochs = mne.make_fixed_length_epochs(bp_data, duration=10, preload=True)
    epochs_all_subj[subj_id] = epochs
    


Opening raw data file /Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg//preprocess/clean_data/MS001/MS001_bp_ref_ieeg.fif...
    Range : 0 ... 304687 =      0.000 ...   609.374 secs
Ready.
Reading 0 ... 304687  =      0.000 ...   609.374 secs...
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 42 events and 5000 original time points ...
0 bad epochs dropped
Opening raw data file /Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg//preprocess/clean_data/MS003/MS003_bp_ref_ieeg.fif...
    Range : 0 ... 299894 =      0.000 ...   599.788 secs
Ready.
Reading 0 ... 299894  =      0.000 ...   599.788 secs...
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 42 events and 5000 original time points ...
1 bad epochs dropped
Opening raw 

Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 42 events and 5000 original time points ...
0 bad epochs dropped
Opening raw data file /Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg//preprocess/clean_data/MS026/MS026_bp_ref_ieeg.fif...
    Range : 0 ... 226561 =      0.000 ...   453.122 secs
Ready.
Reading 0 ... 226561  =      0.000 ...   453.122 secs...
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 42 events and 5000 original time points ...
0 bad epochs dropped
Opening raw data file /Users/alexandrafink/Documents/GraduateSchool/SaezLab/resting_state_proj/resting_state_ieeg//preprocess/clean_data/MS027/MS027_bp_ref_ieeg.fif...
    Range : 0 ... 317874 =      0.000 ...   635.748 secs
Ready.
Reading 0 ... 317874  =      0.000 ...   635.748 secs...
Not setting 

In [9]:
epochs_all_subj

{'MS001': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~97.8 MB, data loaded,
  '1': 42>,
 'MS003': <Epochs |  41 events (all good), 0 - 9.998 sec, baseline off, ~125.3 MB, data loaded,
  '1': 41>,
 'MS006': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~115.5 MB, data loaded,
  '1': 42>,
 'MS007': <Epochs |  32 events (all good), 0 - 9.998 sec, baseline off, ~143.0 MB, data loaded,
  '1': 32>,
 'MS008': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~129.9 MB, data loaded,
  '1': 42>,
 'MS010': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~125.1 MB, data loaded,
  '1': 42>,
 'MS012': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~86.6 MB, data loaded,
  '1': 42>,
 'MS014': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~54.5 MB, data loaded,
  '1': 42>,
 'MS016': <Epochs |  42 events (all good), 0 - 9.998 sec, baseline off, ~104.3 MB, data loaded,
  '1': 42>,
 'MS017': <Epochs |  41 events 

In [10]:
epochs_all_subj['MS007']

0,1
Number of events,32
Events,1: 32
Time range,0.000 – 9.998 sec
Baseline,off


## Epoching
https://mne.tools/dev/generated/mne.make_fixed_length_epochs.html#mne.make_fixed_length_epochs
https://mne.tools/dev/auto_tutorials/epochs/60_make_fixed_length_epochs.html#sphx-glr-auto-tutorials-epochs-60-make-fixed-length-epochs-py

In [15]:
bp_lfp_all_subj

{'MS001': <Raw | MS001_bp_ref_ieeg.fif, 61 x 210001 (420.0 s), ~97.8 MB, data loaded>,
 'MS003': <Raw | MS003_bp_ref_ieeg.fif, 80 x 210001 (420.0 s), ~128.3 MB, data loaded>,
 'MS006': <Raw | MS006_bp_ref_ieeg.fif, 72 x 210001 (420.0 s), ~115.5 MB, data loaded>,
 'MS007': <Raw | MS007_bp_ref_ieeg.fif, 117 x 210001 (420.0 s), ~187.6 MB, data loaded>,
 'MS008': <Raw | MS008_bp_ref_ieeg.fif, 81 x 210001 (420.0 s), ~129.9 MB, data loaded>,
 'MS010': <Raw | MS010_bp_ref_ieeg.fif, 78 x 210001 (420.0 s), ~125.1 MB, data loaded>,
 'MS012': <Raw | MS012_bp_ref_ieeg.fif, 54 x 210001 (420.0 s), ~86.6 MB, data loaded>,
 'MS014': <Raw | MS014_bp_ref_ieeg.fif, 34 x 210001 (420.0 s), ~54.5 MB, data loaded>,
 'MS016': <Raw | MS016_bp_ref_ieeg.fif, 65 x 210001 (420.0 s), ~104.3 MB, data loaded>,
 'MS017': <Raw | MS017_bp_ref_ieeg.fif, 58 x 210001 (420.0 s), ~93.0 MB, data loaded>,
 'MS018': <Raw | MS018_bp_ref_ieeg.fif, 93 x 210001 (420.0 s), ~149.2 MB, data loaded>,
 'MS019': <Raw | MS019_bp_ref_ieeg.

In [None]:
#### make epochs 

#loop through subj
#extract df 
#need to clip cont data into desired length raw.crop(0, 60).pick(picks=["mag", "stim"]).load_data() ?



epochs = mne.make_fixed_length_epochs(raw, duration=30, preload=False)
#event_related_plot = epochs.plot_image(picks=["MEG 1142"]) visualization

In [None]:
num_channels = len(clean_raw.ch_names) # get number of channels
freq_master = np.zeros((77,)) # frequency - initialize list to store frequency for each channel
psd_master = np.zeros((num_channels, 77)) # power spectral density - initialize list to store psd for each channel
ch_names = clean_raw.ch_names # get channel names

for channel in range(num_channels):
    # compute power spectral density for each channel using welch method, median average, hann window, 2 second window, 50% overlap 
    freq, psd = compute_spectrum_welch(sig=clean_raw._data[channel,:], fs=sr, window='hann',avg_type='median',nperseg=sr*2,f_range=(2,40),noverlap=(sr*2)/2)
    psd_master[channel,:] = psd 
    if channel == 0: # only need to get frequency once
        freq_master[:,] = freq
    else:
        continue # continue to next channel

In [None]:
# set some windows of interest 

buf = 1.0 # this is the buffer before and after that we use to limit edge effects for TFRs

IED_args = {'peak_thresh':4,
           'closeness_thresh':0.25, 
           'width_thresh':0.2}

# evs = ['gamble_start', 'feedback_start', 'baseline_start']
evs = {'gamble_start': [-1.0, 0.5],
       'feedback_start': [-0.5, 1.5],
       'baseline_start': [0, 0.75]}


# add behavioral times of interest 
for subj_id in subj_ids:
    # Set paths
    load_path = f'{base_dir}/projects/guLab/Salman/EMU/{subj_id}/neural/Day1'
    save_path = f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/neural/Day1'

    epochs_all_evs = {f'{x}': np.nan for x in evs}
    for event in evs.keys():
        pre = evs[event][0]
        post = evs[event][1]
        fixed_baseline = None
        behav_times = learn_df[(learn_df.participant==subj_id)][event]

        epochs = lfp_preprocess_utils.make_epochs(load_path=f'{save_path}/bp_ref_ieeg.fif', 
                                                  slope=slopes[subj_id][0], offset=offsets[subj_id][0], 
                                                  behav_name=event, behav_times=behav_times,
                                                  ev_start_s=pre, ev_end_s=post, buf_s=1, downsamp_factor=None, IED_args=IED_args)


        epochs_all_evs[event] = epochs
        epochs_all_evs[event].save(f'{save_path}/{event}-epo.fif', overwrite=True)


In [None]:
#things to consider:

#do we want to add a buffer?
#should epochs be overlapping?

#epoch cleaning:
# # 1/19/24: Let's also look for noisy epochs, which can persist even after notch filtering the whole session. 
# notch_freqs = [60] 
# notch_ranges = np.concatenate([np.arange(x-3,x+4) for x in notch_freqs]).flatten().tolist()
# noisy_epochs_dict = {f'{x}':np.nan for x in ev_epochs.ch_names}

# for ch_ in ev_epochs.ch_names:
#     sig = ev_epochs.get_data(picks=[ch_])[:,0,:]
#     noise_evs = []
#     # compute the power spectrum
#     freqs, psds = compute_spectrum(sig, ev_epochs.info['sfreq'], method='welch', avg_type='median')

#     for event in np.arange(sig.shape[0]):
#         # Find peaks in the power spectrum
#         peaks, _ = find_peaks(np.log10(psds[event, :]), prominence=1.)  # Adjust threshold as needed
#         peak_freqs = freqs[peaks]
#         # do they intersect with noise ranges?
#         intersection = set(peak_freqs) & set(notch_ranges)
#         if intersection:
#             noise_evs.append(event)
#     ev_epochs.metadata.loc[noise_evs, ch_] = 'noise'
#     # noisy_epochs_dict[ch_] = noise_evs




## Data Formatting + Saving for Connectivity Analyses

In [None]:
#https://mne.tools/dev/auto_tutorials/preprocessing/30_filtering_resampling.html#tut-filter-resample
epochs.load_data().filter(l_freq=8, h_freq=12)
alpha_data = epochs.get_data(copy=False)

In [None]:
# https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.spectral_connectivity_epochs.html

In [None]:
# # choose the connectivity metric
# metric = 'coh'  # ['coh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', 'gc_tr']

# band_dict = {'theta' : [2, 9],
#                'beta' : [14, 30]}

# freqs = np.logspace(*np.log10([2, 200]), num=30)
# n_cycles = np.floor(np.logspace(*np.log10([3, 10]), num=30))

# buf_ms = 1000
# n_surr = 500

# source_region = 'HPC'

# # source_region = 'dmPFC'
# # ['OFC', 'ACC', 'AMY']
# # source_region = 'OFC'
# # ['ACC', 'AMY']
# # source_region = 'ACC'
# # ['AMY']

# # iterate through target regions
# # analysis_evs = ['baseline_start']
# analysis_evs = ['feedback_start', 'recog_time']


# for target_region in ['OFC', 'AMY']: 
#     conn_group_data = []

#     for subj_id in subj_ids:

#         filepath = f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/scratch/PSI'
#         if not os.path.exists(filepath):
#             os.makedirs(filepath)

#         for event in analysis_evs:
#             save_path = f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/neural/{day}'
#             epochs_reref = mne.read_epochs(f'{save_path}/{event}-epo.fif', preload=True) 

#             # Get electrode df 
#             elec_df = pd.read_csv(f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/Day1_reref_elec_df')

#             # construct the seed-to-target mapping based on your rois - matters most for PSI as this is directional 
#             seed_target_df = pd.DataFrame(columns=['seed', 'target'], index=['l', 'r'])
#             for hemi in ['l', 'r']:
#                 seed_target_df['seed'][hemi] = np.where(elec_df.loc[elec_df.hemisphere==hemi, 'salman_region'] == source_region)[0]
# #                 seed_target_df['target'][hemi] = np.where(elec_df.loc[elec_df.hemisphere==hemi, 'salman_region'].isin(target_regions))[0]    

#                 seed_target_df['target'][hemi] = np.where(elec_df.loc[elec_df.hemisphere==hemi, 'salman_region'] == target_region)[0]    

#             seed_target_df = seed_target_df[
#                         (seed_target_df['seed'].map(lambda d: len(d) > 0)) & (seed_target_df['target'].map(lambda d: len(d) > 0))]


#             # for cond in conditions: 
#             for hemi in ['l', 'r']:
#                 # first determine if ipsi connectivity is even possible; if not, move on
#                 if hemi not in seed_target_df.index.tolist():
#                     continue
#                 else:
#                     seed_to_target = seed_target_indices(
#                         seed_target_df['seed'][hemi],
#                         seed_target_df['target'][hemi])
                    
                
                    
#                 # These epochs only grab those images that were later involved in memory recognition:
#                 if event in ['gamble_start', 'feedback_start', 'baseline_start']:
#                     epochs_to_analyze = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_gamble').reset_index(drop=True).trials_gamble.values - 1
#                 elif event == 'recog_time': 
#                     epochs_to_analyze = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_mem').reset_index(drop=True).trials_mem.values - 1
                    
#                 trunc_epochs = epochs_reref[epochs_to_analyze.astype(int)]
                
#                 avg_dim = 'time'
                
#                 # NOTE: If I compute any epoch-avged connectivity measures, I need to NAN bad epochs BEFORE hand 
#                 npairs = len(seed_to_target[0])
                
# #                 for pair in range(npairs): 
# #                     source_bad_epochs  = list(np.where(trunc_epochs.metadata[trunc_epochs.ch_names[seed_to_target[0][pair]]].notnull())[0])
# #                     target_bad_epochs  = list(np.where(trunc_epochs.metadata[trunc_epochs.ch_names[seed_to_target[1][pair]]].notnull())[0])
# #                     bad_epochs = np.array(source_bad_epochs+target_bad_epochs)
# #                     channels_involved = np.concatenate([np.unique(x) for x in seed_to_target])
# #                     trunc_epochs._data[bad_epochs[:, np.newaxis], channels_involved, :] = 0
                
#                 pwise_dfs = []
#                 for band in band_dict.keys():
                
#                     pwise = oscillation_utils.compute_connectivity(trunc_epochs, 
#                                                                band = band_dict[band], 
#                                                                metric = metric, 
#                                                                indices = seed_to_target, 
#                                                                freqs = freqs, 
#                                                                n_cycles = n_cycles,
#                                                                buf_ms = buf_ms, 
#                                                                n_surr=n_surr,
#                                                                avg_over_dim=avg_dim)

#                     if avg_dim == 'epoch':
#                         pwise = pwise[:, int(buf_ms  * (epochs_reref.info['sfreq']/1000)):-int(buf_ms  * (epochs_reref.info['sfreq']/1000))]

#                     for pair in range(npairs): 
#                         pwise_df = pd.DataFrame(columns=['participant', 'roi1', 'roi2', 'hemi', 'pair_label', 'metric', 'event', 'conn'])
#                         pwise_df['conn'] = pwise[:, pair] 
#                         pwise_df['participant'] = subj_id
#                         pwise_df['age'] = subj_df[subj_df.MSSMCode==subj_id].Age.values[0]
#                         pwise_df['sex'] = subj_df[subj_df.MSSMCode==subj_id].Sex.str.strip().values[0]

#                         pwise_df['roi1'] = source_region
#                         pwise_df['roi2'] = target_region
#                         pwise_df['hemi'] = hemi 
#                         pwise_df['metric'] = metric 
#                         pwise_df['event'] = event
#                         pwise_df['band'] = band
#                         pwise_df['pair_label'] = f'{epochs_reref.ch_names[seed_to_target[0][pair]]}:{epochs_reref.ch_names[seed_to_target[1][pair]]}'
#                         # pwise_df['trials'] = np.arange(1, pwise.shape[0]+1)
#                         if event in ['gamble_start', 'feedback_start', 'baseline_start']:
#                             pwise_df['trials_gamble'] = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_gamble').reset_index(drop=True).trials_gamble.values
#                             pwise_df = pwise_df.merge(combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_gamble').reset_index(drop=True), 
#                                                       on=['participant', 'trials_gamble'])
#                         elif event == 'recog_time':
#                             pwise_df['trials_mem'] = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_mem').reset_index(drop=True).trials_mem.values
#                             pwise_df = pwise_df.merge(combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_mem').reset_index(drop=True), 
#                                                       on=['participant', 'trials_mem'])

#                         pwise_df['zrpe'] = (pwise_df.rpe - np.nanmean(pwise_df.rpe)) / (2*np.nanstd(pwise_df.rpe))
#                         pwise_df['zpm'] = (pwise_df.DPRIME - np.nanmean(pwise_df.DPRIME)) / (2*np.nanstd(pwise_df.DPRIME))


#                         pwise_df['good_epoch'] = 1
#                         # NOTE: HOW TO HANDLE BAD EPOCHS? 
#                         # find all the bad epochs across both channels and add to the dataframe under pair label 
#                         # for ch_ix in progress_bar: 
#                         source_bad_epochs  = list(np.where(epochs_reref.metadata[epochs_reref.ch_names[seed_to_target[0][pair]]].notnull())[0])
#                         target_bad_epochs  = list(np.where(epochs_reref.metadata[epochs_reref.ch_names[seed_to_target[1][pair]]].notnull())[0])
#                         bad_epochs = np.array(source_bad_epochs+target_bad_epochs)
#                         if event in ['gamble_start', 'feedback_start', 'baseline_start']:
#                             pwise_df.loc[pwise_df['trials_gamble'].isin(bad_epochs), 'good_epoch'] = 0
#                         elif event == 'recog_time':
#                             pwise_df.loc[pwise_df['trials_mem'].isin(bad_epochs), 'good_epoch'] = 0
                            
#                         # aggregate
#                         pwise_dfs.append(pwise_df)
#                 pwise_dfs = pd.concat(pwise_dfs)
#                 conn_group_data.append(pwise_dfs)

#     all_pairs_df = pd.concat(conn_group_data)
#     all_pairs_df.to_csv(f'/sc/arion/projects/guLab/Salman/EphysAnalyses/Connectivity/{source_region}_{target_region}_{metric}_df', index=False)

