From b083bb8b29d6bd2fe3638254238fe921bd2808a6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 3 Nov 2025 21:37:29 +0100 Subject: [PATCH 1/2] Add sparsity_mask option in estimate_templates_with_accumulator() --- .../core/tests/test_waveform_tools.py | 30 ++++++++++---- src/spikeinterface/core/waveform_tools.py | 39 +++++++++++++++---- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 3c978181dd..bc0e93d8d3 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -227,16 +227,30 @@ def test_estimate_templates(): job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + # mask with differents sparsity + sparsity_mask = np.ones((sorting.unit_ids.size, recording.channel_ids.size), dtype=bool) + sparsity_mask[:4, :recording.channel_ids.size//2 -1 ] = False + sparsity_mask[4:, recording.channel_ids.size//2:] = False + + for operator in ("average", "median"): - templates = estimate_templates( + templates_array = estimate_templates( recording, spikes, sorting.unit_ids, nbefore, nafter, operator=operator, return_in_uV=True, **job_kwargs ) # print(templates.shape) - assert templates.shape[0] == sorting.unit_ids.size - assert templates.shape[1] == nbefore + nafter - assert templates.shape[2] == recording.get_num_channels() + assert templates_array.shape[0] == sorting.unit_ids.size + assert templates_array.shape[1] == nbefore + nafter + assert templates_array.shape[2] == recording.get_num_channels() - assert np.any(templates != 0) + assert np.any(templates_array != 0) + + sparse_templates_array = estimate_templates( + recording, spikes, sorting.unit_ids, nbefore, nafter, operator=operator, + return_in_uV=True, sparsity_mask=sparsity_mask, **job_kwargs + ) + n_chan = np.max(np.sum(sparsity_mask, axis=1)) + assert n_chan == sparse_templates_array.shape[2] + assert np.any(sparse_templates_array == 0) # import matplotlib.pyplot as plt # fig, ax = plt.subplots() @@ -247,7 +261,7 @@ def test_estimate_templates(): if __name__ == "__main__": - cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core" - test_waveform_tools(cache_folder) - test_estimate_templates_with_accumulator() + # cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core" + # test_waveform_tools(cache_folder) + # test_estimate_templates_with_accumulator() test_estimate_templates() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a055a68696..a405229d11 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -744,12 +744,13 @@ def estimate_templates( operator: str = "average", return_scaled=None, return_in_uV=True, + sparsity_mask=None, job_name=None, **job_kwargs, ): """ Estimate dense templates with "average" or "median". - If "average" internally estimate_templates_with_accumulator() is used to saved memory/ + If "average" internally estimate_templates_with_accumulator() is used to saved memory. Parameters ---------- @@ -770,6 +771,8 @@ def estimate_templates( return_in_uV : bool, default: True If True and the recording has scaling (gain_to_uV and offset_to_uV properties), traces are scaled to uV + sparsity_mask: None or array of bool, default: None + If not None shape must be must be (len(unit_ids), len(channel_ids)) Returns ------- @@ -791,7 +794,7 @@ def estimate_templates( if operator == "average": templates_array = estimate_templates_with_accumulator( - recording, spikes, unit_ids, nbefore, nafter, return_in_uV=return_in_uV, job_name=job_name, **job_kwargs + recording, spikes, unit_ids, nbefore, nafter, return_in_uV=return_in_uV, sparsity_mask=sparsity_mask, job_name=job_name, **job_kwargs ) elif operator == "median": all_waveforms, wf_array_info = extract_waveforms_to_single_buffer( @@ -802,6 +805,7 @@ def estimate_templates( nafter, mode="shared_memory", return_in_uV=return_in_uV, + sparsity_mask=sparsity_mask, copy=False, **job_kwargs, ) @@ -828,6 +832,7 @@ def estimate_templates_with_accumulator( nafter: int, return_scaled=None, return_in_uV=True, + sparsity_mask=None, job_name=None, return_std: bool = False, verbose: bool = False, @@ -859,6 +864,8 @@ def estimate_templates_with_accumulator( return_in_uV : bool, default: True If True and the recording has scaling (gain_to_uV and offset_to_uV properties), traces are scaled to uV + sparsity_mask: None or array of bool, default: None + If not None shape must be must be (len(unit_ids), len(channel_ids)) return_std: bool, default: False If True, the standard deviation is also computed. @@ -882,10 +889,14 @@ def estimate_templates_with_accumulator( job_kwargs = fix_job_kwargs(job_kwargs) num_worker = job_kwargs["n_jobs"] - num_chans = recording.get_num_channels() + if sparsity_mask is None: + num_chans = int(recording.get_num_channels()) + else: + num_chans = int(max(np.sum(sparsity_mask, axis=1))) # This is a numpy scalar, so we cast to int num_units = len(unit_ids) - + shape = (num_worker, num_units, nbefore + nafter, num_chans) + dtype = np.dtype("float32") waveform_accumulator_per_worker, shm = make_shared_array(shape, dtype) shm_name = shm.name @@ -909,6 +920,7 @@ def estimate_templates_with_accumulator( nbefore, nafter, return_in_uV, + sparsity_mask, ) if job_name is None: @@ -965,6 +977,7 @@ def _init_worker_estimate_templates( nbefore, nafter, return_in_uV, + sparsity_mask, ): worker_dict = {} worker_dict["recording"] = recording @@ -972,6 +985,7 @@ def _init_worker_estimate_templates( worker_dict["nbefore"] = nbefore worker_dict["nafter"] = nafter worker_dict["return_in_uV"] = return_in_uV + worker_dict["sparsity_mask"] = sparsity_mask from multiprocessing.shared_memory import SharedMemory import multiprocessing @@ -1009,6 +1023,7 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic waveform_squared_accumulator_per_worker = worker_dict.get("waveform_squared_accumulator_per_worker", None) worker_index = worker_dict["worker_index"] return_in_uV = worker_dict["return_in_uV"] + sparsity_mask = worker_dict["sparsity_mask"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -1040,6 +1055,16 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic unit_index = spikes[spike_index]["unit_index"] wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] - waveform_accumulator_per_worker[worker_index, unit_index, :, :] += wf - if waveform_squared_accumulator_per_worker is not None: - waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :] += wf**2 + if sparsity_mask is None: + waveform_accumulator_per_worker[worker_index, unit_index, :, :] += wf + if waveform_squared_accumulator_per_worker is not None: + waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :] += wf**2 + + + else: + mask = sparsity_mask[unit_index, :] + wf = wf[:, mask] + waveform_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf + if waveform_squared_accumulator_per_worker is not None: + waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :wf.shape[1]] += wf**2 + From d7359366b243693c30bc52313485c66f717d32f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Nov 2025 20:38:24 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/tests/test_waveform_tools.py | 16 +++++++++++----- src/spikeinterface/core/waveform_tools.py | 16 +++++++++++----- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index bc0e93d8d3..c128e598a9 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -229,9 +229,8 @@ def test_estimate_templates(): # mask with differents sparsity sparsity_mask = np.ones((sorting.unit_ids.size, recording.channel_ids.size), dtype=bool) - sparsity_mask[:4, :recording.channel_ids.size//2 -1 ] = False - sparsity_mask[4:, recording.channel_ids.size//2:] = False - + sparsity_mask[:4, : recording.channel_ids.size // 2 - 1] = False + sparsity_mask[4:, recording.channel_ids.size // 2 :] = False for operator in ("average", "median"): templates_array = estimate_templates( @@ -245,8 +244,15 @@ def test_estimate_templates(): assert np.any(templates_array != 0) sparse_templates_array = estimate_templates( - recording, spikes, sorting.unit_ids, nbefore, nafter, operator=operator, - return_in_uV=True, sparsity_mask=sparsity_mask, **job_kwargs + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + operator=operator, + return_in_uV=True, + sparsity_mask=sparsity_mask, + **job_kwargs, ) n_chan = np.max(np.sum(sparsity_mask, axis=1)) assert n_chan == sparse_templates_array.shape[2] diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a405229d11..920de33c2f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -794,7 +794,15 @@ def estimate_templates( if operator == "average": templates_array = estimate_templates_with_accumulator( - recording, spikes, unit_ids, nbefore, nafter, return_in_uV=return_in_uV, sparsity_mask=sparsity_mask, job_name=job_name, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_in_uV=return_in_uV, + sparsity_mask=sparsity_mask, + job_name=job_name, + **job_kwargs, ) elif operator == "median": all_waveforms, wf_array_info = extract_waveforms_to_single_buffer( @@ -894,7 +902,7 @@ def estimate_templates_with_accumulator( else: num_chans = int(max(np.sum(sparsity_mask, axis=1))) # This is a numpy scalar, so we cast to int num_units = len(unit_ids) - + shape = (num_worker, num_units, nbefore + nafter, num_chans) dtype = np.dtype("float32") @@ -1060,11 +1068,9 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic if waveform_squared_accumulator_per_worker is not None: waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :] += wf**2 - else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] waveform_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf if waveform_squared_accumulator_per_worker is not None: - waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :wf.shape[1]] += wf**2 - + waveform_squared_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf**2