From dfc84f8c2184232655b202006ba5c4d540b7520b Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 13 Mar 2025 09:27:41 +0000 Subject: [PATCH 1/7] Allow aggregate_channels to accept a dict of recordings --- .../core/channelsaggregationrecording.py | 13 +- .../test_channelsaggregationrecording.py | 151 +++++++++++------- 2 files changed, 101 insertions(+), 63 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 4fa1d88974..ca45268458 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -13,8 +13,15 @@ class ChannelsAggregationRecording(BaseRecording): """ - def __init__(self, recording_list, renamed_channel_ids=None): + def __init__(self, recording_list_or_dict, renamed_channel_ids=None): + if isinstance(recording_list_or_dict, dict): + recording_list = list(recording_list_or_dict.values()) + elif isinstance(recording_list_or_dict, list): + recording_list = recording_list_or_dict + else: + raise TypeError("`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings.") + self._recordings = recording_list self._perform_consistency_checks() @@ -207,8 +214,8 @@ def aggregate_channels(recording_list, renamed_channel_ids=None): Parameters ---------- - recording_list: list - List of BaseRecording objects to aggregate + recording_list: list | dict + List or dict of BaseRecording objects to aggregate renamed_channel_ids: array-like If given, channel ids are renamed as provided. diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 16a91a55e1..c637e3e6b3 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -1,7 +1,6 @@ import numpy as np from spikeinterface.core import aggregate_channels - from spikeinterface.core import generate_recording @@ -21,65 +20,97 @@ def test_channelsaggregationrecording(): recording2.set_channel_locations([[20.0, 20.0], [20.0, 40.0], [20.0, 60.0]]) recording3.set_channel_locations([[40.0, 20.0], [40.0, 40.0], [40.0, 60.0]]) - # test num channels - recording_agg = aggregate_channels([recording1, recording2, recording3]) - assert len(recording_agg.get_channel_ids()) == 3 * num_channels - - assert np.allclose(recording_agg.get_times(0), recording1.get_times(0)) - - # test traces - channel_ids = recording1.get_channel_ids() - - for seg in range(num_seg): - # single channels - traces1_1 = recording1.get_traces(channel_ids=[channel_ids[1]], segment_index=seg) - traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) - traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) - - assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) - assert np.allclose( - traces2_0, - recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), - ) - assert np.allclose( - traces3_2, - recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), - ) - # all traces - traces1 = recording1.get_traces(segment_index=seg) - traces2 = recording2.get_traces(segment_index=seg) - traces3 = recording3.get_traces(segment_index=seg) - - assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg)) - assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg)) - assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg)) - - # test rename channels - renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] - recording_agg_renamed = aggregate_channels( - [recording1, recording2, recording3], renamed_channel_ids=renamed_channel_ids - ) - assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids()) - - # test properties - # complete property - recording1.set_property("brain_area", ["CA1"] * num_channels) - recording2.set_property("brain_area", ["CA2"] * num_channels) - recording3.set_property("brain_area", ["CA3"] * num_channels) - - # skip for inconsistency - recording1.set_property("template", np.zeros((num_channels, 4, 30))) - recording2.set_property("template", np.zeros((num_channels, 20, 50))) - recording3.set_property("template", np.zeros((num_channels, 2, 10))) - - # incomplete property - recording1.set_property("quality", ["good"] * num_channels) - recording2.set_property("quality", ["bad"] * num_channels) - - recording_agg_prop = aggregate_channels([recording1, recording2, recording3]) - assert "brain_area" in recording_agg_prop.get_property_keys() - assert "quality" not in recording_agg_prop.get_property_keys() - print(recording_agg_prop.get_property("brain_area")) + recordings_list_possibilities = [ + [recording1, recording2, recording3], + {0: recording1, 1: recording2, 2: recording3} + ] + + for recordings_list in recordings_list_possibilities: + + # test num channels + recording_agg = aggregate_channels(recordings_list) + assert len(recording_agg.get_channel_ids()) == 3 * num_channels + + assert np.allclose(recording_agg.get_times(0), recording1.get_times(0)) + + # test traces + channel_ids = recording1.get_channel_ids() + + for seg in range(num_seg): + # single channels + traces1_1 = recording1.get_traces(channel_ids=[channel_ids[1]], segment_index=seg) + traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) + traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) + + assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) + assert np.allclose( + traces2_0, + recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), + ) + assert np.allclose( + traces3_2, + recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), + ) + # all traces + traces1 = recording1.get_traces(segment_index=seg) + traces2 = recording2.get_traces(segment_index=seg) + traces3 = recording3.get_traces(segment_index=seg) + + assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg)) + assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg)) + assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg)) + + # test rename channels + renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] + recording_agg_renamed = aggregate_channels( + recordings_list, renamed_channel_ids=renamed_channel_ids + ) + assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids()) + + # test properties + # complete property + recording1.set_property("brain_area", ["CA1"] * num_channels) + recording2.set_property("brain_area", ["CA2"] * num_channels) + recording3.set_property("brain_area", ["CA3"] * num_channels) + + # skip for inconsistency + recording1.set_property("template", np.zeros((num_channels, 4, 30))) + recording2.set_property("template", np.zeros((num_channels, 20, 50))) + recording3.set_property("template", np.zeros((num_channels, 2, 10))) + + # incomplete property + recording1.set_property("quality", ["good"] * num_channels) + recording2.set_property("quality", ["bad"] * num_channels) + + recording_agg_prop = aggregate_channels(recordings_list) + assert "brain_area" in recording_agg_prop.get_property_keys() + assert "quality" not in recording_agg_prop.get_property_keys() + print(recording_agg_prop.get_property("brain_area")) + +def test_split_then_aggreate_preserve_user_property(): + """ + Checks that splitting then aggregating a recording preserves the unit_id to property mapping. + """ + + num_channels = 10 + durations = [10,5] + recording = generate_recording(num_channels=num_channels, durations=durations, set_probe=False) + + recording.set_property(key='group', values=[2,0,1,1,1,0,1,0,1,2]) + + old_properties = recording.get_property(key='group') + old_channel_ids = recording.channel_ids + old_properties_ids_dict = dict(zip(old_channel_ids,old_properties)) + + split_recordings = recording.split_by('group') + + aggregated_recording = aggregate_channels(split_recordings) + + new_properties = aggregated_recording.get_property(key='group') + new_channel_ids = aggregated_recording.channel_ids + new_properties_ids_dict = dict(zip(new_channel_ids,new_properties)) + + assert np.all(old_properties_ids_dict == new_properties_ids_dict) def test_channel_aggregation_preserve_ids(): From 1cb4a44542fcfb34f458163406de46a3c9b2c5a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Mar 2025 09:33:02 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/channelsaggregationrecording.py | 6 ++-- .../test_channelsaggregationrecording.py | 29 ++++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index ca45268458..c339672802 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -20,8 +20,10 @@ def __init__(self, recording_list_or_dict, renamed_channel_ids=None): elif isinstance(recording_list_or_dict, list): recording_list = recording_list_or_dict else: - raise TypeError("`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings.") - + raise TypeError( + "`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings." + ) + self._recordings = recording_list self._perform_consistency_checks() diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index c637e3e6b3..4b4744fed2 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -21,8 +21,8 @@ def test_channelsaggregationrecording(): recording3.set_channel_locations([[40.0, 20.0], [40.0, 40.0], [40.0, 60.0]]) recordings_list_possibilities = [ - [recording1, recording2, recording3], - {0: recording1, 1: recording2, 2: recording3} + [recording1, recording2, recording3], + {0: recording1, 1: recording2, 2: recording3}, ] for recordings_list in recordings_list_possibilities: @@ -42,7 +42,9 @@ def test_channelsaggregationrecording(): traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) - assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) + assert np.allclose( + traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg) + ) assert np.allclose( traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), @@ -62,9 +64,7 @@ def test_channelsaggregationrecording(): # test rename channels renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] - recording_agg_renamed = aggregate_channels( - recordings_list, renamed_channel_ids=renamed_channel_ids - ) + recording_agg_renamed = aggregate_channels(recordings_list, renamed_channel_ids=renamed_channel_ids) assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids()) # test properties @@ -87,28 +87,29 @@ def test_channelsaggregationrecording(): assert "quality" not in recording_agg_prop.get_property_keys() print(recording_agg_prop.get_property("brain_area")) + def test_split_then_aggreate_preserve_user_property(): """ Checks that splitting then aggregating a recording preserves the unit_id to property mapping. """ num_channels = 10 - durations = [10,5] + durations = [10, 5] recording = generate_recording(num_channels=num_channels, durations=durations, set_probe=False) - - recording.set_property(key='group', values=[2,0,1,1,1,0,1,0,1,2]) - old_properties = recording.get_property(key='group') + recording.set_property(key="group", values=[2, 0, 1, 1, 1, 0, 1, 0, 1, 2]) + + old_properties = recording.get_property(key="group") old_channel_ids = recording.channel_ids - old_properties_ids_dict = dict(zip(old_channel_ids,old_properties)) + old_properties_ids_dict = dict(zip(old_channel_ids, old_properties)) - split_recordings = recording.split_by('group') + split_recordings = recording.split_by("group") aggregated_recording = aggregate_channels(split_recordings) - new_properties = aggregated_recording.get_property(key='group') + new_properties = aggregated_recording.get_property(key="group") new_channel_ids = aggregated_recording.channel_ids - new_properties_ids_dict = dict(zip(new_channel_ids,new_properties)) + new_properties_ids_dict = dict(zip(new_channel_ids, new_properties)) assert np.all(old_properties_ids_dict == new_properties_ids_dict) From 0ad72fda4670f1db8df25f78732f26fe7578ef63 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 13 Mar 2025 11:43:00 +0000 Subject: [PATCH 3/7] set group property for aggregated lists, if there are no preexisting properties --- .../core/baserecordingsnippets.py | 1 + .../core/channelsaggregationrecording.py | 18 +++++- .../test_channelsaggregationrecording.py | 60 ++++++++++++++----- 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index b224e0d282..3b8ffc7b03 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -564,6 +564,7 @@ def split_by(self, property="group", outputs="dict"): (inds,) = np.nonzero(values == value) new_channel_ids = self.get_channel_ids()[inds] subrec = self.select_channels(new_channel_ids) + subrec.set_annotation("split_by_property", value=property) if outputs == "list": recordings.append(subrec) elif outputs == "dict": diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index ca45268458..c85a93147e 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -19,9 +19,21 @@ def __init__(self, recording_list_or_dict, renamed_channel_ids=None): recording_list = list(recording_list_or_dict.values()) elif isinstance(recording_list_or_dict, list): recording_list = recording_list_or_dict + + # Check if the recordings were previously split using `split_by` + if recording_list[0].get_annotation("split_by_property") is None: + # If default 'group'ing (all equal 0), we label the recordings using the 'group' property + recording_groups = [] + for recording in recording_list: + recording_groups.extend(recording.get_property("group")) + if np.all(np.unique(recording_groups) == np.array([0])): + for group_id, recording in enumerate(recording_list): + recording.set_property("group", group_id * np.ones(recording.get_num_channels())) else: - raise TypeError("`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings.") - + raise TypeError( + "`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings." + ) + self._recordings = recording_list self._perform_consistency_checks() @@ -215,7 +227,7 @@ def aggregate_channels(recording_list, renamed_channel_ids=None): Parameters ---------- recording_list: list | dict - List or dict of BaseRecording objects to aggregate + List or dict of BaseRecording objects to aggregate. renamed_channel_ids: array-like If given, channel ids are renamed as provided. diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index c637e3e6b3..e83d51c256 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -21,8 +21,8 @@ def test_channelsaggregationrecording(): recording3.set_channel_locations([[40.0, 20.0], [40.0, 40.0], [40.0, 60.0]]) recordings_list_possibilities = [ - [recording1, recording2, recording3], - {0: recording1, 1: recording2, 2: recording3} + [recording1, recording2, recording3], + {0: recording1, 1: recording2, 2: recording3}, ] for recordings_list in recordings_list_possibilities: @@ -42,7 +42,9 @@ def test_channelsaggregationrecording(): traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) - assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) + assert np.allclose( + traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg) + ) assert np.allclose( traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), @@ -62,9 +64,7 @@ def test_channelsaggregationrecording(): # test rename channels renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] - recording_agg_renamed = aggregate_channels( - recordings_list, renamed_channel_ids=renamed_channel_ids - ) + recording_agg_renamed = aggregate_channels(recordings_list, renamed_channel_ids=renamed_channel_ids) assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids()) # test properties @@ -87,28 +87,29 @@ def test_channelsaggregationrecording(): assert "quality" not in recording_agg_prop.get_property_keys() print(recording_agg_prop.get_property("brain_area")) + def test_split_then_aggreate_preserve_user_property(): """ Checks that splitting then aggregating a recording preserves the unit_id to property mapping. """ num_channels = 10 - durations = [10,5] + durations = [10, 5] recording = generate_recording(num_channels=num_channels, durations=durations, set_probe=False) - - recording.set_property(key='group', values=[2,0,1,1,1,0,1,0,1,2]) - old_properties = recording.get_property(key='group') + recording.set_property(key="group", values=[2, 0, 1, 1, 1, 0, 1, 0, 1, 2]) + + old_properties = recording.get_property(key="group") old_channel_ids = recording.channel_ids - old_properties_ids_dict = dict(zip(old_channel_ids,old_properties)) + old_properties_ids_dict = dict(zip(old_channel_ids, old_properties)) - split_recordings = recording.split_by('group') + split_recordings = recording.split_by("group") aggregated_recording = aggregate_channels(split_recordings) - new_properties = aggregated_recording.get_property(key='group') + new_properties = aggregated_recording.get_property(key="group") new_channel_ids = aggregated_recording.channel_ids - new_properties_ids_dict = dict(zip(new_channel_ids,new_properties)) + new_properties_ids_dict = dict(zip(new_channel_ids, new_properties)) assert np.all(old_properties_ids_dict == new_properties_ids_dict) @@ -125,6 +126,37 @@ def test_channel_aggregation_preserve_ids(): assert list(aggregated_recording.get_channel_ids()) == ["a", "b", "c", "d", "e"] +def test_aggretion_labelling_for_lists(): + """Aggregated lists of recordings get different labels depending on their underlying labels""" + + recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) + recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) + + # If we don't label at all, aggregation will add a 'group' label + recording1.set_property("group", [0, 0, 0, 0]) + recording2.set_property("group", [0, 0]) + aggregated_recording = aggregate_channels([recording1, recording2]) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [0, 0, 0, 0, 1, 1]) + + # If we have different group labels, these should be respected + recording1.set_property("group", [2, 2, 2, 2]) + recording2.set_property("group", [6, 6]) + aggregated_recording = aggregate_channels([recording1, recording2]) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [2, 2, 2, 2, 6, 6]) + + # If we use `split_by`, aggregation should retain the split_by property, even if we only pass the list + recording1.set_property("user_group", [6, 7, 6, 7]) + recording_list = list(recording1.split_by("user_group").values()) + aggregated_recording = aggregate_channels(recording_list) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [2, 2, 2, 2]) + user_group_property = aggregated_recording.get_property("user_group") + # Note, aggregation reorders the channel_ids into the order of the ids of each individual recording + assert np.all(user_group_property == [6, 6, 7, 7]) + + def test_channel_aggregation_does_not_preserve_ids_if_not_unique(): recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check From 09c0c5b6dffa48b7dd143130dcda72fb6737ae95 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 13 Mar 2025 12:03:27 +0000 Subject: [PATCH 4/7] fix tests? --- src/spikeinterface/core/channelsaggregationrecording.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index c85a93147e..ac20629dbf 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -25,7 +25,10 @@ def __init__(self, recording_list_or_dict, renamed_channel_ids=None): # If default 'group'ing (all equal 0), we label the recordings using the 'group' property recording_groups = [] for recording in recording_list: - recording_groups.extend(recording.get_property("group")) + if (group_property := recording.get_property("group")) is not None: + recording_groups.extend(group_property) + else: + recording_groups.extend([0]) if np.all(np.unique(recording_groups) == np.array([0])): for group_id, recording in enumerate(recording_list): recording.set_property("group", group_id * np.ones(recording.get_num_channels())) From ba0d0ac9ffd0b8474e901ab734df3a6419033ca8 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 13 Mar 2025 14:46:39 +0000 Subject: [PATCH 5/7] allow dicts to also get labelled if no labels are known --- .../core/channelsaggregationrecording.py | 39 ++++++++++++------- .../test_channelsaggregationrecording.py | 33 ++++++++++++++-- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index ac20629dbf..f34df30637 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -17,21 +17,10 @@ def __init__(self, recording_list_or_dict, renamed_channel_ids=None): if isinstance(recording_list_or_dict, dict): recording_list = list(recording_list_or_dict.values()) + recording_ids = list(recording_list_or_dict.keys()) elif isinstance(recording_list_or_dict, list): recording_list = recording_list_or_dict - - # Check if the recordings were previously split using `split_by` - if recording_list[0].get_annotation("split_by_property") is None: - # If default 'group'ing (all equal 0), we label the recordings using the 'group' property - recording_groups = [] - for recording in recording_list: - if (group_property := recording.get_property("group")) is not None: - recording_groups.extend(group_property) - else: - recording_groups.extend([0]) - if np.all(np.unique(recording_groups) == np.array([0])): - for group_id, recording in enumerate(recording_list): - recording.set_property("group", group_id * np.ones(recording.get_num_channels())) + recording_ids = range(len(recording_list)) else: raise TypeError( "`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings." @@ -39,6 +28,11 @@ def __init__(self, recording_list_or_dict, renamed_channel_ids=None): self._recordings = recording_list + splitting_known = self._is_splitting_known() + if not splitting_known: + for group_id, recording in zip(recording_ids, recording_list): + recording.set_property("group", [group_id] * recording.get_num_channels()) + self._perform_consistency_checks() sampling_frequency = recording_list[0].get_sampling_frequency() dtype = recording_list[0].get_dtype() @@ -123,6 +117,25 @@ def __init__(self, recording_list_or_dict, renamed_channel_ids=None): def recordings(self): return self._recordings + def _is_splitting_known(self): + + # If we have the `split_by_property` annotation, we know how the recording was split + if self._recordings[0].get_annotation("split_by_property") is not None: + return True + + # Check if all 'group' properties are equal to 0 + recording_groups = [] + for recording in self._recordings: + if (group_labels := recording.get_property("group")) is not None: + recording_groups.extend(group_labels) + else: + recording_groups.extend([0]) + # If so, we don't know the splitting + if np.all(np.unique(recording_groups) == np.array([0])): + return False + else: + return True + def _perform_consistency_checks(self): # Check for consistent sampling frequency across recordings diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index e83d51c256..9a30abe022 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -127,14 +127,12 @@ def test_channel_aggregation_preserve_ids(): def test_aggretion_labelling_for_lists(): - """Aggregated lists of recordings get different labels depending on their underlying labels""" + """Aggregated lists of recordings get different labels depending on their underlying `property`s""" recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) # If we don't label at all, aggregation will add a 'group' label - recording1.set_property("group", [0, 0, 0, 0]) - recording2.set_property("group", [0, 0]) aggregated_recording = aggregate_channels([recording1, recording2]) group_property = aggregated_recording.get_property("group") assert np.all(group_property == [0, 0, 0, 0, 1, 1]) @@ -157,6 +155,35 @@ def test_aggretion_labelling_for_lists(): assert np.all(user_group_property == [6, 6, 7, 7]) +def test_aggretion_labelling_for_dicts(): + """Aggregated dicts of recordings get different labels depending on their underlying `property`s""" + + recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) + recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) + + # If we don't label at all, aggregation will add a 'group' label based on the dict keys + aggregated_recording = aggregate_channels({0: recording1, "cat": recording2}) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [0, 0, 0, 0, "cat", "cat"]) + + # If we have different group labels, these should be respected + recording1.set_property("group", [2, 2, 2, 2]) + recording2.set_property("group", [6, 6]) + aggregated_recording = aggregate_channels({0: recording1, "cat": recording2}) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [2, 2, 2, 2, 6, 6]) + + # If we use `split_by`, aggregation should retain the split_by property, even if we pass a different dict + recording1.set_property("user_group", [6, 7, 6, 7]) + recordings_dict = recording1.split_by("user_group") + aggregated_recording = aggregate_channels(recordings_dict) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [2, 2, 2, 2]) + user_group_property = aggregated_recording.get_property("user_group") + # Note, aggregation reorders the channel_ids into the order of the ids of each individual recording + assert np.all(user_group_property == [6, 6, 7, 7]) + + def test_channel_aggregation_does_not_preserve_ids_if_not_unique(): recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check From ce947088d187b58014ae2eba176912b6c7a1f0d2 Mon Sep 17 00:00:00 2001 From: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 13 Mar 2025 14:49:21 +0000 Subject: [PATCH 6/7] Update src/spikeinterface/core/tests/test_channelsaggregationrecording.py Co-authored-by: Alessio Buccino --- .../core/tests/test_channelsaggregationrecording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 9a30abe022..a9bb51dfed 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -126,7 +126,7 @@ def test_channel_aggregation_preserve_ids(): assert list(aggregated_recording.get_channel_ids()) == ["a", "b", "c", "d", "e"] -def test_aggretion_labelling_for_lists(): +def test_aggregation_labeling_for_lists(): """Aggregated lists of recordings get different labels depending on their underlying `property`s""" recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) From 8ff42e560addafbc4993b11b24fb9af0da63b2d0 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 10:01:55 +0000 Subject: [PATCH 7/7] respond to zach --- .../core/channelsaggregationrecording.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index f34df30637..3947a0decc 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -1,4 +1,5 @@ from __future__ import annotations +import warnings import numpy as np @@ -236,13 +237,17 @@ def get_traces( return np.concatenate(traces, axis=1) -def aggregate_channels(recording_list, renamed_channel_ids=None): +def aggregate_channels( + recording_list_or_dict=None, + renamed_channel_ids=None, + recording_list=None, +): """ Aggregates channels of multiple recording into a single recording object Parameters ---------- - recording_list: list | dict + recording_list_or_dict: list | dict List or dict of BaseRecording objects to aggregate. renamed_channel_ids: array-like If given, channel ids are renamed as provided. @@ -252,4 +257,13 @@ def aggregate_channels(recording_list, renamed_channel_ids=None): aggregate_recording: ChannelsAggregationRecording The aggregated recording object """ - return ChannelsAggregationRecording(recording_list, renamed_channel_ids) + + if recording_list is not None: + warnings.warn( + "`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + recording_list_or_dict = recording_list + + return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids)