# BLAES Units MUA Preprocessing

This notebook contains code that mirrors the main preprocessing code (`BLAESUnitPrepro.ipynb`), but calculates threshold crossings in the MUA signal rather than spiking of isolated neurons.

---

> *Author: Justin Campbell (justin.campbell@hsc.utah.edu)*  
> *Version: 6/4/2025*


## 1. Setup Notebook
Import libraries, utilities, and packages. Configure notebook figure settings.

In [None]:
# Import Libraries
import os
import sys
import glob
import mat73
import datetime
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.io import loadmat
from scipy.signal import find_peaks, filtfilt, find_peaks, firwin
from scipy.fftpack import fft, ifft

# Import Blackrock Python Utilities
sys.path.append('/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits/Code/Blacrock-Python-Utilities')
import brpylib

%matplotlib inline
# %config InlineBackend.figure_format='retina'
%config InlineBackend.figure_format='svg'

## 2. Load Data

In [None]:
def loadNEV(nev_path):
    '''
    Load .nev file and extract spike events and waveforms.
    
    Arguments:
        nev_path (str): path to .nev file
    Returns:
        events (pd.DataFrame): spike events
        waveforms (np.ndarray): spike waveforms
    '''
    
    # Open .nev file
    nev_file = brpylib.NevFile(nev_path)
    
    # Extract data
    print('\nLoading .nev data...')
    nev_data = nev_file.getdata()['spike_events']
    
    # Close .nev file
    nev_file.close()
    
    # Extract waveforms
    waveforms = nev_data['Waveforms']
    
    # Convert spike events to pd.DataFrame
    spikes = nev_data.copy()
    spikes.pop('Waveforms', None)
    events = pd.DataFrame.from_dict(spikes)
    
    # Define valid unit labels
    valid_unit_labels = np.arange(1,17) # 0 = unclassified, 255 = noise, 1-16 = valid units
    events = events[events['Unit'].isin(valid_unit_labels)]
    valid_unit_idxs = events.index.values
    waveforms = waveforms[valid_unit_idxs]
    
    # Reindex
    events.reset_index(drop = True, inplace = True)
    
    return events, waveforms

In [None]:
def loadNSX(nsx_path):
    '''
    Load .nsx file and extract data and header.
    
    Arguments:
        nsx_path (str): path to .nsx file
        
    Returns:
        nsx_data (np.ndarray): raw neural data (30 kHz)
        header (dict): header (recording info)
    '''
    # Open .nsx file
    nsx_file = brpylib.NsxFile(nsx)

    # Extract data & header
    print('\nLoading .nsx data...')
    nsx_data = nsx_file.getdata(full_timestamps = True)['data'][0]
    header = nsx_file.extended_headers

    # Close .nev file
    nsx_file.close()
    
    return nsx_data, header

In [None]:
def loadSortedMat(mat_path):
    '''
    Load sorted .mat file and extract data and chan_labels.
    
    Arguments:
        mat_path (str): path to *sorted.mat file
        
    Returns:
        mat_dict (dict): micro channels with sorted units
        chan_labels (list): micro channel labels (all channels)
    '''
    # load the mat file
    mat_dict = loadmat(mat_path, simplify_cells=True)

    # load WashU unit labels
    washU_labels = pd.read_csv('/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits/WashU Micro Labels.csv')
    lead_labels = washU_labels[washU_labels['pID'] == pID[:-2]]['Label'].values

    # create a list of all channels
    chan_labels = []
    for i in range(len(lead_labels)):
        for ii in range(1, 9):
            chan_labels.append(lead_labels[i] + str(ii))

    # remove variables with '__' in them
    mat_dict = {k: v for k, v in mat_dict.items() if '__' not in k}

    # remove Sync channel
    if len(mat_dict.keys()) == 17:
        mat_dict.pop('Channel17', None)
    elif len(mat_dict.keys()) == 9:
        mat_dict.pop('Channel09', None)

    # rename keys with channel names
    mat_dict = {chan_labels[i]: v for i, (k, v) in enumerate(mat_dict.items())}

    # remove channels that are empty (no units)
    mat_dict = {k: v for k, v in mat_dict.items() if v.size > 0}
    
    return mat_dict, chan_labels

In [None]:
# Paths
data_path = '/Volumes/Hippocampus/BLAESUnits/Data_30k'

# Find relevant .nev, .nsx, and .mat files
nev_files = glob.glob(root_dir = data_path, pathname = '**/*.nev', recursive = True)
nsx_files = glob.glob(root_dir = data_path, pathname = '**/*.ns6', recursive = True)
sorted_mat_files = glob.glob(root_dir = data_path, pathname = '**/*sorted.mat', recursive = True)
raw_mat_files = glob.glob(root_dir = data_path, pathname = '**/*raw.mat', recursive = True)

# Sort based on filename
nev_files.sort()
nsx_files.sort()
sorted_mat_files.sort()
raw_mat_files.sort()

# Combine to create list of sessions
sessions = [nev_files[i].split('/')[0] for i in range(len(nev_files))]
for i in range(len(sorted_mat_files)):
    sessions.append(sorted_mat_files[i].split('/')[0])

# Data params
fs = 30000 # sampling rate

# Display
print('Found data for %s sessions:' % (len(sessions)))
print('----------------------------')
for i in range(len(sessions)):
    print('%s: %s' % (i, sessions[i]))

In [None]:
# Select the session file to analyze
file_idx = int(input('Enter the session number to analyze:'))
fileName = sessions[file_idx].split('/')[-1]
fType = fileName[0:3]

# Print pID
pID = fileName
proj_path = '/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits'
save_path = os.path.join(proj_path, 'Results', pID)


# Utah files
if fType == 'UIC':
    # Locate files
    nev = os.path.join(data_path, sessions[file_idx], sessions[file_idx] + '.nev')
    nsx = os.path.join(data_path, sessions[file_idx], sessions[file_idx] + '.ns6')
        
    # Load .nev and .nsx files
    events, waveforms = loadNEV(nev)
    cont_data, header = loadNSX(nsx)

    # Adjust for Blackrock's convention of multiplying uV by 4
    waveforms = waveforms / 4

    # Get chan IDs and chan labels
    chan_IDs = [header[i]['ElectrodeID'] for i in range(len(header))]
    chan_labels = [header[i]['ElectrodeLabel'] for i in range(len(header))]

    # Create a dictionary to map chan_IDs to chan_labels
    chan_map = {}
    for idx, val in enumerate(chan_IDs):
        chan_map[val] = chan_labels[idx] # adjust for Python indexing

    # Map electrode labels to channel numbers
    events['Channel'] = events['Channel'].map(chan_map)
    
    # Get sync & PD channels
    sync = cont_data[chan_labels.index('Sync'),:]
    PD = cont_data[chan_labels.index('PD'),:]
    
    # # Delete cont_data to save memory
    # del cont_data


# WashU files
elif fType == 'BJH':
    # Locate files
    sorted_mat = os.path.join(data_path, sessions[file_idx], sessions[file_idx] + '_sorted.mat')
    trigger_mat = os.path.join(data_path, sessions[file_idx], 'Stim_trigger.mat')
    raw_mat = os.path.join(data_path, sessions[file_idx], sessions[file_idx] + '_raw.mat')
    
    # Load sorted.mat and raw.mat files
    mat_dict, chan_labels = loadSortedMat(sorted_mat)
    sync = loadmat(trigger_mat, simplify_cells = True)['stimTrigger']
    try:
        cont_data = loadmat(raw_mat, simplify_cells = True)['signals']
    except:
        cont_data = mat73.loadmat(raw_mat)['signals'] # if saved with v7.3 compression
    
    # Get events and waveforms
    events = []
    waveforms = []
    for k, v in mat_dict.items():
        df = pd.DataFrame(mat_dict[k])
        df['Channel'] = [chan_labels[int(df[0][i]-1)] for i in range(df.shape[0])]
        df['Unit'] = [int(df[1][i]) for i in range(df.shape[0])]
        df['TimeStamps'] = [int(df[2][i]*fs) for i in range(df.shape[0])]
        events.append(df[['TimeStamps', 'Unit', 'Channel']])
        waveforms.append(df.iloc[:, 6:-3])
    events = pd.concat(events)
    waveforms = np.array(pd.concat(waveforms))
    
    # Adjust uV
    waveforms = waveforms / 4
    
else:
    print('Filetype not recognized.')

In [None]:
# Get micro lead (bundle) labels
micro_chans = [x for x in chan_labels if x.startswith('m')]
active_micro_chans = [x for x in events['Channel'].unique().tolist()]
micro_leads = set([x[:-1] for x in micro_chans])
active_micro_leads = set([x[:-1] for x in active_micro_chans])

## 3. Create Epochs
Epochs are created using the `sync` and `PD` channels within the recording to define the stim/no-stim trials, respectively. Because of the way in which the data were recorded, as well as noise/artifact present in the `sync`/`PD` channels, this section entails a considerable amount of manual tuning to ensure epochs are correctly aligned.

### 3.1 Load Trial Info

In [None]:
trial_info = pd.read_csv(os.path.join(proj_path, 'TrialInfo', pID + '_TrialInfo.csv'))
trial_info[['Onset_SC', 'Onset_PD']] = trial_info[['Onset_SC', 'Onset_PD']].apply(lambda x: x * 15) # convert from 2->30 kHz

### 3.2 Detect Sync Channel Onsets

In [None]:
def getSyncOnsets(sync):
    '''
    Detect sync pulses within the 30 kHz recording.
    
    Arguments:
        sync (np.ndarray): sync channel data
        
    Returns:
        stim_onsets (np.ndarray): array of stim_onset indices
    '''

    # Define threshold for sync channel
    if fType == 'UIC':
        thresh = 15000
    elif fType == 'BJH':
        thresh = 15000

    # Find stim pulses using the following criteria: 
    # (1) sync > threshold, (2) transitioning from below to above threshold, (3) > 5s since last crossing
    stim_onsets = []
    syncBool = (sync > thresh).astype(int)
    counter = 0
    for i in range(len(syncBool)-1):
        if (syncBool[i] == 1) & (syncBool[-i] == 0) & (counter > (fs*5)):
            stim_onsets.append(i)
            counter = 0
        counter += 1
    
    return stim_onsets

In [None]:
# Detect sync (stim) pulses
sync_onsets = getSyncOnsets(sync)

# Add sync_onsets to trial_info
special_cases = ['BJH02503', 'BJH02703', 'BJH04502']
stim_trial_counter = 0
if pID not in special_cases:
    for i in range(trial_info.shape[0]):
        if trial_info.loc[i, 'Condition'] == 1:
            trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter]
            stim_trial_counter += 1
        else:
            trial_info.loc[i, 'Onset_Sync'] = np.nan
else:
    if pID == 'BJH02503':
        # Sync pulses not present for 1st few trials, but LFP shows stim
        for i in range(trial_info.shape[0]):
            if trial_info.loc[i, 'Condition'] == 1:
                if stim_trial_counter < 6:
                    trial_info.loc[i, 'Onset_Sync'] = trial_info.loc[i, 'Onset_SC'] - 18000 # empirical correction
                else:
                    trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter-6]
                stim_trial_counter += 1
            else:
                trial_info.loc[i, 'Onset_Sync'] = np.nan
                
    if pID == 'BJH02703':
        # Many sync pulses not present
        fixed = False
        for i in range(trial_info.shape[0]):
            if trial_info.loc[i, 'Condition'] == 1:
                if (stim_trial_counter >= 27) and (stim_trial_counter < 40):
                    trial_info.loc[i, 'Onset_Sync'] = trial_info.loc[i, 'Onset_SC'] - 10000 # empirical correction
                    fixed = True
                else:
                    if not fixed:
                        trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter+1]
                    else:
                        trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter-13]
                stim_trial_counter += 1
                
    elif pID == 'BJH04502':
        # Missed detection of one pulse
        fixed = False
        for i in range(trial_info.shape[0]):
            if trial_info.loc[i, 'Condition'] == 1:
                if stim_trial_counter == 33:
                    trial_info.loc[i, 'Onset_Sync'] = trial_info.loc[i, 'Onset_SC'] - 13000 # empirical correction
                    fixed = True
                else:
                    if not fixed:
                        trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter]
                    else:
                        trial_info.loc[i, 'Onset_Sync'] = sync_onsets[stim_trial_counter-1]
                stim_trial_counter += 1
            else:
                trial_info.loc[i, 'Onset_Sync'] = np.nan

### 3.3 Define Stim/No-Stim Epochs

#### 3.3.1 Align Blackrock to BCI2000 w/ PD

In [None]:
# Utah data does not have 1:1 mapping to timing in BCI2000, thus this step includes an empirical re-alignment to a common feature
# present in both datasets, the photodiode (PD) signal. 

# Get PD onsets from trial_info
PD_onsets = trial_info['Onset_PD'].values

# Find peaks in PD channel (30 kHz); some trial onsets are +PD vs. -PD spike (depending on experiment)
if pID in ['UIC20230201']:
    peaks_30k, _ = find_peaks(PD*-1, height = 500, distance = 150000)
else:
    if fType == 'UIC':
        peaks_30k, _ = find_peaks(PD, height = 500, distance = 150000)

# Manual tuning (empirical correction from visual inspection)
if pID == 'UIC20221301':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(2,47), np.arange(48, 93), np.arange(94, 139), np.arange(140, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20221501':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(2,47), np.arange(48, 93), np.arange(94, 139), np.arange(140, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20221701':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(3,48), np.arange(49, 94), np.arange(98, 143), np.arange(144, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]

elif pID == 'UIC20230201':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(2,82), np.arange(83, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20230601':
    matched_peaks = np.empty(len(trial_info)) * np.nan # initialize as nans; (missing PD for 17 trials)
    keep_peaks = np.concatenate((np.arange(3,67), np.arange(68, 148), np.arange(150, 230), np.arange(231, len(peaks_30k)))) # realign
    matched_peaks[16:] = peaks_30k[keep_peaks]
    
elif pID == 'UIC20230701':
    peaks_30k = peaks_30k[3:] # remove spurious PDs
    keep_peaks = np.concatenate((np.arange(0,12), np.arange(13,81), np.arange(82,162), np.arange(163,243), np.arange(244,324))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20230801':
    peaks_30k = peaks_30k[4:] # remove spurious PDs
    keep_peaks = np.concatenate((np.arange(0,80), np.arange(81,161), np.arange(162,242), np.arange(243,323))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20231101':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(3,83), np.arange(84, 164), np.arange(165, 245), np.arange(247, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20231401':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(2,122), np.arange(123, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
elif pID == 'UIC20240101':
    matched_peaks = peaks_30k
    keep_peaks = np.concatenate((np.arange(3,123), np.arange(124, len(peaks_30k)))) # realign
    matched_peaks = peaks_30k[keep_peaks]
    
# O/S cohort (must remove scrambled images)
if pID in ['UIC20221301', 'UIC20221501', 'UIC20221701']:
    PD_block_idxs = [np.arange(0, 40), np.arange(40, 80), np.arange(80, 120), np.arange(120, 160)]
    peaks_block_idxs = [np.arange(0, 45), np.arange(45, 90), np.arange(90, 135), np.arange(135, 180)]
    
    # Get peaks and onsets for each block
    matched_peaks_noscramb = []
    for i in range(len(PD_block_idxs)):
        peaks_block = matched_peaks[peaks_block_idxs[i]]
        PD_onsets_block = PD_onsets[PD_block_idxs[i]]

        # Adjust PD onsets to align with 1st peak in block ("blocks" are separate .dat files stitched together)
        PD_onsets_adj = PD_onsets_block + (peaks_block[0] - PD_onsets_block[0])

        # Get the values in peaks closest to PD_onsets_adj
        peaks_adj_block = []
        for i in range(len(PD_onsets_adj)):
            peaks_adj_block.append(matched_peaks[np.argmin(np.abs(matched_peaks - PD_onsets_adj[i]))])
        matched_peaks_noscramb.append(peaks_adj_block)
        
    matched_peaks = np.array([item for sublist in matched_peaks_noscramb for item in sublist])

In [None]:
# Overwrite Utah trial_info PD onsets using matched_peaks
if fType == 'UIC':
    trial_info['Onset_PD'] = matched_peaks

#### 3.3.2 Use Sync & PD to Define Stim/No-Stim Epochs

In [None]:
stim_epochs = []
nostim_epochs = []

for i in range(trial_info.shape[0]):
    # stim trial
    if trial_info.loc[i, 'Condition'] == 1:
        stim_epochs.append(np.arange(trial_info.loc[i, 'Onset_Sync'] - (4 * fs), trial_info.loc[i, 'Onset_Sync'] + (2*fs)))
    # no stim trial
    else:
        if np.isnan(trial_info['Onset_PD'][i]):
            nostim_epochs.append(np.zeros(fs*6)) # populate empty epochs with zeros (cannot use np.nan in pd.DataFrame)
        else:
            nostim_epochs.append(np.arange(trial_info.loc[i, 'Onset_PD'] - (4 * fs), trial_info.loc[i, 'Onset_PD'] + (2*fs)))
        
stim_epochs = pd.DataFrame(stim_epochs).astype(int).T
nostim_epochs = pd.DataFrame(nostim_epochs).astype(int).T

stim_epochs.to_csv(os.path.join(save_path, 'StimEpochs.csv'))
nostim_epochs.to_csv(os.path.join(save_path, 'NoStimEpochs.csv'))

### 3.4 Epoch 30 kHz Continuous Data

In [None]:
stim_epochs_30k = []
nostim_epochs_30k = []

for i in range(stim_epochs.shape[1]):
    stim_epochs_30k.append(cont_data[:, stim_epochs[i].values])
    
for i in range(nostim_epochs.shape[1]):
    nostim_epochs_30k.append(cont_data[:, nostim_epochs[i].values])
    
stim_epochs_30k = np.array(stim_epochs_30k)
nostim_epochs_30k = np.array(nostim_epochs_30k)

In [None]:
# filter epochs to just include micro channels
stim_epochs_30k_micro = []
for i in range(stim_epochs_30k.shape[0]):
    stim_epochs_30k_micro.append(stim_epochs_30k[i, [chan_labels.index(x) for x in micro_chans]])
stim_epochs_30k_micro = np.array(stim_epochs_30k_micro)

nostim_epochs_30k_micro = []
for i in range(nostim_epochs_30k.shape[0]):
    nostim_epochs_30k_micro.append(nostim_epochs_30k[i, [chan_labels.index(x) for x in micro_chans]])
nostim_epochs_30k_micro = np.array(nostim_epochs_30k_micro)

In [None]:
# Separate into pre trial (1st second) and post trial (last second)
stim_epochs_30k_pre = stim_epochs_30k_micro[:, :, :4*fs]
stim_epochs_30k_post = stim_epochs_30k_micro[:, :, -fs:]
nostim_epochs_30k_pre = nostim_epochs_30k_micro[:, :, :fs]
nostim_epochs_30k_post = nostim_epochs_30k_micro[:, :, -fs:]

## 4. MUA

In [None]:
def extractMUA(data):
    """
    Detects multiunit activity (MUA) in neural recordings using RMS thresholding.

    Parameters:
    - data: numpy array of shape [channels x samples], sampled at 30 kHz

    Returns:
    - mua: filtered MUA signals
    - timestamps: list of spike timestamps (in seconds) for each channel
    - thresh: last computed threshold value
    """
    
    def find_inflections(signal, kind='minima'):
        """Find local minima of the signal."""
        if kind == 'minima':
            # Invert signal to find minima using find_peaks
            peaks, _ = find_peaks(-signal)
            return peaks
        else:
            raise ValueError("Only 'minima' kind is supported.")
    
    Fs = 30000  # Sampling rate
    band = [300, 3000]  # Bandpass range in Hz
    num_chans = data.shape[0]
    stim_win = [4, 5]

    # Design FIR filter
    b = firwin(numtaps=97, cutoff=[band[0], band[1]], pass_zero=False, fs=Fs)
    a = 1  # FIR filters have only b coefficients

    mua = np.zeros_like(data, dtype=float)
    timestamps = []

    for ch in range(num_chans):
        filtered = filtfilt(b, a, data[ch, :].astype(float))
        filtered -= np.mean(filtered)
        filtered_pre = filtered[:Fs]  # Pre-trial segment
        mua[ch, :] = filtered
        thresh = -3.5 * np.sqrt(np.mean(filtered_pre**2))  # RMS threshold
        peaks = find_inflections(filtered, kind='minima')
        spiketimes = peaks[filtered[peaks] < thresh]
        timestamps.append(spiketimes)
        
    # Count spikes in pre-/post- epochs
    pre_counts = np.zeros(num_chans)
    post_counts = np.zeros(num_chans)
    for ch in range(num_chans):
        pre_spikes = timestamps[ch][timestamps[ch] <= 1 * Fs]
        post_spikes = timestamps[ch][timestamps[ch] >= 5 * Fs]
        pre_counts[ch] = len(pre_spikes)
        post_counts[ch] = len(post_spikes)
        
    # Organize into a DataFrame
    counts_df = pd.DataFrame({
        'Chan': [chan_labels[ch] for ch in range(num_chans)],
        'Pre_Spikes': pre_counts,
        'Post_Spikes': post_counts
    })
        
    # Remove timestamps that are inside the stim window
    for ch in range(num_chans):
        valid_spikes = []
        for ts in timestamps[ch]:
            if ts < stim_win[0] * Fs or ts > stim_win[1] * Fs:
                valid_spikes.append(ts)
        timestamps[ch] = np.array(valid_spikes)

    return mua, timestamps, thresh, counts_df

In [None]:
n_trials_stim = stim_epochs_30k_pre.shape[0]
n_trials_nostim = nostim_epochs_30k_pre.shape[0]

stim_dfs = []
for i in range(n_trials_stim):
    _, _, _, counts_df = extractMUA(stim_epochs_30k_micro[i, :, :])
    counts_df['Trial'] = i + 1
    stim_dfs.append(counts_df)
stim_dfs = pd.concat(stim_dfs, ignore_index=True)
stim_dfs['Condition'] = 'Stim'
stim_dfs['pID'] = pID
stim_dfs = stim_dfs.reset_index(drop=True)

nostim_dfs = []
for i in range(n_trials_nostim):
    _, _, _, counts_df = extractMUA(nostim_epochs_30k_micro[i, :, :])
    counts_df['Trial'] = i + 1
    nostim_dfs.append(counts_df)
nostim_dfs = pd.concat(nostim_dfs, ignore_index=True)
nostim_dfs['Condition'] = 'NoStim'
nostim_dfs['pID'] = pID
nostim_dfs = nostim_dfs.reset_index(drop=True)

# Combine stim and nostim DataFrames
mua_counts = pd.concat([stim_dfs, nostim_dfs], ignore_index=True)
mua_counts = mua_counts.reset_index(drop=True)

# Export MUA counts to CSV
mua_counts.to_csv(os.path.join(save_path, 'MUACounts.csv'))
mua_counts

In [None]:
# # Example MUA for Supplemental Figure 8

# # UIC20230601 (7), mRHCA4 (18), trial 127 
# test_idx = 19
# trial_idx = 126
# Fs = 30000

# mua, timestamps, thresh, counts_df = extractMUA(stim_epochs_30k_micro[trial_idx, :, :])

# # plot MUA
# fig, axes = plt.subplots(2, 1, figsize = (8, 2), gridspec_kw={'height_ratios': [3, 1]}, sharex = True)
# axes[1].set_position([0.125, 0.3, 0.775, 0.15])  # [left, bottom, width, height]


# axes[0].plot(mua[test_idx, :], label='Filtered MUA', lw = 0.25, color = '#545352')
# sns.rugplot(timestamps[test_idx], ax=axes[1], color='#da1b61', height=0.5, lw = 1)
# axes[0].scatter(timestamps[test_idx], mua[test_idx, (timestamps[test_idx])], color='#da1b61', s=0.005, label='Threshold Crossing', zorder = 10)



# axes[0].axvspan(xmin = 4*Fs, xmax = 5*Fs, color='grey', alpha = 0.15, zorder = 10)
# axes[0].axvline(Fs, color = 'k', linestyle = '--', lw = 0.5)
# axes[0].axvline(4*Fs, color = 'k', linestyle = '--', lw = 0.5)
# axes[0].axvline(5*Fs, color = 'k', linestyle = '--', lw = 0.5)
# # axes[0].axhline(thresh, color = '#da1b61', linestyle = '-', lw = 0.5, label='Threshold', zorder = -1) # subtract 8 to account for lw

# for ax in axes:
#     ax.tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)
#     sns.despine(ax=ax, top=True, right=True, left=True, bottom=True)
#     ax.set_xlim(-1000, (6 * Fs)+1000)  # Set x-axis limits to 6 seconds (180000 samples)
# # axes[0].set_ylim(-400, 400)
# # axes[0].text(-0.045, 0.42, '~100 $\mu$V', fontsize = 'x-small', rotation = 90, transform = axes[0].transAxes)
# # axes[1].text(0.07, 0.175, '1 s', fontsize = 'x-small', transform = axes[0].transAxes)
# axes[0].text(0.05, 1.1, 'Pre-ISI', fontsize = 'small', transform = axes[0].transAxes)
# axes[0].text(0.325, 1.1, 'Image', fontsize = 'small', transform = axes[0].transAxes)
# axes[0].text(0.68, 1.1, 'Stim | No-Stim', fontsize = 'small', transform = axes[0].transAxes)
# axes[0].text(0.88, 1.1, 'Post-ISI', fontsize = 'small', transform = axes[0].transAxes)

# plt.savefig(os.path.join('/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits/Presentations/Methods Figures/MUAEx.pdf'), dpi=1000, bbox_inches='tight')
# plt.savefig(os.path.join('/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits/Presentations/Methods Figures/MUAEx.png'), dpi=1000, bbox_inches='tight')

# plt.show()