diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 3947a0decc..0797313793 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -14,7 +14,15 @@ class ChannelsAggregationRecording(BaseRecording): """ - def __init__(self, recording_list_or_dict, renamed_channel_ids=None): + def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, recording_list=None): + + if recording_list is not None: + warnings.warn( + "`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + recording_list_or_dict = recording_list if isinstance(recording_list_or_dict, dict): recording_list = list(recording_list_or_dict.values()) @@ -258,12 +266,4 @@ def aggregate_channels( The aggregated recording object """ - if recording_list is not None: - warnings.warn( - "`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.", - category=DeprecationWarning, - stacklevel=2, - ) - recording_list_or_dict = recording_list - - return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids) + return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids, recording_list) diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index d9c20aa505..bfaf97ec4a 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -1,7 +1,14 @@ import pytest import numpy as np -from spikeinterface import generate_ground_truth_recording, create_sorting_analyzer, load, SortingAnalyzer, Templates +from spikeinterface import ( + generate_ground_truth_recording, + create_sorting_analyzer, + load, + SortingAnalyzer, + Templates, + aggregate_channels, +) from spikeinterface.core.motion import Motion from spikeinterface.core.generate import generate_unit_locations, generate_templates from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal @@ -176,6 +183,24 @@ def test_load_motion(tmp_path, generate_motion_object): assert motion == motion_loaded +def test_load_aggregate_recording_from_json(generate_recording_sorting, tmp_path): + """ + Save, then load an aggregated recording using its provenance.json file. + """ + + recording, _ = generate_recording_sorting + + recording.set_property("group", [0, 0, 1, 1]) + list_of_recs = list(recording.split_by("group").values()) + aggregated_rec = aggregate_channels(list_of_recs) + + recording_path = tmp_path / "aggregated_recording" + aggregated_rec.save_to_folder(folder=recording_path) + loaded_rec = load(recording_path / "provenance.json", base_folder=recording_path) + + assert np.all(loaded_rec.get_property("group") == recording.get_property("group")) + + @pytest.mark.streaming_extractors @pytest.mark.skipif(not HAVE_S3, reason="s3fs not installed") def test_remote_recording():