diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 27f05bb36b..7a6172369b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size): def divide_recording_into_chunks(recording, chunk_size): - all_chunks = [] + recording_slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return all_chunks + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return recording_slices def ensure_n_jobs(recording, n_jobs=1): @@ -387,13 +387,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self, all_chunks=None): + def run(self, recording_slices=None): """ Runs the defined jobs. """ - if all_chunks is None: - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + if recording_slices is None: + recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] @@ -402,17 +402,17 @@ def run(self, all_chunks=None): if self.n_jobs == 1: if self.progress_bar: - all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name) + recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name) worker_ctx = self.init_func(*self.init_args) - for segment_index, frame_start, frame_stop in all_chunks: + for segment_index, frame_start, frame_stop in recording_slices: res = self.func(segment_index, frame_start, frame_stop, worker_ctx) if self.handle_returns: returns.append(res) if self.gather_func is not None: self.gather_func(res) else: - n_jobs = min(self.n_jobs, len(all_chunks)) + n_jobs = min(self.n_jobs, len(recording_slices)) # parallel with ProcessPoolExecutor( @@ -421,10 +421,10 @@ def run(self, all_chunks=None): mp_context=mp.get_context(self.mp_context), initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), ) as executor: - results = executor.map(function_wrapper, all_chunks) + results = executor.map(function_wrapper, recording_slices) if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(all_chunks)) + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) for res in results: if self.handle_returns: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index d90a20902d..53c2445c77 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,6 +489,7 @@ def run_node_pipeline( names=None, verbose=False, skip_after_n_peaks=None, + recording_slices=None, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -540,6 +541,10 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the function iterates over the entire duration of the recording. Returns ------- @@ -578,7 +583,7 @@ def run_node_pipeline( **job_kwargs, ) - processor.run() + processor.run(recording_slices=recording_slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2ab74ce51e..4aabbfd587 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -806,7 +806,7 @@ def append_noise_chunk(res): gather_func=append_noise_chunk, **job_kwargs, ) - executor.run(all_chunks=recording_slices) + executor.run(recording_slices=recording_slices) noise_levels_chunks = np.stack(noise_levels_chunks) noise_levels = np.mean(noise_levels_chunks, axis=0) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index deef2291c6..028eaecf12 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -4,7 +4,7 @@ import shutil from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording - +from spikeinterface.core.job_tools import divide_recording_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -191,8 +191,8 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) -def test_skip_after_n_peaks(): - recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) +def test_skip_after_n_peaks_and_recording_slices(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205) # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) @@ -211,18 +211,27 @@ def test_skip_after_n_peaks(): node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) nodes = [node0, node1] + # skip skip_after_n_peaks = 30 some_amplitudes = run_node_pipeline( recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks ) - assert some_amplitudes.size >= skip_after_n_peaks assert some_amplitudes.size < spikes.size + # slices : 1 every 4 + recording_slices = divide_recording_into_chunks(recording, 10_000) + recording_slices = recording_slices[::4] + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices + ) + tolerance = 1.2 + assert some_amplitudes.size < (spikes.size // 4) * tolerance + # the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core") # test_run_node_pipeline(folder) - test_skip_after_n_peaks() + test_skip_after_n_peaks_and_recording_slices() diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 5b1d33b334..d03744f8f9 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -57,6 +57,7 @@ def detect_peaks( folder=None, names=None, skip_after_n_peaks=None, + recording_slices=None, **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -83,6 +84,10 @@ def detect_peaks( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the function iterates over the entire duration of the recording. {method_doc} {job_doc} @@ -135,6 +140,7 @@ def detect_peaks( folder=folder, names=names, skip_after_n_peaks=skip_after_n_peaks, + recording_slices=recording_slices, ) return outs