# MS009 - preprocessing

In this notebook we are going to walk through a single patient example. There are probably some patient-specific stuff in here that might change with other patients. Should be able to demonstrate the usage of different functions from the toolbox.

1. Load raw data (.edf in this notebook) using mne

2. Add in electrode information

3. Notch filter line noise and cleaning out bad channels 

4. Re-reference the data (bipolar re-referencing)


Must read guides: 

https://www.sciencedirect.com/science/article/pii/S1053811922005559


In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import mne
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from scipy.stats import zscore, linregress
import pandas as pd
from mne.preprocessing.bads import _find_outliers
import re

In [None]:
import sys
sys.path.append('/Users/christinamaher/Documents/Github/LFPAnalysis')

In [None]:
from LFPAnalysis import lfp_preprocess_utils, sync_utils, analysis_utils, nlx_utils

In [None]:
subject = 'MS009'

## Load raw data (.edf in this notebook) using mne

It's a good idea to setup a sensible directory structure like below. Note that all my data lives on '/sc/arion' which is Minerva. 


mne: https://mne.tools/stable/index.html

In [None]:
base_dir = '/Users/christinamaher/Documents/Gem_Hunters/data/ieeg/' # this is the root directory for most un-archived data and results 

save_dir = f'{base_dir}/{subject}'  # save intermediate results in the 'work' directory
    
# I have saved most of my raw data in the 'projects directory'
behav_dir = save_dir
neural_dir = save_dir
anat_dir = save_dir
edf_files = glob(f'{neural_dir}/{subject}.edf')

load the electrophysiology data

In [None]:
mne_data = mne.io.read_raw_edf(edf_files[0], preload=True)

In [None]:
mne_data.ch_names

In [None]:
# Sanity check
plt.plot(mne_data._data[0,:4999])
plt.title("Raw iEEG, electrode 0, samples 0-4999")
plt.show()

In [None]:
# Sanity check the photodiode
trig_ix = mne_data.ch_names.index('DC1') # either named DC1 or Research
plt.plot(mne_data._data[trig_ix])
plt.title("Photodiode")
plt.show()

In [None]:
# Save out the photodiode channel separately
mne_data.save(f'{save_dir}/photodiode.fif', picks='DC1', overwrite=True)

In [None]:
# Drop the photodiode channel
mne_data.drop_channels(['DC1'])

## Add in electrode information

In [None]:
new_name_dict = {x:x.replace(" ", "").lower() for x in mne_data.ch_names}
mne_data.rename_channels(new_name_dict)

In [None]:
# Load the electrode localization data and add it in
csv_files = glob(f'{anat_dir}/{subject}_labels.csv')
elec_locs = pd.read_csv(csv_files[0])

# Sometimes there's extra columns with no entries: 
elec_locs = elec_locs[elec_locs.columns.drop(list(elec_locs.filter(regex='Unnamed')))]
elec_locs

In [None]:
list(elec_locs.label)

The electrode names read out of the edf file do not always match those 
in the pdf (used for localization). This could be error on the side of the tech who input the labels, 
or on the side of MNE reading the labels in. Usually there's a mixup between lowercase 'l' and capital 'I'.

Sometimes, there's electrodes on the pdf that are NOT in the MNE data structure... let's identify those as well. 


In [None]:
anat_names = list(elec_locs.label.str.lower())
sum([ch not in mne_data.ch_names for ch in anat_names]) #if there are no missing channels, sum = 0. if sum >0, find the missing elecs
print([ch for ch in mne_data.ch_names if ch not in anat_names ]) #print extra channels in mne_data.ch_names and make sure none of them are neural channels (will be EEG etc.)

In [None]:
new_mne_names, unmatched_names, unmatched_seeg = lfp_preprocess_utils.match_elec_names(mne_data.ch_names, elec_locs.label)

In [None]:
unmatched_seeg #make sure there are no unmatched names


In [None]:
new_name_dict = {x:y for (x,y) in zip(mne_data.ch_names, new_mne_names)}
new_name_dict #make sure this passes the eye test 


So we retun a new list of channel names for the mne data structure as well as a list of channels in the localization csv which are not found in the mne structure. Make sure that unmatched_seeg does not factor into any referencing schemes later - it's not in the MNE data

In [None]:
# Rename the mne data according to the localization data
mne_data.rename_channels(new_name_dict)

In [None]:
right_seeg_names = [i for i in mne_data.ch_names if i.startswith('r')]
left_seeg_names = [i for i in mne_data.ch_names if i.startswith('l')]
print(f'We have a total of {len(left_seeg_names)} left sEEG and {len(right_seeg_names)} right sEEG electrodes')
print(f'We have a total of {len(left_seeg_names) + len(right_seeg_names)} sEEG electrodes')


In [None]:
sEEG_mapping_dict = {f'{x}':'seeg' for x in left_seeg_names+right_seeg_names}
mne_data.set_channel_types(sEEG_mapping_dict)


In [None]:
# Drop random chans? 
drop_chans = list(set(mne_data.ch_names)^set(left_seeg_names+right_seeg_names))
mne_data.drop_channels(drop_chans) #number of chans should = number of seegs 

In [None]:
# make montage (convert mm to m)

montage = mne.channels.make_dig_montage(ch_pos=dict(zip(elec_locs.label, 
                                                        elec_locs[['mni_x', 'mni_y', 'mni_z']].to_numpy(dtype=float)/1000)),
                                        coord_frame='mni_tal')

mne_data.set_montage(montage, match_case=False, on_missing='warn')

## Notch filter line noise and resample to 500 hz


We want to remove the line noise (60 Hz and harmonics in US data, 50 Hz and harmonics in EU data). 

To do so, we use a band-stop filter that removes a narrow band of frequencies. 

Maybe eventually we don't want to use filters, especially if interested in ERPs: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6456018/

In [None]:
mne_data.info['line_freq'] = 60
# Notch out 60 Hz noise and harmonics 
mne_data.notch_filter(freqs=(60, 120, 180, 240))

In [None]:
#all patients should be resampled to 500 Hz
resample_sr = 500
mne_data.resample(sfreq=resample_sr, npad='auto', n_jobs=-1)

## Signal Cleaning 
Methods:
- Use manual bad channel detection for **bipolar** referencing
    - bipolar referencing completely ignores channels labeled as 'bad' so do not remove unless absolutely necessary
    - remove ch if there is massive artifact in a channel not in the rest of probe 
    - if entire probe shows same artifact try to keep
- Remove additional channels as needed
- This process is iterative depending on how TFRs look.

In [None]:
ch_names = list(elec_locs.label.str.lower())
pattern = '[0-9]'
ch_names_no_num = [re.sub(pattern, '', i) for i in ch_names]
probe_names = np.unique(ch_names_no_num)
probe_names

In [None]:
probe_ch_counts = {} #need this to select channel number for visualization
for p in probe_names:
    c = ch_names_no_num.count(p)
    probe_ch_counts[p] = c

In [None]:
probe_ch_counts

In [None]:
lfp_preprocess_utils.detect_bad_elecs(mne_data, sEEG_mapping_dict) # likely inaccurate, just use it to get a general idea of problematic channels

In [None]:
 %matplotlib notebook
fig = mne_data.plot(start=0, duration=1000, n_channels=40, scalings=mne_data._data.max()/50)

In [None]:
mne_data.info['bads']

In [None]:
mne_data.info #sanity check that bads info saved

## Bipolar re-referencing 

If you're like me, you find the concept of re-referencing somewhat confusing. Isn't the data recorded relative to a ground and reference in the EMU (https://ahleighton.github.io/OE-ephys-course/EEA/theoryday3.html)? 

It is, but we do digital re-referencing of the recorded signal to clean up any remaining shared noise. 

**Re-referencing should be an EXTREMELY conscious choice as it changes the LFP signal dramatically!** In our case, we choose to do local white-matter re-referencing because electrodes in white matter should be fairly stable (low-variance) and not contain local, slow oscillations of interest. 

Now, let's use the localization data to determine the gray vs. white matter electrodes. 
Then, let's re-reference each gray matter electrode to the closest and most low-amplitude white matter electrode. 

Make sure 'bad' electrodes are not used in the re-referencing. Same with unmatched seeg electrodes (not present in the mne data structure).

In [None]:
# Re-reference neural data
mne_data_bp_reref = lfp_preprocess_utils.ref_mne(mne_data=mne_data, 
                                              elec_path=anat_file, 
                                              method='bipolar', 
                                              site='MSSM')
mne_data_bp_reref

In [None]:
mne_data_bp_reref.ch_names

## Examine bp re-referenced data

In [None]:
 %matplotlib notebook
fig = mne_data_bp_reref.plot(start=0, duration=1000, n_channels=40, scalings=mne_data_bp_reref._data.max())

In [None]:
mne_data_bp_reref.compute_psd().plot()

## Save data


In [None]:
mne_data_bp_reref.save(f'{neural_dir}{subj_id}/bp_ref_ieeg.fif',overwrite=True)
mne_data.save(f'{neural_dir}{subj_id}/raw_ieeg.fif',overwrite=True)

# Epoching + TFRs
- Check whether data is ready to be analyzed - if this step shows noise then steps above should be repeated
- Align photodiode to behavior 
- Epoch data (and mark bad epochs)
- Baseline data
- Visualize TFRs 

In [None]:
#remove mne_data from environment to save memory 
del mne_data, mne_data_bp_reref

# Photodiode alignment

In [None]:
# load the photodiode and resample to match the neural data
photodiode = mne.io.read_raw_fif(f'{base_dir}/photodiode.fif', preload=True)
resample_sr = 500
photodiode.resample(sfreq=resample_sr, npad='auto', n_jobs=-1)

In [None]:
# load behavior data and save timestamp(s) of interest as variable 
behav_df = pd.read_csv(f'{behav_dir}/{subject}_clean.csv')

choice_ts = behav_df['choice_ts'].copy()

# add column of ITI ts
behav_df['iti_ts'] = behav_df['choice_ts'].copy() + 1500
iti_ts = behav_df['iti_ts'].copy()

In [None]:
# plot photodiode and choice timestamps before alignment
%matplotlib qt
plt.plot(photodiode._data[0])
plt.xlabel("Time")
plt.ylabel("V")
plt.title("Photodiode")

zeros = np.array([0.05] * len(choice_ts))
x_ts = choice_ts
y_ts = zeros.T
plt.scatter(x_ts,y_ts,color='red')

plt.show()

In [None]:
# Sample time-frequency data as an example
time_frequency_data = photodiode._data[0]

# Define the threshold values
threshold_min = -0.15

photodiode_deflected = []

# Detect sequential peaks
for value in time_frequency_data:
    if value > threshold_min:
        photodiode_deflected.append(1)
    elif value < threshold_min:
        photodiode_deflected.append(-1)

In [None]:
def sum_sequential_values(arr):
    if not arr:
        return []

    result = []
    current_sum = arr[0]

    for i in range(1, len(arr)):
        if arr[i] == arr[i - 1]:
            current_sum += arr[i]
        else:
            result.append(current_sum)
            current_sum = arr[i]

    # Append the last calculated sum
    result.append(current_sum)

    return result

In [None]:
lengths = sum_sequential_values(photodiode_deflected)

In [None]:
def assign_peak_indices(lengths):
    photodiode_indices = []
    
    for l in lengths:
        if l < 0:
            l_temp = l * -1 # flip sign
            drop_numbers = np.ones(l_temp, dtype=int)
            photodiode_indices.append(drop_numbers)
        elif (l > 800) | (l < 700):
            drop_numbers = np.ones(l, dtype=int)
            photodiode_indices.append(drop_numbers)
        elif (1 < 800) & (l > 700): 
            keep_numbers = np.zeros(l, dtype=int)
            photodiode_indices.append(keep_numbers)
        else:
            print("error")
    
    return np.concatenate(photodiode_indices)

In [None]:
photodiode_indices = assign_peak_indices(lengths)

In [None]:
new_diode = photodiode._data[0].copy()
peaks_to_exclude = [index for index, value in enumerate(photodiode_indices) if value == 1]
new_diode[peaks_to_exclude] = np.min(photodiode._data[0]) # make these all super small

photodiode_final = photodiode.copy()
photodiode_final._data[0] = new_diode

In [None]:
# plot photodiode
plt.plot(photodiode_final._data[0])
plt.xlabel("Time")
plt.ylabel("V")
plt.title("Photodiode")
plt.show()

In [None]:
slope, offset = sync_utils.synchronize_data(choice_ts, 
                                            photodiode_final, 
                                            smoothSize=11, windSize=10, height=0.7)
print(slope,offset) 

In [None]:
# visualize updated choice ts 
choice_ts = choice_ts * slope + offset

%matplotlib qt
plt.plot(photodiode_final._data[0])
plt.xlabel("Time")
plt.ylabel("V")
plt.title("Photodiode")

zeros = np.array([-0.01] * len(choice_ts))
x_ts = choice_ts * 500
y_ts = zeros.T
plt.scatter(x_ts,y_ts,color='red')

plt.show()

In [None]:
# visualize updated iti 
iti_ts = iti_ts * slope + offset

%matplotlib qt
plt.plot(photodiode_final._data[0])
plt.xlabel("Time")
plt.ylabel("V")
plt.title("Photodiode")

zeros = np.array([-0.01] * len(iti_ts))
x_ts = iti_ts * 500
y_ts = zeros.T
plt.scatter(x_ts,y_ts,color='red')

plt.show()

In [None]:
behav_df['choice_ts'] = [(x*slope + offset) for x in behav_df['choice_ts']]
behav_df['iti_ts'] = [(x*slope + offset) for x in behav_df['iti_ts']]

## Epoch Data
- Epoch neural data into trial epochs 
- Add behavioral data to epochs metadata
- Save epochs
- Baseline + decompose data into TFRs
- Plot + save TFRs (examine quality)

In [None]:
# IED removal requires that we set some parameters for IED detection. 
# 1. peak_thresh: how many stds should they exceed the baseline by? 
# 2. closeness_thresh: how close should they be allowed to be (in sec) to other candidate IEDs? 
# 3. width_thresh: how wide should they have to be (in sec)?

# Defaults:
IED_args = {'peak_thresh':4,
           'closeness_thresh':0.25, 
           'width_thresh':0.2}

In [None]:
# Create a dictionary with your event name (matching your dataframe), and the time-window for the event
evs = {'choice_ts': [-1.5, 1.5], 
       'iti_ts': [0.0, 0.5]} 

In [None]:
epochs_all_evs = {f'{x}': np.nan for x in evs}

In [None]:
for event in evs.keys():
    # Make the epochs. 
    ev_epochs = lfp_preprocess_utils.make_epochs(load_path=f'{neural_dir}{subj_id}/bp_ref_ieeg.fif', 
                 slope=slope, offset=offset,
                 behav_name=event, behav_times=behav_df[event].values, 
                 ev_start_s=evs[event][0], ev_end_s=evs[event][1], buf_s = 1.0, IED_args=IED_args) #1.0 buf unsaved

    epochs_all_evs[event] = ev_epochs

epochs_all_evs

### Add behavioral data to metadata 

In [None]:
behav_df.columns

In [None]:
# select params of interest 
behav_params = ["learned","reward",
               "correct","rd","tf","condition",
                "phiEVcombo","phiRPEcombo",
                "phiEVrd","phiRPErd",
               "chosen_f_rd","chosen_f_ird"]

In [None]:
for event in evs.keys():

    event_metadata = epochs_all_evs[event].metadata.copy()
    
    #independent vars
    for param in behav_params: 
        event_metadata[param] = behav_df[param].tolist()

    epochs_all_evs[event].metadata = event_metadata 

In [None]:
epochs_all_evs[event].metadata

### Save raw epoched data 

In [None]:
for event in evs.keys():
    epochs_all_evs[event].save(f'{neural_dir}{subj_id}/bp_epoch_{event}.fif', overwrite=True)

### Baseline + Decompose into TFRs

In [None]:
# Explicitly define a list of analysis events and the baseline event. Should correspond to the dict
analysis_evs = ['choice_ts']
baseline_ev = 'iti_ts'
evs = {evs = {'choice_ts': [-1.5, 1.5], 
       'iti_ts': [0.0, 0.5]}}

In [None]:
# Set some spectrogram parameters 
freqs = np.logspace(*np.log10([2, 200]), num=30)
n_cycles = np.floor(np.logspace(*np.log10([3, 10]), num=30))
sr = 500.0 
buf = 1.0
buf_ix = int(buf*sr)

In [None]:
epochs_all_baseline = {} #removed unecessary 'decisiononset' empty key

In [None]:
#baseline epoch - 
event = 'iti_ts'
epochs = epochs_all_evs[event]

good_chans = [x for x in epochs.ch_names if x not in epochs.info['bads']]
picks = [x for x in good_chans]

pow_struct = np.nan * np.ones([epochs._data.shape[0], 
                       epochs._data.shape[1], len(freqs), 
                       epochs._data.shape[-1]])

for ch_ix in np.arange(epochs._data.shape[1]): 
    ch_data = epochs._data[:, ch_ix:ch_ix+1, :]
    bad_epochs  = np.where(epochs.metadata[epochs.ch_names[ch_ix]].notnull())[0]
    good_epochs = np.delete(np.arange(ch_data.shape[0]), bad_epochs)
    ch_data = np.delete(ch_data, bad_epochs, axis=0)
    ch_pow = mne.time_frequency.tfr_array_morlet(ch_data, sfreq=epochs.info['sfreq'], 
                                        freqs=freqs, n_cycles=n_cycles, zero_mean=False, 
                                        use_fft=True, output='power', n_jobs=1)

    pow_struct[good_epochs, ch_ix, :, :] = ch_pow[:, 0, :, :]

temp_pow = mne.time_frequency.EpochsTFR(epochs.info, pow_struct, 
                                        epochs.times, freqs)
temp_pow.crop(tmin=evs[event][0], tmax=evs[event][1])

epochs_all_baseline[event] = temp_pow

In [None]:
epochs_all_baseline

In [None]:
power_epochs = {}

In [None]:
event = 'choice_ts'

epochs = epochs_all_evs[event]

# Let's make sure we only do this for good channels
good_chans = [x for x in epochs.ch_names if x not in epochs.info['bads']]
picks = [x for x in good_chans]

pow_struct = np.nan * np.ones([epochs._data.shape[0], 
                       epochs._data.shape[1], len(freqs), 
                       epochs._data.shape[-1]])

for ch_ix in np.arange(epochs._data.shape[1]): 
    ch_data = epochs._data[:, ch_ix:ch_ix+1, :]
    bad_epochs  = np.where(epochs.metadata[epochs.ch_names[ch_ix]].notnull())[0] 
    good_epochs = np.delete(np.arange(ch_data.shape[0]), bad_epochs)
    ch_data = np.delete(ch_data, bad_epochs, axis=0) #this is where bad epochs for ch are deleted!!
    ch_pow = mne.time_frequency.tfr_array_morlet(ch_data, sfreq=epochs.info['sfreq'], 
                                        freqs=freqs, n_cycles=n_cycles, zero_mean=False, 
                                        use_fft=True, output='power', n_jobs=1)

    pow_struct[good_epochs, ch_ix, :, :] = ch_pow[:, 0, :, :]

temp_pow = mne.time_frequency.EpochsTFR(epochs.info, pow_struct, 
                                        epochs.times, freqs)

temp_pow.crop(tmin=evs[event][0], tmax=evs[event][1])


baseline_corrected_power = lfp_preprocess_utils.baseline_trialwise_TFR(data=temp_pow.data, 
                                                  baseline_mne=epochs_all_baseline['iti_ts'], 
                                                  mode='zscore', 
                                                  trialwise=False, ## make sure this is FALSE! More robust baselining method if set to TRUE.
                                                  baseline_only=True)


zpow = mne.time_frequency.EpochsTFR(epochs.info, baseline_corrected_power, 
                                temp_pow.times, freqs)

zpow.metadata = epochs_all_evs[event].metadata

power_epochs[event] = zpow


In [None]:
power_epochs['choice_ts']

In [None]:
#conda install -c conda-forge h5io must install to save tfr data
power_epochs['choice_ts'].save(f'{save_dir}/bp_pow_epochs-tfr.h5', overwrite=True)

## Plot TFRs

In [None]:
####add directory for TFRs
tfr_dir = f'{base_dir}/{subject}/tfr/'
os.makedirs(tfr_dir,exist_ok = True) #added so you don't have to manually make subject folders in clean_data
date = datetime.date.today().strftime('%m%d%Y')

print(date)
# mne_data_bp_reref = mne.io.read_raw_fif(f'{neural_dir}{subj_id}/bp_ref_ieeg.fif',preload=True)

In [None]:
event = 'choice_ts'
yticks = [4, 12, 30, 60, 90, 120, 150, 180, 200]
good_ch = [x for x in power_epochs[event].ch_names if '-' in x]
save_path = tfr_dir

print(offset)


for ch in good_ch:
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    times = power_epochs[event].times
    plot_data = np.nanmean(np.nanmean(power_epochs[event].copy().pick_channels([ch]).data, axis=0), axis=0)

    im = ax.imshow(plot_data,
            extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation='Bicubic',
            aspect='auto', origin='lower', cmap='RdBu_r',vmin = -np.nanmax(np.abs(plot_data)), vmax = np.nanmax(np.abs(plot_data)))
    ax.set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency',title=f'{ch} Encoding')
    ax.yaxis.set_tick_params(labelsize=8)
    fig.colorbar(im, ax=ax)
    plt.savefig(f'{save_path}/{ch}_{date}_bp_ref.png', format='png', metadata=None,
    bbox_inches=None, pad_inches=0.1,
    facecolor='auto', edgecolor='auto',
    backend=None)