# Objective
Extract waveforms and store as .npy based on spike times

In [1]:
# region Set up notebook imports
%load_ext autoreload
%autoreload
# Allow for imports of other scripts
import sys
PATH = "/data/MEAprojects/PropSignal"
if PATH not in sys.path:
    sys.path.append(PATH)
# Reload a module after changes have been made
from importlib import reload
# endregion

import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
from tqdm import tqdm

from src.recording import Recording
from src.sorters.prop_signal import PropSignal
from src import utils

## Setup
#### SPIKE_TIMES must be in ms

Spike times need to be in format [[spike times for unit 1], [spike times for unit 2], ...]
SAVE_PATH is folder where .npy files of waveforms will be saved

In [2]:
RECORDING = Recording(utils.PATH_REC_DL.format('2953'), freq_min=300, freq_max=6000)
SPIKE_TIMES = np.load("/data/MEAprojects/DLSpikeSorter/models/v0_4_4/2953/230101_133514_582221/log/windows_200_120/prop_signal/propagating_times.npy", allow_pickle=True)
SAVE_PATH = Path("/data/MEAprojects/DLSpikeSorter/models/v0_4_4/2953/230101_133514_582221/log/windows_200_120/prop_signal/waveforms")

N_BEFORE = int(2 * RECORDING.get_sampling_frequency())
N_AFTER = int(2 * RECORDING.get_sampling_frequency())

In [3]:
SAVE_PATH.mkdir(parents=True, exist_ok=True)

## Extract waveforms

In [4]:
def extract_waveforms(spike_train,
                      recording, n_before, n_after,
                      max_spikes=None):
    if max_spikes is not None and len(spike_train) > max_spikes:
        spike_train = np.array(spike_train)[np.random.choice(len(spike_train), size=max_spikes, replace=False)]

    waveforms = np.zeros((len(spike_train), recording.get_num_channels(), n_before + 1 + n_after), dtype="float32")

    sf = recording.get_sampling_frequency()
    for i, st in enumerate(spike_train):
        st = int(st * sf)
        traces = recording.get_traces_filt(st-n_before, st+n_after+1)
        if traces.shape[1] == waveforms.shape[2]:
            waveforms[i, :, :] = traces
        else:
            if st-N_BEFORE < 0:
                waveforms[i, :, :traces.shape[1]] = traces
            else:
                waveforms[i, :, -traces.shape[1]:] = traces

    return waveforms

def extract_waveforms_job(unit_idx):
    waveforms = extract_waveforms(SPIKE_TIMES[unit_idx], RECORDING,
                                  N_BEFORE, N_AFTER)
    np.save(SAVE_PATH / f"{unit_idx}.npy", waveforms)

In [5]:
with Pool(processes=12) as pool:
    tasks = range(len(SPIKE_TIMES))
    for _ in tqdm(pool.imap(extract_waveforms_job, tasks), total=len(tasks)):
        pass

100%|██████████| 90/90 [02:25<00:00,  1.61s/it]
