# BLAES Units Preprocessing

This notebook contains code for preprocessing single unit data collected during the BLAES encoding experiments. The raw 30 kHz data is loaded from the `.ns6`/`.nev` or `.mat` files, and subsequently restructured as `pd.DataFrames`. Basic data features are summarized (e.g., # units) and exported as `.csv` or `.txt` files. Stimulation pulses are detected from the NSP Sync channel, and used to define peri-stimulation epochs (± 1) for analysis. Plotting functions enable visualization of individual unit waveforms, a raster plot for the entire recording session, and peri-stimulation raster plots (within-channel).

---

> *Author: Justin Campbell (justin.campbell@hsc.utah.edu)*  
> *Version: 06/24/2024*


## 1. Setup Notebook
Import libraries, utilities, and packages. Configure notebook figure settings.

In [None]:
# Import Libraries
import os
import sys
import glob
import mat73
import datetime
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.io import loadmat
from scipy.signal import find_peaks, butter, filtfilt, find_peaks, correlate, decimate
from scipy.fftpack import fft, ifft

# Import Blackrock Python Utilities
sys.path.append('/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits/Code/Blacrock-Python-Utilities')
import brpylib

%matplotlib inline
# %config InlineBackend.figure_format='retina'
%config InlineBackend.figure_format='svg'

## 2. Load Data

In [None]:
def loadNEV(nev_path):
    '''
    Load .nev file and extract spike events and waveforms.
    
    Arguments:
        nev_path (str): path to .nev file
    Returns:
        events (pd.DataFrame): spike events
        waveforms (np.ndarray): spike waveforms
    '''
    
    # Open .nev file
    nev_file = brpylib.NevFile(nev_path)
    
    # Extract data
    print('\nLoading .nev data...')
    nev_data = nev_file.getdata()['spike_events']
    
    # Close .nev file
    nev_file.close()
    
    # Extract waveforms
    waveforms = nev_data['Waveforms']
    
    # Convert spike events to pd.DataFrame
    spikes = nev_data.copy()
    spikes.pop('Waveforms', None)
    events = pd.DataFrame.from_dict(spikes)
    
    # Define valid unit labels
    valid_unit_labels = np.arange(1,17) # 0 = unclassified, 255 = noise, 1-16 = valid units
    events = events[events['Unit'].isin(valid_unit_labels)]
    valid_unit_idxs = events.index.values
    waveforms = waveforms[valid_unit_idxs]
    
    # Reindex
    events.reset_index(drop = True, inplace = True)
    
    return events, waveforms

In [None]:
def loadNSX(nsx_path):
    '''
    Load .nsx file and extract data and header.
    
    Arguments:
        nsx_path (str): path to .nsx file
        
    Returns:
        nsx_data (np.ndarray): raw neural data (30 kHz)
        header (dict): header (recording info)
    '''
    # Open .nsx file
    nsx_file = brpylib.NsxFile(nsx)

    # Extract data & header
    print('\nLoading .nsx data...')
    nsx_data = nsx_file.getdata(full_timestamps = True)['data'][0]
    header = nsx_file.extended_headers

    # Close .nev file
    nsx_file.close()
    
    return nsx_data, header

In [None]:
def loadSortedMat(mat_path):
    '''
    Load sorted .mat file and extract data and chan_labels.
    
    Arguments:
        mat_path (str): path to *sorted.mat file
        
    Returns:
        mat_dict (dict): micro channels with sorted units
        chan_labels (list): micro channel labels (all channels)
    '''
    # load the mat file
    mat_dict = loadmat(mat_path, simplify_cells=True)

    # load WashU unit labels
    washU_labels = pd.read_csv('/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits/WashU Micro Labels.csv')
    lead_labels = washU_labels[washU_labels['pID'] == pID[:-2]]['Label'].values

    # create a list of all channels
    chan_labels = []
    for i in range(len(lead_labels)):
        for ii in range(1, 9):
            chan_labels.append(lead_labels[i] + str(ii))

    # remove variables with '__' in them
    mat_dict = {k: v for k, v in mat_dict.items() if '__' not in k}

    # remove Sync channel
    if len(mat_dict.keys()) == 17:
        mat_dict.pop('Channel17', None)
    elif len(mat_dict.keys()) == 9:
        mat_dict.pop('Channel09', None)

    # rename keys with channel names
    mat_dict = {chan_labels[i]: v for i, (k, v) in enumerate(mat_dict.items())}

    # remove channels that are empty (no units)
    mat_dict = {k: v for k, v in mat_dict.items() if v.size > 0}
    
    return mat_dict, chan_labels

In [None]:
# Paths
data_path = '/Volumes/Hippocampus/BLAESUnits/Data_30k'

# Find relevant .nev, .nsx, and .mat files
nev_files = glob.glob(root_dir = data_path, pathname = '**/*.nev', recursive = True)
nsx_files = glob.glob(root_dir = data_path, pathname = '**/*.ns6', recursive = True)
sorted_mat_files = glob.glob(root_dir = data_path, pathname = '**/*sorted.mat', recursive = True)
raw_mat_files = glob.glob(root_dir = data_path, pathname = '**/*raw.mat', recursive = True)

# Sort based on filename
nev_files.sort()
nsx_files.sort()
sorted_mat_files.sort()
raw_mat_files.sort()

# Combine to create list of sessions
sessions = [nev_files[i].split('/')[0] for i in range(len(nev_files))]
for i in range(len(sorted_mat_files)):
    sessions.append(sorted_mat_files[i].split('/')[0])

# Data params
fs = 30000 # sampling rate

# Display
print('Found data for %s sessions:' % (len(sessions)))
print('----------------------------')
for i in range(len(sessions)):
    print('%s: %s' % (i, sessions[i]))

In [None]:
# Select the session file to analyze
file_idx = int(input('Enter the session number to analyze:'))
fileName = sessions[file_idx].split('/')[-1]
fType = fileName[0:3]

# Print pID
pID = fileName
proj_path = '/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits'
save_path = os.path.join(proj_path, 'Results', pID)
if not os.path.exists(save_path):
    os.mkdir(save_path)
    print('Processing session %s...' % pID)
else:
    print('Session %s already processed. Overwriting...' % pID)

# Utah files
if fType == 'UIC':
    # Locate files
    nev = os.path.join(data_path, sessions[file_idx], sessions[file_idx] + '.nev')
    nsx = os.path.join(data_path, sessions[file_idx], sessions[file_idx] + '.ns6')
        
    # Load .nev and .nsx files
    events, waveforms = loadNEV(nev)
    cont_data, header = loadNSX(nsx)

    # Adjust for Blackrock's convention of multiplying uV by 4
    waveforms = waveforms / 4

    # Get chan IDs and chan labels
    chan_IDs = [header[i]['ElectrodeID'] for i in range(len(header))]
    chan_labels = [header[i]['ElectrodeLabel'] for i in range(len(header))]

    # Create a dictionary to map chan_IDs to chan_labels
    chan_map = {}
    for idx, val in enumerate(chan_IDs):
        chan_map[val] = chan_labels[idx] # adjust for Python indexing

    # Map electrode labels to channel numbers
    events['Channel'] = events['Channel'].map(chan_map)
    
    # Get sync & PD channels
    sync = cont_data[chan_labels.index('Sync'),:]
    PD = cont_data[chan_labels.index('PD'),:]
    
    # # Delete cont_data to save memory
    # del cont_data


# WashU files
elif fType == 'BJH':
    # Locate files
    sorted_mat = os.path.join(data_path, sessions[file_idx], sessions[file_idx] + '_sorted.mat')
    trigger_mat = os.path.join(data_path, sessions[file_idx], 'Stim_trigger.mat')
    raw_mat = os.path.join(data_path, sessions[file_idx], sessions[file_idx] + '_raw.mat')
    
    # Load sorted.mat and raw.mat files
    mat_dict, chan_labels = loadSortedMat(sorted_mat)
    sync = loadmat(trigger_mat, simplify_cells = True)['stimTrigger']
    try:
        cont_data = loadmat(raw_mat, simplify_cells = True)['signals']
    except:
        cont_data = mat73.loadmat(raw_mat)['signals'] # if saved with v7.3 compression
    
    # Get events and waveforms
    events = []
    waveforms = []
    for k, v in mat_dict.items():
        df = pd.DataFrame(mat_dict[k])
        df['Channel'] = [chan_labels[int(df[0][i]-1)] for i in range(df.shape[0])]
        df['Unit'] = [int(df[1][i]) for i in range(df.shape[0])]
        df['TimeStamps'] = [int(df[2][i]*fs) for i in range(df.shape[0])]
        events.append(df[['TimeStamps', 'Unit', 'Channel']])
        waveforms.append(df.iloc[:, 6:-3])
    events = pd.concat(events)
    waveforms = np.array(pd.concat(waveforms))
    
    # Adjust uV
    waveforms = waveforms / 4
    
else:
    print('Filetype not recognized.')

In [None]:
# Get micro lead (bundle) labels
micro_chans = [x for x in chan_labels if x.startswith('m')]
active_micro_chans = [x for x in events['Channel'].unique().tolist()]
micro_leads = set([x[:-1] for x in micro_chans])
active_micro_leads = set([x[:-1] for x in active_micro_chans])

In [None]:
# Export events and waveforms
events.to_csv(os.path.join(save_path, 'Events.csv'))
pd.DataFrame(waveforms).to_csv(os.path.join(save_path, 'Waveforms.csv')) # waveforms have different num samples (Blackrock: 48, Nihon Koden: 32)

## 3. Summarize Dataset
Generate summary statistics to characterize the dataset (e.g., # units detected).  
Export features in `.csv` and `.txt` files.

In [None]:
def summarizeData():
    print('Session: %s' %pID)
    print('- Recording Duration: %.1f min' %(len(sync) / fs / 60))
    print('- Leads (Bundles) w/ Units: %s (%.0f%%)' %(len(active_micro_leads), (len(active_micro_leads) / len(micro_leads))*100))
    print('- Channels w/ Units: %s (%.0f%%)' %(len(active_micro_chans), (len(active_micro_chans) / len(micro_chans))*100))
    print('- Units Detected: %s' %str(summaryDF.shape[0]))
    print('- Waveforms (Mean ± SD): %.1f ± %.1f' %(summaryDF['Waveforms'].mean(), summaryDF['Waveforms'].std()))
    print('\n')
    print('Processed: ' + str(datetime.datetime.now()))

In [None]:
# Group events by Channel and Unit, count the number of waveforms for each unit
summaryDF = events.groupby(['Channel', 'Unit']).count()
summaryDF.rename(columns={'TimeStamps': 'Waveforms'}, inplace=True)
unit_labels = summaryDF.reset_index()['Channel'].astype(str) + '-' + summaryDF.reset_index()['Unit'].astype(str)

# Display data summary
summarizeData()
summaryDF

In [None]:
%%capture cap
summarizeData()

In [None]:
# Save descriptives
summaryDF.to_csv(os.path.join(save_path, 'WaveformCounts.csv'))
with open(os.path.join(save_path, 'Summary.txt'), 'w') as f:
    f.write(str(cap))

## 4. Create Epochs
Epochs are created using the `sync` and `PD` channels within the recording to define the stim/no-stim trials, respectively. Because of the way in which the data were recorded, as well as noise/artifact present in the `sync`/`PD` channels, this section entails a considerable amount of manual tuning to ensure epochs are correctly aligned.

### 4.1 Load Trial Info

In [None]:
trial_info = pd.read_csv(os.path.join(proj_path, 'TrialInfo', pID + '_TrialInfo.csv'))
trial_info[['Onset_SC', 'Onset_PD']] = trial_info[['Onset_SC', 'Onset_PD']].apply(lambda x: x * 15) # convert from 2->30 kHz

### 4.2 Detect Sync Channel Onsets

In [None]:
def getSyncOnsets(sync):
    '''
    Detect sync pulses within the 30 kHz recording.
    
    Arguments:
        sync (np.ndarray): sync channel data
        
    Returns:
        stim_onsets (np.ndarray): array of stim_onset indices
    '''

    # Define threshold for sync channel
    if fType == 'UIC':
        thresh = 15000
    elif fType == 'BJH':
        thresh = 15000

    # Find stim pulses using the following criteria: 
    # (1) sync > threshold, (2) transitioning from below to above threshold, (3) > 5s since last crossing
    stim_onsets = []
    syncBool = (sync > thresh).astype(int)
    counter = 0
    for i in range(len(syncBool)-1):
        if (syncBool[i] == 1) & (syncBool[-i] == 0) & (counter > (fs*5)):
            stim_onsets.append(i)
            counter = 0
        counter += 1
    
    return stim_onsets

In [None]:
# Detect sync (stim) pulses
sync_onsets = getSyncOnsets(sync)

# Add sync_onsets to trial_info
special_cases = ['BJH02503', 'BJH02703', 'BJH04502']
stim_trial_counter = 0
if pID not in special_cases:
    for i in range(trial_info.shape[0]):
        if trial_info.loc[i, 'Condition'] == 1:
            trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter]
            stim_trial_counter += 1
        else:
            trial_info.loc[i, 'Onset_Sync'] = np.nan
else:
    if pID == 'BJH02503':
        # Sync pulses not present for 1st few trials, but LFP shows stim
        for i in range(trial_info.shape[0]):
            if trial_info.loc[i, 'Condition'] == 1:
                if stim_trial_counter < 6:
                    trial_info.loc[i, 'Onset_Sync'] = trial_info.loc[i, 'Onset_SC'] - 18000 # empirical correction
                else:
                    trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter-6]
                stim_trial_counter += 1
            else:
                trial_info.loc[i, 'Onset_Sync'] = np.nan
                
    if pID == 'BJH02703':
        # Many sync pulses not present
        fixed = False
        for i in range(trial_info.shape[0]):
            if trial_info.loc[i, 'Condition'] == 1:
                if (stim_trial_counter >= 27) and (stim_trial_counter < 40):
                    trial_info.loc[i, 'Onset_Sync'] = trial_info.loc[i, 'Onset_SC'] - 10000 # empirical correction
                    fixed = True
                else:
                    if not fixed:
                        trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter+1]
                    else:
                        trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter-13]
                stim_trial_counter += 1
                
    elif pID == 'BJH04502':
        # Missed detection of one pulse
        fixed = False
        for i in range(trial_info.shape[0]):
            if trial_info.loc[i, 'Condition'] == 1:
                if stim_trial_counter == 33:
                    trial_info.loc[i, 'Onset_Sync'] = trial_info.loc[i, 'Onset_SC'] - 13000 # empirical correction
                    fixed = True
                else:
                    if not fixed:
                        trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter]
                    else:
                        trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter-1]
                stim_trial_counter += 1
            else:
                trial_info.loc[i, 'Onset_Sync'] = np.nan

### 4.3 Define Stim/No-Stim Epochs

#### 4.3.1 Align Blackrock to BCI2000 w/ PD

In [None]:
# Utah data does not have 1:1 mapping to timing in BCI2000, thus this step includes an empirical re-alignment to a common feature
# present in both datasets, the photodiode (PD) signal. 

# Get PD onsets from trial_info
PD_onsets = trial_info['Onset_PD'].values

# Find peaks in PD channel (30 kHz); some trial onsets are +PD vs. -PD spike (depending on experiment)
if pID in ['UIC20230201']:
    peaks_30k, _ = find_peaks(PD*-1, height = 500, distance = 150000)
else:
    if fType == 'UIC':
        peaks_30k, _ = find_peaks(PD, height = 500, distance = 150000)

# Manual tuning (empirical correction from visual inspection)
if pID == 'UIC20221301':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(2,47), np.arange(48, 93), np.arange(94, 139), np.arange(140, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20221501':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(2,47), np.arange(48, 93), np.arange(94, 139), np.arange(140, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20221701':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(3,48), np.arange(49, 94), np.arange(98, 143), np.arange(144, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]

elif pID == 'UIC20230201':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(2,82), np.arange(83, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20230601':
    matched_peaks = np.empty(len(trial_info)) * np.nan # initialize as nans; (missing PD for 17 trials)
    keep_peaks = np.concatenate((np.arange(3,67), np.arange(68, 148), np.arange(150, 230), np.arange(231, len(peaks_30k)))) # realign
    matched_peaks[16:] = peaks_30k[keep_peaks]
    
elif pID == 'UIC20230701':
    peaks_30k = peaks_30k[3:] # remove spurious PDs
    keep_peaks = np.concatenate((np.arange(0,12), np.arange(13,81), np.arange(82,162), np.arange(163,243), np.arange(244,324))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20230801':
    peaks_30k = peaks_30k[4:] # remove spurious PDs
    keep_peaks = np.concatenate((np.arange(0,80), np.arange(81,161), np.arange(162,242), np.arange(243,323))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20231101':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(3,83), np.arange(84, 164), np.arange(165, 245), np.arange(247, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20231401':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(2,122), np.arange(123, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20240101':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(3,123), np.arange(124, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
# O/S cohort (must remove scrambled images)
if pID in ['UIC20221301', 'UIC20221501', 'UIC20221701']:
    PD_block_idxs = [np.arange(0, 40), np.arange(40, 80), np.arange(80, 120), np.arange(120, 160)]
    peaks_block_idxs = [np.arange(0, 45), np.arange(45, 90), np.arange(90, 135), np.arange(135, 180)]
    
    # Get peaks and onsets for each block
    matched_peaks_noscramb = []
    for i in range(len(PD_block_idxs)):
        peaks_block = matched_peaks[peaks_block_idxs[i]]
        PD_onsets_block = PD_onsets[PD_block_idxs[i]]

        # Adjust PD onsets to align with 1st peak in block ("blocks" are separate .dat files stitched together)
        PD_onsets_adj = PD_onsets_block + (peaks_block[0] - PD_onsets_block[0])

        # Get the values in peaks closest to PD_onsets_adj
        peaks_adj_block = []
        for i in range(len(PD_onsets_adj)):
            peaks_adj_block.append(matched_peaks[np.argmin(np.abs(matched_peaks - PD_onsets_adj[i]))])
        matched_peaks_noscramb.append(peaks_adj_block)
        
    matched_peaks = np.array([item for sublist in matched_peaks_noscramb for item in sublist])

In [None]:
# # Visualize
# print('Detected Peaks: %i' %len(peaks_30k))
# print('Matched Peaks: %i/%i' %(len(matched_peaks), len(trial_info)))

# fig, ax = plt.subplots(figsize = (90, 3))
# plt.plot(PD)
# for x in peaks_30k:
#     plt.axvline(x = x, color = 'r', linestyle = '-')
# plt.plot(matched_peaks, PD[matched_peaks], 'x')
# plt.plot(sync_onsets, PD[sync_onsets], 'o')
# plt.show()

In [None]:
# Overwrite Utah trial_info PD onsets using matched_peaks
if fType == 'UIC':
    trial_info['Onset_PD'] = matched_peaks

#### 4.3.2 Use Sync & PD to Define Stim/No-Stim Epochs

In [None]:
stim_epochs = []
nostim_epochs = []

for i in range(trial_info.shape[0]):
    # stim trial
    if trial_info.loc[i, 'Condition'] == 1:
        stim_epochs.append(np.arange(trial_info.loc[i, 'Onset_Sync'] - (4 * fs), trial_info.loc[i, 'Onset_Sync'] + (2*fs)))
    # no stim trial
    else:
        if np.isnan(trial_info['Onset_PD'][i]):
            nostim_epochs.append(np.zeros(fs*6)) # populate empty epochs with zeros (cannot use np.nan in pd.DataFrame)
        else:
            nostim_epochs.append(np.arange(trial_info.loc[i, 'Onset_PD'] - (4 * fs), trial_info.loc[i, 'Onset_PD'] + (2*fs)))
        
stim_epochs = pd.DataFrame(stim_epochs).astype(int).T
nostim_epochs = pd.DataFrame(nostim_epochs).astype(int).T

stim_epochs.to_csv(os.path.join(save_path, 'StimEpochs.csv'))
nostim_epochs.to_csv(os.path.join(save_path, 'NoStimEpochs.csv'))

### 4.4 Epoch 30 kHz Continuous Data

In [None]:
stim_epochs_30k = []
nostim_epochs_30k = []

for i in range(stim_epochs.shape[1]):
    stim_epochs_30k.append(cont_data[:, stim_epochs[i].values])
    
for i in range(nostim_epochs.shape[1]):
    nostim_epochs_30k.append(cont_data[:, nostim_epochs[i].values])
    
stim_epochs_30k = np.array(stim_epochs_30k)
nostim_epochs_30k = np.array(nostim_epochs_30k)

## 5. Extract Peri-Stimulation Spike Counts
Calculate the # spikes in the pre-/during-/post-stim windows (1s each). Export results as `PSCounts.csv`.

In [None]:
def getPSCounts(epochs, unit):
    '''
    Get peri-stimulation counts for a given unit (pre/during/post), exports a pd.DataFrame.
    
    Arguments:
        epochs (np.ndarray): array of peri-stimulation epochs (indices)
        unit (str): unit label
        
    Returns:
        PSCountsDF (pd.DataFrame): peri-stimulation counts
    '''
    
    # Get unit data
    unitDF = events.copy()
    unitDF['Chan-Unit'] = unitDF['Channel'].astype(str) + '-' + unitDF['Unit'].astype(str)
    unitDF = unitDF[unitDF['Chan-Unit'] == unit]
    unitDF['TimeStamps'] = unitDF['TimeStamps'].astype('int')

    # Store counts for pre/during/post stim
    preISICounts = 0
    preCounts = 0
    duringCounts = 0
    postCounts = 0

    # Plot raster
    for i in range(epochs.shape[1]):
        epoch_start = epochs[i].min()
        stim_start = epoch_start + (4*fs)
        epoch_end = epochs[i].max()
        epochDF = unitDF[(unitDF['TimeStamps'] >= epoch_start) & (unitDF['TimeStamps'] <= epoch_end)]
        epochDF = epochDF.reset_index(drop = True)
        for spike in range(epochDF.shape[0]):
            spike_time = (epochDF.iloc[spike]['TimeStamps'] - stim_start) / fs # set stim at t0, convert samples -> sec
            if spike_time < -3:
                preISICounts += 1
            elif spike_time > -1 and spike_time < 0:
                preCounts += 1
            elif spike_time > 0 and spike_time < 1:
                duringCounts += 1
            elif spike_time > 1:
                postCounts += 1
                
    # Construct DF of counts
    PSCountsDF = pd.DataFrame({'PreISI': [preISICounts], 'Pre': [preCounts], 'During': [duringCounts], 'Post': [postCounts]})
    PSCountsDF['pID'] = pID
    PSCountsDF['Channel'] = unit.split('-')[0]
    PSCountsDF['Unit'] = unit.split('-')[1]
    
    return PSCountsDF

In [None]:
# Get peri-stimulation counts for each unit
stim_counts = []
nostim_counts = []

for i in range(len(unit_labels)):
    stim_counts.append(getPSCounts(epochs = stim_epochs, unit = unit_labels[i]))
stim_counts = pd.concat(stim_counts)
stim_counts['Condition'] = 'Stim'

for i in range(len(unit_labels)):
    nostim_counts.append(getPSCounts(epochs = nostim_epochs, unit = unit_labels[i]))
nostim_counts = pd.concat(nostim_counts)
nostim_counts['Condition'] = 'No-Stim'

PSCountsDF = pd.concat([stim_counts, nostim_counts])
PSCountsDF.reset_index(inplace = True, drop = True)
PSCountsDF.to_csv((os.path.join(save_path, 'PSCounts.csv')))

## 6. Generate Figures

In [None]:
# Preview color palette
sns.color_palette('flare', summaryDF.shape[0])

### 6.1 Unit Waveforms
Use `plotUnitWaveforms()` and `plotUnitWaveforms_RAW()` to generate separate figures for each channel showing the waveforms of the units detected.

In [None]:
def plotUnitWaveforms(chan, show = True, save = False):
    '''
    Plot unit waveforms.
    
    Arguments:
        chan (str): channel label
        show (bool): show figure
        save (bool): save figure
        
    Returns:
        None
    '''
    
    # Parset events for chan data
    unitDF = events[events['Channel'] == chan]
    n_units = len(unitDF['Unit'].unique())
    unit_idxs = unitDF.index.values
    
    # Parse waveforms using unit indices
    unit_waveforms = pd.DataFrame(waveforms[unit_idxs])
    
    # Add unit labels
    unit_waveforms['Unit'] = unitDF['Unit'].reset_index(drop=True)
    
    # Track waveform counts for legend
    waveform_counts = unitDF['Unit'].value_counts().values.tolist()
    waveform_counts = [str(x) for x in waveform_counts]
    
    # Melt DataFrame for plotting in Seaborn
    unit_waveforms = unit_waveforms.melt(id_vars = ['Unit'], var_name = 'Time', value_name = 'Voltage')
    
    # Figure parameters
    fig, ax = plt.subplots(1, 1, figsize = (5,2.5))
    palette = ['#ff6e61', '#ffb84d', '#6d9dc5', '#5e4b8b']
    
    # Adjust time (samples -> ms)
    unit_waveforms['Time'] = (unit_waveforms['Time'] / fs) * 1000
    
    # Plot average waveform ± SD (shaded)
    sns.lineplot(x = 'Time', y = 'Voltage', hue = 'Unit', data = unit_waveforms, palette = palette, linewidth = 2, errorbar = 'sd', ax = ax)
    
    # Figure aeshetics
    if fType == 'UIC':
        plt.xlim([0, 1.5])
        plt.xticks([0, 0.5, 1, 1.5])
    elif fType == 'BJH':
        plt.xlim([0, 1])
        plt.xticks([0, 0.5, 1])
    plt.xlabel('Time (ms)', fontsize = 'large')
    plt.yticks([-100, -50, 0, 50, 100])
    plt.ylabel('Voltage ($\mu$V)', fontsize = 'large')
    plt.title(chan, fontsize = 'x-large', fontweight = 'bold')
    legend_handles, _= ax.get_legend_handles_labels()
    ax.legend(legend_handles, waveform_counts, title = 'WFs', loc = 'lower right', fontsize = 'x-small', title_fontsize = 'x-small')
    sns.despine(top = True, right = True)
    
    # Export figure
    if save:
        if not os.path.exists(os.path.join(save_path, 'Units')):
            os.mkdir(os.path.join(save_path, 'Units'))
        plt.savefig(os.path.join(save_path, 'Units', chan + '.pdf'), dpi = 1500, bbox_inches = 'tight')
    
    if show == False:
        plt.close()

In [None]:
def plotUnitWaveforms_RAW(chan, show = True, save = False):
    '''
    Plot unit waveforms and visualize raw traces.
    
    Arguments:
        chan (str): channel label
        show (bool): show figure
        save (bool): save figure
        
    Returns:
        None
    '''
    
    # Parset events for chan data
    unitDF = events[events['Channel'] == chan]
    n_units = len(unitDF['Unit'].unique())
    palette = ['#ff6e61', '#ffb84d', '#6d9dc5', '#5e4b8b']
    
    for i in unitDF['Unit'].unique():
        unitDFsub = unitDF[unitDF['Unit'] == i]
        unit_idxs = unitDFsub.index.values
        
        # Parse waveforms using unit indices
        unit_waveforms = pd.DataFrame(waveforms[unit_idxs])
        
        # Figure parameters
        fig, ax = plt.subplots(1, 1, figsize = (5,2.5))

        # Add unit and spike labels
        unit_waveforms['Unit'] = unitDFsub['Unit'].reset_index(drop=True)
        unit_waveforms['Spike'] = unit_waveforms.index
        
        # Track waveform counts for legend
        waveform_counts = unitDFsub['Unit'].value_counts().values[0]
        # waveform_counts = [str(x) for x in waveform_counts]
        
        # Melt DataFrame for plotting in Seaborn
        unit_waveforms = unit_waveforms.melt(id_vars = ['Unit', 'Spike'], var_name = 'Time', value_name = 'Voltage')
        
        # Adjust time (samples -> ms)
        unit_waveforms['Time'] = (unit_waveforms['Time'] / fs) * 1000
        
        for wf in range(unit_waveforms['Spike'].max()):
            plt.plot(unit_waveforms[unit_waveforms['Spike'] == wf]['Time'], unit_waveforms[unit_waveforms['Spike'] == wf]['Voltage'], lw = 0.1, color = '#D3D3D3', alpha = 0.5, axes = ax)
            
        # Plot average waveform
        sns.lineplot(x = 'Time', y = 'Voltage', data = unit_waveforms, color = palette[i-1], linewidth = 3, errorbar = None, ax = ax, label = waveform_counts)    
        
        # Figure aeshetics
        if fType == 'UIC':
            plt.xlim([0, 1.5])
            plt.xticks([0, 0.5, 1, 1.5])
        elif fType == 'BJH':
            plt.xlim([0, 1])
            plt.xticks([0, 0.5, 1])
        plt.xlabel('Time (ms)', fontsize = 'large')
        plt.yticks([-100, -50, 0, 50, 100])
        plt.ylabel('Voltage ($\mu$V)', fontsize = 'large')
        plt.title((chan + '-' + str(i)), fontsize = 'x-large', fontweight = 'bold')
        legend_handles, legend_labels = ax.get_legend_handles_labels()
        ax.legend(legend_handles, legend_labels, title = 'WFs', loc = 'lower right', fontsize = 'x-small', title_fontsize = 'x-small')
        sns.despine(top = True, right = True)
        
        # Export figure
        if save:
            if not os.path.exists(os.path.join(save_path, 'Units')):
                os.mkdir(os.path.join(save_path, 'Units'))
            plt.savefig(os.path.join(save_path, 'Units', 'RAW_' + chan + '-' + str(i) + '.pdf'), dpi = 1500, bbox_inches = 'tight')
        
        if show == False:
            plt.close()

In [None]:
# Plot unit waveforms for each channel (& export as .pdfs)
for chan in events['Channel'].unique():
    plotUnitWaveforms(chan = chan, show = False, save = True)
    plotUnitWaveforms_RAW(chan = chan, show = False, save = True)

### 6.2 Unit Spike Rasters
Use `plotFullRaster()` to generate a raster plot showing unit activity across the entire recording.

In [None]:
def plotFullRaster(show = True, save = False):
    '''
    Plot raster of all units for duration of recording.
    
    Arguments:
        show (bool): show figure
        save (bool): save figure
        
    Returns:
        None
    '''

    # Figure parameters
    if summaryDF.shape[0] > 5:
        fig = plt.figure(figsize = (10, summaryDF.shape[0] * 0.35), constrained_layout=True)
    else:
        fig = plt.figure(figsize = (10, 3), constrained_layout=True)
    gs = gridspec.GridSpec(ncols=10, nrows=1, figure=fig)
    ax1 = fig.add_subplot(gs[0:9])
    ax2 = fig.add_subplot(gs[9])
    palette = sns.color_palette('flare_r', n_colors = summaryDF.shape[0])

    # Sort events by Channel
    rasterDF = events.copy().sort_values(by = ['Channel', 'Unit'], ascending = False)

    # Create iterator and list to track unit label and position
    i = 0
    unit_labels = []

    # Loop through each channel, each unit, and each spike to plot spike times
    for chan in rasterDF['Channel'].unique():
        chanDF = rasterDF[rasterDF['Channel'] == chan]
        for unit in chanDF['Unit'].unique():
            unitDF = chanDF[chanDF['Unit'] == unit]
            unit_labels.append(unitDF['Channel'].unique()[0] + '-' + str(unitDF['Unit'].unique()[0]))
            i += 1
            for spike in range(unitDF.shape[0]):
                spike_time = unitDF.iloc[spike]['TimeStamps']
                spike_time = spike_time / fs / 60 # convert from samples -> min
                ax1.vlines(spike_time, i - 0.4, i + 0.4, linewidth = 0.25, colors = palette[i-1])
                
    # Add barplot for number of waveforms
    waveCountDF = summaryDF.copy().reset_index()
    waveCountDF['Chan-Unit'] = waveCountDF['Channel'].astype(str) + '-' + waveCountDF['Unit'].astype(str)
    waveCountDF = waveCountDF.sort_values('Chan-Unit', ascending = False)
    sns.barplot(x = 'Waveforms', y = 'Chan-Unit', data = waveCountDF, ax = ax2, palette = palette)

    # Raster plot aesthetics
    ax1.set_ylim([0, summaryDF.shape[0] + 1])
    ax1.set_yticks(np.arange(1, summaryDF.shape[0] + 1), unit_labels)
    ax1.set_ylabel('')
    ax1.set_xlim([0, (len(sync) / fs / 60)])
    ax1.set_xlabel('Time (min)', fontsize = 'large')
    
    # Bar plot aesthetics
    ax2.set_ylim([-1, summaryDF.shape[0]])
    ax2.set_yticks(np.arange(0, summaryDF.shape[0]), unit_labels)
    ax2.set_ylabel('')
    ax2.set_yticklabels([])
    ax2.set_xlim([0, np.ceil(summaryDF['Waveforms'].max()/100) * 100])
    ax2.set_xticks([0, np.ceil(summaryDF['Waveforms'].max()/100) * 100])
    ax2.set_xlabel('WFs', fontsize = 'large')
    
    # Figure aesthetics
    plt.suptitle(pID, fontweight = 'bold', fontsize = 'x-large')
    sns.despine(top = True, right = True)

    # Export figure
    if save:
        if not os.path.exists(os.path.join(save_path, 'Rasters')):
            os.mkdir(os.path.join(save_path, 'Rasters'))
        plt.savefig(os.path.join(save_path, 'Rasters', 'FullRaster.pdf'), dpi = 1500, bbox_inches = 'tight')

    if show == False:
        plt.close()

In [None]:
# Plot raster plot of all units during full recording (& export as .pdf)
plotFullRaster(show=False, save=True)

### 6.3 Peri-Stim Spike Rasters
Use `plotPSRaster()` and `plotPSRaster_Extended()`to generate a peri-stim raster plot for a unique unit.

In [None]:
def plotPSRaster(epochs, unit, color_idx, show = True, save = False):
    '''
    Plot unit raster for peri-stimulation epochs.
    
    Arguments:
        epochs (np.ndarray): array of peri-stimulation epochs (indices)
        unit (str): unit label
        color_idx (int): index of color in palette
        show (bool): show figure
        save (bool): save figure
        
    Returns:
        None
    '''

    # Figure parameters
    # fig, axes = plt.subplots(2, 1, figsize = (4, 4))
    fig = plt.figure(figsize = (4,4), constrained_layout = True)
    gs = gridspec.GridSpec(ncols=1, nrows=8, figure=fig)
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1:6])
    ax3 = fig.add_subplot(gs[6:8])
    
    palette = ['#ff6e61', '#ffb84d', '#6d9dc5', '#5e4b8b']

    # Get unit data
    unitDF = events.copy()
    unitDF['Chan-Unit'] = unitDF['Channel'].astype(str) + '-' + unitDF['Unit'].astype(str)
    unitDF = unitDF[unitDF['Chan-Unit'] == unit]
    unitDF['TimeStamps'] = unitDF['TimeStamps'].astype('int')
    n_trials = epochs.shape[1]
    
    # Get LFP from 30 kHz data, highpass filter, and plot
    nyquist = 0.5 * fs
    hpass_freq = 150
    hpass_cutoff = hpass_freq / nyquist
    b, a = butter(4, hpass_cutoff, btype = 'high')
    chan_idx = chan_labels.index(unit.split('-')[0])
    lfp_time = np.arange(-1, 2, 1/fs)
    lfp_data = filtfilt(b, a, stim_epochs_30k[:,chan_idx,(fs*3):], axis = 1)
    ax1.plot(lfp_time, np.mean(lfp_data, axis = 0), color = palette[color_idx], lw = 1)

    # Plot raster
    epochSpikes = []
    for i in range(n_trials):
        epoch_start = epochs[i].min()
        stim_start = epoch_start + (4*fs)
        epoch_end = epochs[i].max()
        epochDF = unitDF[(unitDF['TimeStamps'] >= epoch_start) & (unitDF['TimeStamps'] <= epoch_end)]
        epochDF = epochDF.reset_index(drop = True)
        epochDF['TimeAdj'] = (epochDF['TimeStamps'] - stim_start) / fs
        epochSpikes.append(epochDF)
        for spike in range(epochDF.shape[0]):
            spike_time = (epochDF.iloc[spike]['TimeStamps'] - stim_start) / fs # set stim at t0, convert samples -> sec
            y_pos = i + 1
            ax2.vlines(spike_time, y_pos - 0.4, y_pos + 0.4, linewidth = 2, color = palette[color_idx])
            
    # Plot binned FR
    spikesDF = pd.concat(epochSpikes)
    spikesDF = spikesDF.reset_index(drop = True)
    binSize = 0.1 # time (s)
    bins = np.arange(-1, 2.1, binSize)
    maxSpikes = np.histogram(spikesDF['TimeAdj'], bins = bins)[0].max()
    ps_spike_win = spikesDF.copy()
    ps_spike_win = ps_spike_win[ps_spike_win['TimeAdj'] >= -1]
    ps_spike_win = ps_spike_win[ps_spike_win['TimeAdj'] <= 2]
    sns.histplot(ps_spike_win['TimeAdj'], bins = bins, kde = True, kde_kws = {'bw_adjust': 0.2, 'cut': 3, 'clip': [-1,2]}, color = palette[color_idx], ax = ax3)

    # Shade times where BLA was stimulated
    for ax in [ax1, ax2, ax3]:
        ax.axvspan(0, 1, color = 'grey', alpha = 0.1, zorder = -10)
        ax.axvline(0, color = 'k', linestyle = '--', lw = 0.5)
        ax.axvline(1, color = 'k', linestyle = '--', lw = 0.5)
        ax.set_xlim([-1, 2.05])
        sns.despine(top = True, right = True)

    # Figure aesthetics
    ax1.set_xticks([-1, 0, 1, 2], ['', '', '', ''])
    ax1.set_yticks([])
    ax1.set_xlim([-1, 2.01])
    ax2.set_xlabel('')
    ax2.set_xticks([-1, 0, 1, 2], ['', '', '', ''])
    ax2.set_ylabel('Trial', fontsize = 'x-large', labelpad= 10)
    ax2.set_ylim([0, n_trials+1])
    ax2.set_xlim([-1, 2.01])
    ax2.set_yticks([1, n_trials], ['1', str(n_trials)], fontsize = 'medium')
    ax2.set_xlabel('')
    ax3.set_xlabel('Time (s)', fontsize = 'x-large')
    ax3.set_xticks([-1, 0, 1, 2], ['-1', '0', '1', '2'], fontsize = 'medium')
    ax3.set_ylabel('FR (Hz)', fontsize = 'x-large', labelpad = 7)
    ax3.set_yticks([0, maxSpikes], [0, np.round((maxSpikes / n_trials) * 10, 1)], fontsize = 'medium')
    ax3.set_xlim([-1, 2.01])
    ax1.set_title(unit, fontsize = 'x-large', fontweight = 'bold', pad = 15)
    sns.despine(top = True, right = True, left = True, ax = ax1)
    

    # # Export figure
    if save:
        if not os.path.exists(os.path.join(save_path, 'Rasters')):
            os.mkdir(os.path.join(save_path, 'Rasters'))
        plt.savefig(os.path.join(save_path, 'Rasters', unit + '_PSRaster.pdf'), dpi = 1500, bbox_inches = 'tight')

    if show == False:
        plt.close()

In [None]:
def plotPSRaster_Extended(epochs, unit, color_idx, show = True, save = False):
    '''
    Plot unit raster for extended (full) peri-stimulation epochs.
    
    Arguments:
        epochs (np.ndarray): array of peri-stimulation epochs (indices)
        unit (str): unit label
        color_idx (int): index of color in palette
        show (bool): show figure
        save (bool): save figure
        
    Returns:
        None
    '''

    # Figure parameters
    # fig, axes = plt.subplots(2, 1, figsize = (4, 4))
    fig = plt.figure(figsize = (8,4), constrained_layout = True)
    gs = gridspec.GridSpec(ncols=1, nrows=8, figure=fig)
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1:6])
    ax3 = fig.add_subplot(gs[6:8])
    
    palette = ['#ff6e61', '#ffb84d', '#6d9dc5', '#5e4b8b']

    # Get unit data
    unitDF = events.copy()
    unitDF['Chan-Unit'] = unitDF['Channel'].astype(str) + '-' + unitDF['Unit'].astype(str)
    unitDF = unitDF[unitDF['Chan-Unit'] == unit]
    unitDF['TimeStamps'] = unitDF['TimeStamps'].astype('int')
    n_trials = epochs.shape[1]
    
    # Get LFP from 30 kHz data, highpass filter, and plot
    nyquist = 0.5 * fs
    hpass_freq = 150
    hpass_cutoff = hpass_freq / nyquist
    b, a = butter(4, hpass_cutoff, btype = 'high')
    chan_idx = chan_labels.index(unit.split('-')[0])
    lfp_time = np.arange(-4, 2, 1/fs)
    lfp_data = filtfilt(b, a, stim_epochs_30k[:,chan_idx,:], axis = 1)
    ax1.plot(lfp_time, np.mean(lfp_data, axis = 0), color = palette[color_idx], lw = 1)

    # Plot raster
    epochSpikes = []
    for i in range(n_trials):
        epoch_start = epochs[i].min()
        stim_start = epoch_start + (4*fs)
        epoch_end = epochs[i].max()
        epochDF = unitDF[(unitDF['TimeStamps'] >= epoch_start) & (unitDF['TimeStamps'] <= epoch_end)]
        epochDF = epochDF.reset_index(drop = True)
        epochDF['TimeAdj'] = (epochDF['TimeStamps'] - stim_start) / fs
        epochSpikes.append(epochDF)
        for spike in range(epochDF.shape[0]):
            spike_time = (epochDF.iloc[spike]['TimeStamps'] - stim_start) / fs # set stim at t0, convert samples -> sec
            y_pos = i + 1
            ax2.vlines(spike_time, y_pos - 0.4, y_pos + 0.4, linewidth = 2, color = palette[color_idx])
            
    # Plot binned FR
    spikesDF = pd.concat(epochSpikes)
    spikesDF = spikesDF.reset_index(drop = True)
    binSize = 0.1 # time (s)
    bins = np.arange(-4, 2.1, binSize)
    maxSpikes = np.histogram(spikesDF['TimeAdj'], bins = bins)[0].max()
    hist = sns.histplot(spikesDF['TimeAdj'], bins = bins, kde = True, kde_kws = {'bw_adjust': 0.2, 'cut': 3, 'clip': [-4,2]}, color = palette[color_idx], ax = ax3)
    
    # get kde values from histplot on ax3
    KDE_X, KDE_Y = hist.get_lines()[0].get_data()
    KDE_DF = pd.DataFrame({'X': KDE_X, 'Y': KDE_Y})

    # Shade times where BLA was stimulated
    for ax in [ax1, ax2, ax3]:
        ax.axvspan(0, 1, color = 'grey', alpha = 0.1, zorder = -10)
        ax.axvline(-3, color = 'k', linestyle = '--', lw = 0.5)
        ax.axvline(0, color = 'k', linestyle = '--', lw = 0.5)
        ax.axvline(1, color = 'k', linestyle = '--', lw = 0.5)
        ax.set_xlim([-4, 2.05])
        sns.despine(top = True, right = True)
    
    # Add text above ax1
    ax1.text(0.055, 1.35, 'Pre-ISI', fontsize = 'medium', transform = ax1.transAxes)
    ax1.text(0.375, 1.35, 'Image', fontsize = 'medium', transform = ax1.transAxes)
    ax1.text(0.675, 1.35, 'Stim | No-Stim', fontsize = 'medium', transform = ax1.transAxes)
    ax1.text(0.875, 1.35, 'Post-ISI', fontsize = 'medium', transform = ax1.transAxes)
        
    # Figure aesthetics
    ax1.set_xticks([-4, -3, -2, -1, 0, 1, 2], ['', '', '', '', '', '', ''])
    ax1.set_yticks([])
    ax1.set_xlim([-4, 2.01])
    ax2.set_xlabel('')
    ax2.set_xticks([-4, -3, -2, -1, 0, 1, 2], ['', '', '', '', '', '', ''])
    ax2.set_ylabel('Trial', fontsize = 'x-large', labelpad= 10)
    ax2.set_ylim([0, n_trials+1])
    ax2.set_xlim([-4, 2.01])
    ax2.set_yticks([1, n_trials], ['1', str(n_trials)], fontsize = 'medium')
    ax2.set_xlabel('')
    ax3.set_xlabel('Time (s)', fontsize = 'x-large', labelpad = 20)
    ax3.set_xticks([-4, -3, -2, -1, 0, 1, 2], ['-4', '-3', '-2', '-1', '0', '1', '2'], fontsize = 'medium')
    ax3.set_ylabel('FR (Hz)', fontsize = 'x-large', labelpad = 7)
    ax3.set_yticks([0, maxSpikes], [0, np.round((maxSpikes / n_trials) * 10, 1)], fontsize = 'medium')
    ax3.set_xlim([-4, 2.01])
    ax1.set_title(unit, fontsize = 'x-large', fontweight = 'bold', pad = 35)
    sns.despine(top = True, right = True, left = True, ax = ax1)
    

    # Export figure
    if not os.path.exists(os.path.join(save_path, 'Rasters')):
        os.mkdir(os.path.join(save_path, 'Rasters'))
    plt.savefig(os.path.join(save_path, 'Rasters', unit + '_ExtendedPSRaster.pdf'), dpi = 1500, bbox_inches = 'tight')

    if show == False:
        plt.close()

In [None]:
# Plot peri-stim rasters for each unit (& export as .pdf)
for i in range(len(unit_labels)):
    plotPSRaster(epochs = stim_epochs, unit = unit_labels[i], color_idx = i, show = False, save = True)
    plotPSRaster_Extended(epochs = stim_epochs, unit = unit_labels[i], color_idx = i, show = False, save = True)

# For generating example rasters in Fig2
# plotPSRaster(epochs = stim_epochs, unit = 'mROFC7-1', color_idx = 1, show = True, save = True) # P6
# plotPSRaster(epochs = stim_epochs, unit = 'mROFC8-1', color_idx = 1, show = True, save = True) # P6
# plotPSRaster(epochs = stim_epochs, unit = 'mRHIP3-1', color_idx = 0, show = True, save = True) # P9
# plotPSRaster(epochs = stim_epochs, unit = 'mLAMY8-1', color_idx = 2, show = True, save = True) # P25

### 6.4 Epoch Validation
Use `plotEpochValidation()` to plot the sync channel for each epoch to verify alignment.

In [None]:
def plotEpochValidation(epochs, show = True, save = True):
    '''
    Plot sync data across epochs to validate alignment.
    
    Arguments:
        show (bool): show figure
        save (bool): save figure
        
    Returns:
        None
    '''
    
    n_rows = int(np.ceil(epochs.shape[1]/8))
    
    fig, axes = plt.subplots(n_rows, 8, figsize = (n_rows+2, n_rows+2), constrained_layout = True)
    for i in range(epochs.shape[1]):
        ax = axes.flatten()[i]
        ax.plot(sync[epochs[i]], color = 'k', linewidth = 0.5)
        ax.axvline(fs*4, color = 'r', linewidth = 1)
        ax.axvline(fs*5, color = 'r', linewidth = 1)
        ax.axvspan(fs*4, fs*5, color = 'r', alpha = 0.1)
        ax.set_xticks([fs*3, fs*4, fs*5, fs*6], ['-1', '0', '1', '2'], fontsize = 'x-small')
        ax.set_xlim([fs*3, fs*6])
        ax.set_yticks([])
        ax.set_title('Epoch: %s' %str(i+1), fontsize = 'x-small')
        sns.despine(right = True, top = True)

    # if axis number is greater than number of epochs, remove axis
    for i in range(epochs.shape[1], axes.size):
        fig.delaxes(axes.flatten()[i])
        
    # Export figure
    if save:
        if not os.path.exists(os.path.join(save_path, 'Validation')):
            os.mkdir(os.path.join(save_path, 'Validation'))
        plt.savefig(os.path.join(save_path, 'Validation', 'EpochsSync.pdf'), dpi = 1500, bbox_inches = 'tight')
    
    if show == False:
        plt.close()

In [None]:
plotEpochValidation(stim_epochs, show = False, save = True)
# plotEpochValidation(nostim_epochs, show = False, save = True)

In [None]:
print('Processed %s' %pID)