## Objective
Get threshold crossings of recording based on RMS sliding window

In [1]:
# region Set up notebook imports
%load_ext autoreload
%autoreload
# Allow for imports of other scripts
import sys
PATH = "/data/MEAprojects/PropSignal"
if PATH not in sys.path:
    sys.path.append(PATH)
# Reload a module after changes have been made
from importlib import reload
# endregion

from spikeinterface.extractors import NwbRecordingExtractor, MaxwellRecordingExtractor
from spikeinterface.preprocessing import bandpass_filter, scale
import numpy as np
from scipy.signal import find_peaks
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool

In [2]:
THRESH_CROSSINGS_PATH = "/data/MEAprojects/dandi/000034/sub-mouse412804/prop_signal/thresh_5/crossings.npy"
SPIKE_AMP_THRESH = 5
CHUNK_SIZE = 1000

REC_PATH = "/data/MEAprojects/dandi/000034/sub-mouse412804/sub-mouse412804_ecephys.nwb"
FREQ_MIN = 300
FREQ_MAX = 3000

In [3]:
if REC_PATH.endswith(".h5"):
    rec = MaxwellRecordingExtractor(REC_PATH)
    import os
    os.environ['HDF5_PLUGIN_PATH'] = '/home/mea/SpikeSorting/spikeinterface'
else:
    rec = NwbRecordingExtractor(REC_PATH)
rec = scale(rec, rec.get_channel_gains(), rec.get_channel_offsets(), dtype="float32")
rec = bandpass_filter(rec, freq_min=FREQ_MIN, freq_max=FREQ_MAX, dtype="float32")

fs = rec.get_sampling_frequency() / 1000

In [4]:
def thresh_crossings_chunk(start_frame):
    # Return thresh crossings in a chunk
    traces_all = rec.get_traces(start_frame=start_frame, end_frame=start_frame + CHUNK_SIZE, return_scaled=False)

    crossings = []
    for trace in traces_all.T:
        noise = np.sqrt(np.mean(np.square(trace)))
        sts = find_peaks(-trace, height=SPIKE_AMP_THRESH * noise)[0]  # Only negative threshold crossings
        sts = (sts + start_frame) / fs
        crossings.append(sts)

    # assert len(crossings) == len(rec.get_channel_locations())

        # import matplotlib.pyplot as plt
        #
        # if st.size > 0:
        #     plt.plot(trace)
        #     plt.axhline(-self.thresh * noise)
        #     plt.axhline(self.thresh * noise)
        #     for x in find_peaks(np.abs(trace), height=self.thresh * noise)[0]:
        #         plt.axvline(x, color="black", linestyle="dashed")
        #     plt.show()

    return crossings

In [5]:
spike_times = [[] for _ in range(rec.get_num_channels())]
with Pool(processes=20) as pool:
    tasks = range(0, rec.get_total_samples()-CHUNK_SIZE+1, CHUNK_SIZE-2)  # range(0, rec.get_total_samples()-CHUNK_SIZE, CHUNK_SIZE)  #
    for crossings in tqdm(pool.imap(thresh_crossings_chunk, tasks, len(tasks) // 1000), total=len(tasks)):
        for i in range(len(spike_times)):
            spike_times[i].extend(crossings[i])
spike_times = [np.unique(st) for st in spike_times]

100%|██████████| 12000/12000 [00:17<00:00, 678.32it/s]


In [6]:
np.save(THRESH_CROSSINGS_PATH, spike_times)