## Objective
We subsequently tested our model on actual spikes detected by Kilosort2 in the hold-out recordings. When only considering the largest waveform amplitude electrodes, we observed precision and recall scores of …±…% and …±…% for organoids and …±…% and …±…% for mice (Sup. Fig. 2A-D). The precision and recall were still …±…% and …±…% for organoids and …±…% and …±…% for mice when we considered the 4th largest waveform amplitude electrodes (Sup. Fig. 2E-H). When interpreting these precision scores, it has to be considered that we are comparing the model against Kilosort2 detections. As such, detections from the model that Kilosort2 misses due to various possible reasons like overlapping waveforms, low amplitude or waveform shape changes will negatively impact the obtained precision scores.


In [1]:
%load_ext autoreload

In [6]:
from multiprocessing import Pool
from pathlib import Path

import numpy as np
import torch

from tqdm import tqdm

%autoreload 2
from src.run_alg.model import ModelSpikeSorter
from src.comparison import Comparison

## Run DL model on recordings

In [3]:
MODEL_PATHS = [
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/dl_models/240318/a/240318_161415_981130",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/dl_models/240318/b/240318_163253_679441",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/dl_models/240318/c/240318_165245_967091",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/dl_models/240318/d/240318_172719_805804",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/dl_models/240318/e/240318_174428_896437",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/dl_models/240318/f/240318_180745_727120",
]

REC_PATHS = [
    Path("/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592315"),
    Path("/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592318"),
    Path("/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592320"),
    Path("/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592324"),
    Path("/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592328"),
    Path("/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592330"),
]

SAMP_FREQ = 30
GAIN_TO_UV = 0.195

In [4]:
def run_dl_model(model_path, unscaled_traces_path,
                 model_traces_path, model_outputs_path,
                 device="cuda"):    
    """
    WARNING: [Torch-TensorRT] - Using an engine plan file across different models of devices is not recommended and is likely to affect performance or even cause errors
        - This is nothing unless using model on a different GPU than what created it (https://github.com/dusty-nv/jetson-inference/issues/883#issuecomment-754106437)
    
    """

    torch.backends.cudnn.benchmark = True
    np_dtype = "float16"
    
    # region Load model
    print("Loading DL model ...")
    model = ModelSpikeSorter.load(model_path) 
    sample_size = model.sample_size
    num_output_locs = model.num_output_locs
    input_scale = model.input_scale
    model = ModelSpikeSorter.load_compiled(model_path)
    # endregion
    
    # region Prepare data
    unscaled_traces = np.load(unscaled_traces_path, mmap_mode="r")
    
    num_chans, rec_duration = unscaled_traces.shape

    start_frames_all = np.arange(0, rec_duration-sample_size+1, num_output_locs)
    
    print("Allocating memory to save model traces and outputs ...")
    traces_all = np.zeros(unscaled_traces.shape, dtype=np_dtype)
    np.save(model_traces_path, traces_all)
    traces_all = np.load(model_traces_path, mmap_mode="r+")

    outputs_all = np.zeros((num_chans, start_frames_all.size*num_output_locs), dtype=np_dtype)
    np.save(model_outputs_path, outputs_all)
    outputs_all = np.load(model_outputs_path, mmap_mode="r+")
    # endregion
    
    # region Calculating inference scaling
    # if INFERENCE_SCALING_NUMERATOR is not None:
    #     window = scaled_traces[:, :PRE_MEDIAN_FRAMES]
    #     iqrs = scipy.stats.iqr(window, axis=1)
    #     median = np.median(iqrs)
    #     inference_scaling = INFERENCE_SCALING_NUMERATOR / median
    # else:
    #     inference_scaling = 1
    inference_scaling = GAIN_TO_UV
    print(f"Inference scaling: {inference_scaling}")
    # endregion
    
    # region Run model
    print("Running model ...")    
    with torch.no_grad():
        for start_frame in tqdm(start_frames_all):
            traces_torch = torch.tensor(unscaled_traces[:, start_frame:start_frame+sample_size], device=device, dtype=torch.float16)
            traces_torch -= torch.median(traces_torch, dim=1, keepdim=True).values
            outputs = model(traces_torch[:, None, :] * input_scale * inference_scaling).cpu()

            traces_all[:, start_frame:start_frame+sample_size] = traces_torch.cpu()
            outputs_all[:, start_frame:start_frame+num_output_locs] = outputs[:, 0, :]
    # endregion
    
    # region Save traces and outputs
    # np.save(model_traces_path, traces_all)
    # np.save(model_outputs_path, outputs_all)
    # endregion
    
def extract_crossings(model_outputs_path, 
                      all_crossings_path, elec_crossings_ind_path,
                      front_buffer, loc_prob_thresh_logit,
                      end_ms=None,
                      window_size=1000,
                      device="cpu"):          
    outputs = np.load(model_outputs_path, mmap_mode="r")
    num_elecs = outputs.shape[0]
    
    # Using end_ms (rel to recording traces) since only need crossings during pre-recording to form sequences, not entire duration  
    if end_ms is None:
        end_frame = outputs.shape[1]-window_size-1
    else: 
        end_frame = round(end_ms * SAMP_FREQ) - front_buffer  # -FRONT_BUFFER to convert from rec. traces to model outputs
    
    all_crossings = []  # [(elec_idx, time, amp)]
    elec_crossings_ind = [[] for _ in range(num_elecs)]  # ith element for elec idx i. Contains ind in all_crossings for elec idx i's crossings
    crossing_idx = 0
    for start_frame in tqdm(range(0, end_frame, window_size)):
        if start_frame >= end_frame: 
            break
        
        window = outputs[:, start_frame:start_frame+window_size+2]
        window = torch.tensor(window, device=device)
        
        main = window[:, 1:-1]
        greater_than_left = main > window[:, :-2]
        greater_than_right = main > window[:, 2:]
        peaks = greater_than_left & greater_than_right
        crosses = main >= loc_prob_thresh_logit
        nonzeros = torch.nonzero((peaks & crosses).T)  # .T is so that outputs are ordered based on peak_ind first then elec_ind
        for peak, elec in nonzeros:
            peak = peak.item()
            elec = elec.item()
            time_ms = (front_buffer + peak + start_frame + 1) / SAMP_FREQ  # +1 since rel. to main (which is +1 rel to window and window is rel. to start_frame)
            all_crossings.append((elec, time_ms, -1)) 
            elec_crossings_ind[elec].append(crossing_idx)
            crossing_idx += 1
            
        # if start_frame > 10000:
        #     times = [all_crossings[idx][1] for idx in elec_crossings_ind[17]]
        #     print(times)
        #     plot_spikes(times, 17)
        #     plt.show()
        #     return
    
    np.save(all_crossings_path, np.array(all_crossings, dtype=object))
    np.save(elec_crossings_ind_path, np.array(elec_crossings_ind, dtype=object))

In [5]:
for model_path, rec_path in zip(MODEL_PATHS, REC_PATHS):
    model_path = Path(model_path)
    
    print(f"Starting on {rec_path.name}")
    inter_path = model_path / "run_on_holdout_recording"
    inter_path.mkdir(exist_ok=True, parents=True)
    
    unscaled_traces_path = rec_path / "traces.npy"
    
    model_traces_path = inter_path / "model_traces.npy"
    model_outputs_path = inter_path / "model_outputs.npy"
    
    all_crossings_path = inter_path / "all_crossings.npy"
    elec_crossings_ind_path = inter_path / "elec_crossings_ind.npy"
    
    model = ModelSpikeSorter.load(model_path)
    num_elecs = np.load(unscaled_traces_path, mmap_mode="r").shape[0]
    model.compile(num_elecs, model_path)    
    
    run_dl_model(model_path, unscaled_traces_path, model_traces_path, model_outputs_path)
    extract_crossings(model_outputs_path, all_crossings_path, elec_crossings_ind_path,
                      front_buffer=model.buffer_front_sample, loc_prob_thresh_logit=model.loc_prob_thresh_logit)

Starting on probe_773592315




Loading DL model ...
Allocating memory to save model traces and outputs ...
Inference scaling: 0.195
Running model ...


100%|██████████| 199999/199999 [16:04<00:00, 207.29it/s] 
100%|██████████| 35999/35999 [02:00<00:00, 298.46it/s]


Starting on probe_773592318




Loading DL model ...
Allocating memory to save model traces and outputs ...
Inference scaling: 0.195
Running model ...


100%|██████████| 199999/199999 [16:15<00:00, 204.99it/s]
100%|██████████| 35999/35999 [03:17<00:00, 181.98it/s]


Starting on probe_773592320




Loading DL model ...
Allocating memory to save model traces and outputs ...
Inference scaling: 0.195
Running model ...


100%|██████████| 199999/199999 [09:21<00:00, 356.27it/s]
100%|██████████| 35999/35999 [01:42<00:00, 350.55it/s]


Starting on probe_773592324




Loading DL model ...
Allocating memory to save model traces and outputs ...
Inference scaling: 0.195
Running model ...


100%|██████████| 199999/199999 [12:53<00:00, 258.44it/s]
100%|██████████| 35999/35999 [01:29<00:00, 402.13it/s]


Starting on probe_773592328
Loading DL model ...




Allocating memory to save model traces and outputs ...
Inference scaling: 0.195
Running model ...


100%|██████████| 199999/199999 [09:06<00:00, 366.14it/s]
100%|██████████| 35999/35999 [02:55<00:00, 204.81it/s]


Starting on probe_773592330




Loading DL model ...
Allocating memory to save model traces and outputs ...
Inference scaling: 0.195
Running model ...


100%|██████████| 199999/199999 [09:15<00:00, 360.16it/s] 
100%|██████████| 35999/35999 [01:45<00:00, 340.32it/s]


## Measure performance

In [None]:
"""Pseudocode:
Save per model/recording

For each unit, get template.
For 0...5 highest amp electrode, count number of matches
Return [num_true_positives, num_false_postiives, num_false_negatives for each highest amp electrode]
Store all in array

[num_units, 6=(tested electrodes), 3=(num_true_postives, num_false_positives, num_false_negatives)]
"""

In [7]:
X_HIGHEST = 5  # Test up to X_HIGHEST electrode (max electrode = 0)

# For extracting unit template to determine highest elecs
N_BEFORE = N_AFTER = round(0.5 * SAMP_FREQ)
N_BEFORE_MEDIAN = N_AFTER_MEDIAN = N_AFTER_MEDIAN = round(50*SAMP_FREQ)
NUM_SPIKES_ESTIMATOR = 300  # Number of spikes needed to determine which of electrodes have largest amp in unit's template

def job(ks_unit_id):
    ks_spike_frames = KS_SPIKE_TIMES[KS_SPIKE_CLUSTERS == ks_unit_id]
    # Find amplitudes on elecs
    np.random.seed(231)
    templates = np.zeros([TRACES.shape[0], N_BEFORE+N_AFTER+1], dtype=float)
    for frame in np.random.choice(ks_spike_frames, NUM_SPIKES_ESTIMATOR):
        if frame < N_BEFORE:  # Prevent broadcasting problems
            continue
        window = TRACES[:, frame-N_BEFORE:frame+N_AFTER+1] - np.median(TRACES[:, frame-N_BEFORE_MEDIAN:frame+N_AFTER_MEDIAN+1], axis=1, keepdims=True)
        templates += window
    amps = np.max(np.abs(templates), axis=1)
    highest_elecs = np.argsort(-amps)
    ks_spike_times = ks_spike_frames / SAMP_FREQ

    perfs = []
    for elec in highest_elecs[:X_HIGHEST+1]:
        model_times = [ALL_CROSSINGS[idx][1] for idx in ELEC_CROSSINGS_IND[elec]]
        num_tp = Comparison.count_matching_events(model_times, ks_spike_times)
        perfs.append((num_tp, len(model_times)-num_tp, len(ks_spike_times)-num_tp))
    return perfs

for model_path, rec_path in zip(MODEL_PATHS, REC_PATHS):
    model_path = Path(model_path)
    
    print(f"Starting on {rec_path.name}")
    inter_path = model_path / "run_on_holdout_recording"    
    unscaled_traces_path = rec_path / "traces.npy"
    all_crossings_path = inter_path / "all_crossings.npy"
    elec_crossings_ind_path = inter_path / "elec_crossings_ind.npy"
    
    TRACES = np.load(unscaled_traces_path, mmap_mode="r")
    ALL_CROSSINGS = np.load(all_crossings_path, allow_pickle=True)
    ELEC_CROSSINGS_IND = np.load(elec_crossings_ind_path, allow_pickle=True)
    
    KS_SPIKE_TIMES = np.load(rec_path / "spikesort_matlab4/curation/first/spike_times.npy")
    KS_SPIKE_CLUSTERS = np.load(rec_path / "spikesort_matlab4/curation/first/spike_clusters.npy")
    
    print(f"Total num model detections: {len(ALL_CROSSINGS)}") # Sanity check that there are model crossings
    
    ks_unit_ids = np.unique(KS_SPIKE_CLUSTERS)
    
    with Pool(processes=16) as pool:
        all_perfs = []
        for perfs in tqdm(pool.imap(job, ks_unit_ids), total=len(ks_unit_ids)):
            all_perfs.append(perfs)
        all_perfs = np.array(all_perfs)
    np.save(inter_path / "units_highest_elecs_tp_fp_fn.npy", all_perfs)

Starting on probe_773592315
Total num model detections: 15778671


  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
100%|██████████| 220/220 [01:37<00:00,  2.25it/s]


Starting on probe_773592318
Total num model detections: 16609692


  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
100%|██████████| 257/257 [02:01<00:00,  2.12it/s]


Starting on probe_773592320
Total num model detections: 12479824


100%|██████████| 340/340 [02:06<00:00,  2.68it/s]


Starting on probe_773592324
Total num model detections: 8671967


100%|██████████| 186/186 [01:22<00:00,  2.25it/s]


Starting on probe_773592328
Total num model detections: 21499889


100%|██████████| 389/389 [02:29<00:00,  2.59it/s]


Starting on probe_773592330
Total num model detections: 12397703


  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)
100%|██████████| 198/198 [01:28<00:00,  2.23it/s]
