From c33221d65a99f5d8366e8b85c87e080d35a11af6 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 2 Apr 2025 09:56:58 +0100 Subject: [PATCH 1/3] move deprecation --- .../core/channelsaggregationrecording.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 3947a0decc..0797313793 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -14,7 +14,15 @@ class ChannelsAggregationRecording(BaseRecording): """ - def __init__(self, recording_list_or_dict, renamed_channel_ids=None): + def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, recording_list=None): + + if recording_list is not None: + warnings.warn( + "`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + recording_list_or_dict = recording_list if isinstance(recording_list_or_dict, dict): recording_list = list(recording_list_or_dict.values()) @@ -258,12 +266,4 @@ def aggregate_channels( The aggregated recording object """ - if recording_list is not None: - warnings.warn( - "`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.", - category=DeprecationWarning, - stacklevel=2, - ) - recording_list_or_dict = recording_list - - return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids) + return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids, recording_list) From c41875acba320c8eed07068bcc5a6806b6a89f4b Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 2 Apr 2025 10:24:08 +0100 Subject: [PATCH 2/3] add test --- src/spikeinterface/core/tests/test_loading.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index d9c20aa505..c3643f06ef 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from spikeinterface import generate_ground_truth_recording, create_sorting_analyzer, load, SortingAnalyzer, Templates +from spikeinterface import generate_ground_truth_recording, create_sorting_analyzer, load, SortingAnalyzer, Templates, aggregate_channels from spikeinterface.core.motion import Motion from spikeinterface.core.generate import generate_unit_locations, generate_templates from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal @@ -176,6 +176,24 @@ def test_load_motion(tmp_path, generate_motion_object): assert motion == motion_loaded +def test_load_aggregate_recording_from_json(generate_recording_sorting, tmp_path): + """ + Save, then load an aggregated recording using its provenance.json file. + """ + + recording, _ = generate_recording_sorting + + recording.set_property("group", [0,0,1,1]) + list_of_recs = list(recording.split_by('group').values()) + aggregated_rec = aggregate_channels(list_of_recs) + + recording_path = tmp_path / "aggregated_recording" + aggregated_rec.save_to_folder(folder=recording_path) + loaded_rec = load(recording_path / "provenance.json", base_folder=recording_path) + + assert np.all(loaded_rec.get_property('group') == recording.get_property('group')) + + @pytest.mark.streaming_extractors @pytest.mark.skipif(not HAVE_S3, reason="s3fs not installed") def test_remote_recording(): From 16dafc0f33c2b6450b7a3d7f59b4855df10d5fdc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Apr 2025 09:26:06 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_loading.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index c3643f06ef..bfaf97ec4a 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -1,7 +1,14 @@ import pytest import numpy as np -from spikeinterface import generate_ground_truth_recording, create_sorting_analyzer, load, SortingAnalyzer, Templates, aggregate_channels +from spikeinterface import ( + generate_ground_truth_recording, + create_sorting_analyzer, + load, + SortingAnalyzer, + Templates, + aggregate_channels, +) from spikeinterface.core.motion import Motion from spikeinterface.core.generate import generate_unit_locations, generate_templates from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal @@ -183,15 +190,15 @@ def test_load_aggregate_recording_from_json(generate_recording_sorting, tmp_path recording, _ = generate_recording_sorting - recording.set_property("group", [0,0,1,1]) - list_of_recs = list(recording.split_by('group').values()) + recording.set_property("group", [0, 0, 1, 1]) + list_of_recs = list(recording.split_by("group").values()) aggregated_rec = aggregate_channels(list_of_recs) recording_path = tmp_path / "aggregated_recording" aggregated_rec.save_to_folder(folder=recording_path) loaded_rec = load(recording_path / "provenance.json", base_folder=recording_path) - assert np.all(loaded_rec.get_property('group') == recording.get_property('group')) + assert np.all(loaded_rec.get_property("group") == recording.get_property("group")) @pytest.mark.streaming_extractors