From 655e7039c88cae93388df9e0e39363fc981638f0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 14 Oct 2025 11:29:07 +0200 Subject: [PATCH 1/3] faster write_binary_recording() --- src/spikeinterface/core/recording_tools.py | 73 ++++++++++++++-------- 1 file changed, 48 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 4b4e1eb3cd..757069f33d 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -61,7 +61,7 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest): worker_ctx["byte_offset"] = byte_offest worker_ctx["dtype"] = np.dtype(dtype) - file_dict = {segment_index: open(file_path, "r+") for segment_index, file_path in file_path_dict.items()} + file_dict = {segment_index: open(file_path, "rb+") for segment_index, file_path in file_path_dict.items()} worker_ctx["file_dict"] = file_dict return worker_ctx @@ -140,6 +140,47 @@ def write_binary_recording( executor.run() +# # used by write_binary_recording + ChunkRecordingExecutor +# def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): +# # recover variables of the worker +# recording = worker_ctx["recording"] +# dtype = worker_ctx["dtype"] +# byte_offset = worker_ctx["byte_offset"] +# file = worker_ctx["file_dict"][segment_index] + +# num_channels = recording.get_num_channels() +# dtype_size_bytes = np.dtype(dtype).itemsize + +# # Calculate byte offsets for the start and end frames relative to the entire recording +# start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes +# end_byte = byte_offset + end_frame * num_channels * dtype_size_bytes + +# # The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY +# memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) +# memmap_offset *= mmap.ALLOCATIONGRANULARITY + +# # This maps in bytes the region of the memmap that corresponds to the chunk +# length = (end_byte - start_byte) + start_offset +# memmap_obj = mmap.mmap(file.fileno(), length=length, access=mmap.ACCESS_WRITE, offset=memmap_offset) + +# # To use numpy semantics we use the array interface of the memmap object +# num_frames = end_frame - start_frame +# shape = (num_frames, num_channels) +# memmap_array = np.ndarray(shape=shape, dtype=dtype, buffer=memmap_obj, offset=start_offset) + +# # Extract the traces and store them in the memmap array +# traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + +# if traces.dtype != dtype: +# traces = traces.astype(dtype, copy=False) + +# memmap_array[...] = traces + +# memmap_obj.flush() + +# memmap_obj.close() + + # used by write_binary_recording + ChunkRecordingExecutor def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker @@ -153,32 +194,14 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): # Calculate byte offsets for the start and end frames relative to the entire recording start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes - end_byte = byte_offset + end_frame * num_channels * dtype_size_bytes - - # The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY - memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) - memmap_offset *= mmap.ALLOCATIONGRANULARITY - - # This maps in bytes the region of the memmap that corresponds to the chunk - length = (end_byte - start_byte) + start_offset - memmap_obj = mmap.mmap(file.fileno(), length=length, access=mmap.ACCESS_WRITE, offset=memmap_offset) - - # To use numpy semantics we use the array interface of the memmap object - num_frames = end_frame - start_frame - shape = (num_frames, num_channels) - memmap_array = np.ndarray(shape=shape, dtype=dtype, buffer=memmap_obj, offset=start_offset) + # end_byte = byte_offset + end_frame * num_channels * dtype_size_bytes - # Extract the traces and store them in the memmap array traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - - if traces.dtype != dtype: - traces = traces.astype(dtype, copy=False) - - memmap_array[...] = traces - - memmap_obj.flush() - - memmap_obj.close() + + traces = traces.astype(dtype, order="c", copy=False) + + file.seek(start_byte) + file.write(traces.data) write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc) From 3e234acbb7c82dc8e942a4b482e80273dcae0945 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 15 Oct 2025 09:53:35 +0200 Subject: [PATCH 2/3] fix test --- src/spikeinterface/core/recording_tools.py | 47 ++----------------- .../core/tests/test_baserecording.py | 13 +++-- 2 files changed, 13 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 757069f33d..dfba7d1f22 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -140,47 +140,6 @@ def write_binary_recording( executor.run() -# # used by write_binary_recording + ChunkRecordingExecutor -# def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): -# # recover variables of the worker -# recording = worker_ctx["recording"] -# dtype = worker_ctx["dtype"] -# byte_offset = worker_ctx["byte_offset"] -# file = worker_ctx["file_dict"][segment_index] - -# num_channels = recording.get_num_channels() -# dtype_size_bytes = np.dtype(dtype).itemsize - -# # Calculate byte offsets for the start and end frames relative to the entire recording -# start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes -# end_byte = byte_offset + end_frame * num_channels * dtype_size_bytes - -# # The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY -# memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) -# memmap_offset *= mmap.ALLOCATIONGRANULARITY - -# # This maps in bytes the region of the memmap that corresponds to the chunk -# length = (end_byte - start_byte) + start_offset -# memmap_obj = mmap.mmap(file.fileno(), length=length, access=mmap.ACCESS_WRITE, offset=memmap_offset) - -# # To use numpy semantics we use the array interface of the memmap object -# num_frames = end_frame - start_frame -# shape = (num_frames, num_channels) -# memmap_array = np.ndarray(shape=shape, dtype=dtype, buffer=memmap_obj, offset=start_offset) - -# # Extract the traces and store them in the memmap array -# traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - -# if traces.dtype != dtype: -# traces = traces.astype(dtype, copy=False) - -# memmap_array[...] = traces - -# memmap_obj.flush() - -# memmap_obj.close() - - # used by write_binary_recording + ChunkRecordingExecutor def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker @@ -192,16 +151,16 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): num_channels = recording.get_num_channels() dtype_size_bytes = np.dtype(dtype).itemsize - # Calculate byte offsets for the start and end frames relative to the entire recording + # Calculate byte offsets for the start frames relative to the entire recording start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes - # end_byte = byte_offset + end_frame * num_channels * dtype_size_bytes traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - traces = traces.astype(dtype, order="c", copy=False) file.seek(start_byte) file.write(traces.data) + # flush is important!! + file.flush() write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index fc01122269..cdd5897675 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -26,11 +26,14 @@ def test_BaseRecording(create_cache_folder): num_samples = 30 sampling_frequency = 10000 dtype = "int16" + seed = None + rng = np.random.default_rng(seed=seed) file_paths = [cache_folder / f"test_base_recording_{i}.raw" for i in range(num_seg)] for i in range(num_seg): a = np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan)) - a[:] = np.random.randn(*a.shape).astype(dtype) + a[:] = rng.normal(scale=5000, size=a.shape).astype(dtype) + rec = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_chan, dtype=dtype ) @@ -201,6 +204,7 @@ def test_BaseRecording(create_cache_folder): positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) traces2 = rec2.get_traces(segment_index=0) + assert np.array_equal(traces2, rec_p.get_traces(segment_index=0)) # from probeinterface.plotting import plot_probe_group, plot_probe @@ -468,5 +472,8 @@ def test_time_slice_with_time_vector(): if __name__ == "__main__": - # test_BaseRecording() - test_interleaved_probegroups() + import tempfile + tmp_path = Path(tempfile.mkdtemp()) + + test_BaseRecording(tmp_path) + # test_interleaved_probegroups() From 02b1af56775c08d7102e019909bfda0687aeca3f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:02:52 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/recording_tools.py | 2 +- src/spikeinterface/core/tests/test_baserecording.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index dfba7d1f22..fd95b11e6a 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -156,7 +156,7 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) traces = traces.astype(dtype, order="c", copy=False) - + file.seek(start_byte) file.write(traces.data) # flush is important!! diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index cdd5897675..9de800b33d 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -473,6 +473,7 @@ def test_time_slice_with_time_vector(): if __name__ == "__main__": import tempfile + tmp_path = Path(tempfile.mkdtemp()) test_BaseRecording(tmp_path)