In [16]:
import spikeinterface as si
import numpy as np

In [2]:
raw_recording = si.read_zarr("../data/ecephys_733583_2024-11-15_12-21-20/ecephys/ecephys_compressed/experiment1_Record Node 105#Neuropix-PXI-100.50205.zarr/")

In [3]:
raw_recording

In [4]:
raw_recording.get_annotation("probes_info")[0]

{'dock': '1',
 'manufacturer': 'IMEC',
 'model_name': 'Neuropixels 2.0 - Single Shank - Prototype',
 'name': '50205',
 'part_number': 'NP2000',
 'port': '2',
 'serial_number': '20403319134',
 'slot': '2'}

In [9]:
saturation_thresholds_uv = {
    "PRB_1_4_0480_1": 0.6 * 1e6,
    "PRB_1_4_0480_1_C": 0.6 * 1e6,
    "PRB_1_2_0480_2": 0.6 * 1e6,
    "NP1010": 0.6 * 1e6,
    # NHP probes
    "NP1015": 0.6 * 1e6,
    "NP1016": 0.6 * 1e6,
    "NP1022": 0.6 * 1e6,
    "NP1030": 0.6 * 1e6,
    "NP1031": 0.6 * 1e6,
    "NP1032": 0.6 * 1e6,
    # NP2.0
    "NP2000": 0.5 * 1e6,
    "NP2010": 0.5 * 1e6,
    "NP2013": 0.62 * 1e6,
    "NP2014": 0.62 * 1e6,
    "NP2003": 0.62 * 1e6,
    "NP2004": 0.62 * 1e6,
    "PRB2_1_2_0640_0": 0.5 * 1e6,
    "PRB2_4_2_0640_0": 0.5 * 1e6,
    # Other probes
    "NP1100": 0.6 * 1e6,  # Ultra probe - 1 bank
    "NP1110": 0.6 * 1e6,  # Ultra probe - 16 banks
    "NP1121": 0.6 * 1e6,  # Ultra probe - beta configuration
    "NP1300": 0.6 * 1e6,  # Opto probe
}

In [11]:
recording = raw_recording

In [None]:
def find_saturation_events(recording, **job_kwargs):
    part_number = recording.get_annotation("probes_info")[0]["part_number"]
    

In [19]:
from spikeinterface.core.node_pipeline import PipelineNode, run_node_pipeline
from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel

In [29]:
si.set_global_job_kwargs(n_jobs=-1)
job_kwargs = si.get_global_job_kwargs()

In [30]:
part_number = recording.get_annotation("probes_info")[0]["part_number"]
threshold_uv = saturation_thresholds_uv[part_number]

In [31]:
threshold_uv

500000.0

In [36]:
recording = recording.frame_slice(0, int(120 * recording.sampling_frequency))

In [37]:
exclude_sweep_ms = 1

num_channels = recording.get_num_channels()
abs_thresholds = np.array([threshold_uv / recording.get_channel_gains()[0]] * num_channels)

saturation_neg = DetectPeakByChannel(recording, noise_levels=np.ones(num_channels))
saturation_neg.args = ("neg", abs_thresholds, exclude_sweep_ms)
saturation_pos = DetectPeakByChannel(recording, noise_levels=np.ones(num_channels))
saturation_pos.args = ("pos", abs_thresholds, exclude_sweep_ms)

nodes = [saturation_neg, saturation_pos]

job_name = f"finding saturation events"
squeeze_output = True

outs = run_node_pipeline(
    recording,
    nodes,
    job_kwargs,
    job_name=job_name,
    squeeze_output=squeeze_output,
)


finding saturation events:   0%|          | 0/120 [00:00<?, ?it/s]

In [39]:
outs[1]

array([],
      dtype=[('sample_index', '<i8'), ('channel_index', '<i8'), ('amplitude', '<f8'), ('segment_index', '<i8')])

In [None]:
class SaturationDetector(PipelineNode):
    def __init__(
        self,
        recording: BaseRecording,
        saturation_threshold: float,
    ):
        PipelineNode.__init__(self, recording=recording, return_output=True)
        self._channel_gains = recording.get_channel_gains()
        self._kwargs = dict(saturation_threshold=saturation_threshold)

    def get_trace_margin(self):
        # can optionaly be overwritten
        return 0

    def get_dtype(self):
        raise NotImplementedError

    def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *args):
        raise NotImplementedError