Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def split_job_kwargs(mixed_kwargs):
def divide_segment_into_chunks(num_frames, chunk_size):
if chunk_size is None:
chunks = [(0, num_frames)]
elif chunk_size > num_frames:
chunks = [(0, num_frames)]
else:
n = num_frames // chunk_size

Expand Down Expand Up @@ -245,12 +247,12 @@ def ensure_chunk_size(
else:
raise ValueError("chunk_duration must be str or float")
else:
# Edge case to define single chunk per segment for n_jobs=1.
# All chunking parameters equal None mean single chunk per segment
if n_jobs == 1:
# not chunk computing
# TODO Discuss, Sam, is this something that we want to do?
# Even in single process mode, we should chunk the data to avoid loading the whole thing into memory I feel
# Am I wrong?
chunk_size = None
num_segments = recording.get_num_segments()
samples_in_larger_segment = max([recording.get_num_samples(segment) for segment in range(num_segments)])
chunk_size = samples_in_larger_segment
else:
raise ValueError("For n_jobs >1 you must specify total_memory or chunk_size or chunk_memory")

Expand Down
13 changes: 10 additions & 3 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ def test_ensure_n_jobs():


def test_ensure_chunk_size():
recording = generate_recording(num_channels=2)
recording = generate_recording(num_channels=2, durations=[5.0, 2.5]) # This is the default value for two semgents
dtype = recording.get_dtype()
assert dtype == "float32"
# make serializable
recording = recording.save()

chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2)
assert chunk_size == 32000000
Expand All @@ -69,6 +67,15 @@ def test_ensure_chunk_size():
chunk_size = ensure_chunk_size(recording, chunk_duration="500ms")
assert chunk_size == 15000

# Test edge case to define single chunk for n_jobs=1
chunk_size = ensure_chunk_size(recording, n_jobs=1, chunk_size=None)
chunks = divide_recording_into_chunks(recording, chunk_size)
assert len(chunks) == recording.get_num_segments()
for chunk in chunks:
segment_index, start_frame, end_frame = chunk
assert start_frame == 0
assert end_frame == recording.get_num_frames(segment_index=segment_index)


def func(segment_index, start_frame, end_frame, worker_ctx):
import os
Expand Down