## Objective
Simulate a real time experiment by replaying an organoid recording and test the speed of DL model detection + prop sorting. For each detection, record the time it took from the spike to occur until the detection and plot these durations as a distribution.

# TODO
Keep multiple workers alive and break up model outputs to find peaks?
Write python code and ask ChatGPT to convert to C++?

Maybe ditch the padding on the ends to detect spikes in first and last frame of input if windows overlap

End with step function, sum values, conv layer?

Break up conv layers so model only outputs 1 value per channel?

In [7]:
# region Set up notebook imports
%load_ext autoreload
%autoreload
# Allow for imports of other scripts
import sys
PATH = "/data/MEAprojects/DLSpikeSorter"
if PATH not in sys.path:
    sys.path.append(PATH)
# Reload a module after changes have been made
from importlib import reload
# endregion

from time import perf_counter
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.nn.functional import pad as torch_pad
import torch_tensorrt

import scipy
from scipy.signal import find_peaks


from src.model import ModelSpikeSorter
from src import utils

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Setup

In [4]:
RECORDING = np.load(utils.PATH_REC_DL_NP.format("2953"), mmap_mode="r")
MODEL = ModelSpikeSorter.load("/data/MEAprojects/DLSpikeSorter/models/v0_4_4/2953/230101_133514_582221").eval().cuda().to(torch.float16)
##
loc_prob_thresh_logit = MODEL.loc_prob_thresh_logit
logit_to_loc_add = MODEL.loc_first_frame

### Setup torch_tensorrt

## Start simulation

#### Simulation pseudocode
    detected_spikes = []
    sim_time = window_duration  # Simulation clock. 10ms = duration of model window inputs
    while sim_time <= recording_duration:
        ## Run simulation frame
        get numpy window from (sim_time - window_duration to sim_time)

        detected_spikes_cur = []

        processing_time_start = 0
        subtract mean from window
        convert window to torch and cuda
        get model outputs
        get model predictions
        plug predictions into prop signal
        check if spikes were detected

        for every detected spike:
            get processing_time_end
            processing_time = processing_time_end - processing_time_start
            detected_spikes_cur.append(processing_time)

        ## Determine detection delay and true/false positive for analysis
        for every detected spike in detected_spikes_cur: (in ascending order of spike_time_in_window)
            time_spike_was_detected = sim_time + processing_time
            predicted_spike_time = sim_time - window_duration + spike_time_in_window

            # check if predicted spike matches with a spike in [sim_time - window_duration, sim_time]
            # only works because front_buffer_sample and end_buffer_sample > match time (2ms >> 0.4ms)
            if closest real spike is close to predicted_spike_time:
                true_positive = true
                mark real spike as matched
            else:
                true_positive = false

            # delay relative to real spike: detection_delay = time_spike_was_detected - real spike time
            # but model predictions times are probably more accurate than kilosort
            detection_delay = time_spike_was_detected - predicted_spike_time

            detected_spikes.append((true_positive, detection_delay))

        ## Continue simulation
        sim_time += processing_time


In [5]:
def warmup(model, inputs, n_runs, verbose=False):
    # GPU needs to warmup

    if verbose:
        print("Warming up ...")
    with torch.no_grad():
        for _ in range(n_runs):
            model(inputs.to("cuda"))
            torch.cuda.synchronize()
warmup(MODEL, torch.tensor(RECORDING[:, None, 1000:1200], dtype=torch.float16), 100)

In [1]:
warmup(MODEL, torch.tensor(RECORDING[:, None, 1000:1200], dtype=torch.float16, device="cuda"), 100)

traces = RECORDING[:, 2000:2200]

# processing_time_start = perf_counter()
MODEL = MODEL.to("cuda")
traces = traces - np.mean(traces, axis=1, keepdims=True)
traces = torch.tensor(traces[:, None, :], device="cuda", dtype=torch.float16)
with torch.no_grad():
    outputs = MODEL(traces)

processing_time_start = perf_counter()

outputs = outputs.cpu()
processing_time_end = perf_counter()

(processing_time_end - processing_time_start) * 1000

# # outputs = torch_pad(outputs.cpu(), (1, 1), value=-np.inf)
# for channel in outputs:
#     peaks = find_peaks(channel, height=loc_prob_thresh_logit)
# processing_time_end = perf_counter()
#
# (processing_time_end - processing_time_start) * 1000

NameError: name 'warmup' is not defined

In [6]:
test = MODEL.model
test.add_module(torch.nn.)

ModelTuning(
  (last_layer): Conv1d(50, 1, kernel_size=(21,), stride=(1,))
  (conv): Sequential(
    (0): Conv1d(1, 50, kernel_size=(21,), stride=(1,))
    (1): ReLU()
    (2): Conv1d(50, 50, kernel_size=(21,), stride=(1,))
    (3): ReLU()
    (4): Conv1d(50, 50, kernel_size=(21,), stride=(1,))
    (5): ReLU()
    (6): Conv1d(50, 1, kernel_size=(21,), stride=(1,))
  )
  (noise): Flatten(start_dim=1, end_dim=-1)
)

## Plot

## Need to test speed not on juypter notebook