Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size):


def divide_recording_into_chunks(recording, chunk_size):
all_chunks = []
recording_slices = []
for segment_index in range(recording.get_num_segments()):
num_frames = recording.get_num_samples(segment_index)
chunks = divide_segment_into_chunks(num_frames, chunk_size)
all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
return all_chunks
recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
return recording_slices


def ensure_n_jobs(recording, n_jobs=1):
Expand Down Expand Up @@ -387,13 +387,13 @@ def __init__(
f"chunk_duration={chunk_duration_str}",
)

def run(self, all_chunks=None):
def run(self, recording_slices=None):
"""
Runs the defined jobs.
"""

if all_chunks is None:
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)
if recording_slices is None:
recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size)

if self.handle_returns:
returns = []
Expand All @@ -402,17 +402,17 @@ def run(self, all_chunks=None):

if self.n_jobs == 1:
if self.progress_bar:
all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name)
recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name)

worker_ctx = self.init_func(*self.init_args)
for segment_index, frame_start, frame_stop in all_chunks:
for segment_index, frame_start, frame_stop in recording_slices:
res = self.func(segment_index, frame_start, frame_stop, worker_ctx)
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)
else:
n_jobs = min(self.n_jobs, len(all_chunks))
n_jobs = min(self.n_jobs, len(recording_slices))

# parallel
with ProcessPoolExecutor(
Expand All @@ -421,10 +421,10 @@ def run(self, all_chunks=None):
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
) as executor:
results = executor.map(function_wrapper, all_chunks)
results = executor.map(function_wrapper, recording_slices)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(all_chunks))
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def run_node_pipeline(
names=None,
verbose=False,
skip_after_n_peaks=None,
recording_slices=None,
):
"""
Machinery to compute in parallel operations on peaks and traces.
Expand Down Expand Up @@ -540,6 +541,10 @@ def run_node_pipeline(
skip_after_n_peaks : None | int
Skip the computation after n_peaks.
This is not an exact because internally this skip is done per worker in average.
recording_slices : None | list[tuple]
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
It must be a list of (segment_index, frame_start, frame_stop).
If None (default), the function iterates over the entire duration of the recording.

Returns
-------
Expand Down Expand Up @@ -578,7 +583,7 @@ def run_node_pipeline(
**job_kwargs,
)

processor.run()
processor.run(recording_slices=recording_slices)

outs = gather_func.finalize_buffers(squeeze_output=squeeze_output)
return outs
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def append_noise_chunk(res):
gather_func=append_noise_chunk,
**job_kwargs,
)
executor.run(all_chunks=recording_slices)
executor.run(recording_slices=recording_slices)
noise_levels_chunks = np.stack(noise_levels_chunks)
noise_levels = np.mean(noise_levels_chunks, axis=0)

Expand Down
19 changes: 14 additions & 5 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil

from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording

from spikeinterface.core.job_tools import divide_recording_into_chunks

# from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import (
Expand Down Expand Up @@ -191,8 +191,8 @@ def test_run_node_pipeline(cache_folder_creation):
unpickled_node = pickle.loads(pickled_node)


def test_skip_after_n_peaks():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])
def test_skip_after_n_peaks_and_recording_slices():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205)

# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)
Expand All @@ -211,18 +211,27 @@ def test_skip_after_n_peaks():
node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True)
nodes = [node0, node1]

# skip
skip_after_n_peaks = 30
some_amplitudes = run_node_pipeline(
recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks
)

assert some_amplitudes.size >= skip_after_n_peaks
assert some_amplitudes.size < spikes.size

# slices : 1 every 4
recording_slices = divide_recording_into_chunks(recording, 10_000)
recording_slices = recording_slices[::4]
some_amplitudes = run_node_pipeline(
recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices
)
tolerance = 1.2
assert some_amplitudes.size < (spikes.size // 4) * tolerance


# the following is for testing locally with python or ipython. It is not used in ci or with pytest.
if __name__ == "__main__":
# folder = Path("./cache_folder/core")
# test_run_node_pipeline(folder)

test_skip_after_n_peaks()
test_skip_after_n_peaks_and_recording_slices()
6 changes: 6 additions & 0 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def detect_peaks(
folder=None,
names=None,
skip_after_n_peaks=None,
recording_slices=None,
**kwargs,
):
"""Peak detection based on threshold crossing in term of k x MAD.
Expand All @@ -83,6 +84,10 @@ def detect_peaks(
skip_after_n_peaks : None | int
Skip the computation after n_peaks.
This is not an exact because internally this skip is done per worker in average.
recording_slices : None | list[tuple]
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
It must be a list of (segment_index, frame_start, frame_stop).
If None (default), the function iterates over the entire duration of the recording.

{method_doc}
{job_doc}
Expand Down Expand Up @@ -135,6 +140,7 @@ def detect_peaks(
folder=folder,
names=names,
skip_after_n_peaks=skip_after_n_peaks,
recording_slices=recording_slices,
)
return outs

Expand Down
Loading