diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index b224e0d282..3b8ffc7b03 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -564,6 +564,7 @@ def split_by(self, property="group", outputs="dict"): (inds,) = np.nonzero(values == value) new_channel_ids = self.get_channel_ids()[inds] subrec = self.select_channels(new_channel_ids) + subrec.set_annotation("split_by_property", value=property) if outputs == "list": recordings.append(subrec) elif outputs == "dict": diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 4fa1d88974..3947a0decc 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -1,4 +1,5 @@ from __future__ import annotations +import warnings import numpy as np @@ -13,10 +14,26 @@ class ChannelsAggregationRecording(BaseRecording): """ - def __init__(self, recording_list, renamed_channel_ids=None): + def __init__(self, recording_list_or_dict, renamed_channel_ids=None): + + if isinstance(recording_list_or_dict, dict): + recording_list = list(recording_list_or_dict.values()) + recording_ids = list(recording_list_or_dict.keys()) + elif isinstance(recording_list_or_dict, list): + recording_list = recording_list_or_dict + recording_ids = range(len(recording_list)) + else: + raise TypeError( + "`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings." + ) self._recordings = recording_list + splitting_known = self._is_splitting_known() + if not splitting_known: + for group_id, recording in zip(recording_ids, recording_list): + recording.set_property("group", [group_id] * recording.get_num_channels()) + self._perform_consistency_checks() sampling_frequency = recording_list[0].get_sampling_frequency() dtype = recording_list[0].get_dtype() @@ -101,6 +118,25 @@ def __init__(self, recording_list, renamed_channel_ids=None): def recordings(self): return self._recordings + def _is_splitting_known(self): + + # If we have the `split_by_property` annotation, we know how the recording was split + if self._recordings[0].get_annotation("split_by_property") is not None: + return True + + # Check if all 'group' properties are equal to 0 + recording_groups = [] + for recording in self._recordings: + if (group_labels := recording.get_property("group")) is not None: + recording_groups.extend(group_labels) + else: + recording_groups.extend([0]) + # If so, we don't know the splitting + if np.all(np.unique(recording_groups) == np.array([0])): + return False + else: + return True + def _perform_consistency_checks(self): # Check for consistent sampling frequency across recordings @@ -201,14 +237,18 @@ def get_traces( return np.concatenate(traces, axis=1) -def aggregate_channels(recording_list, renamed_channel_ids=None): +def aggregate_channels( + recording_list_or_dict=None, + renamed_channel_ids=None, + recording_list=None, +): """ Aggregates channels of multiple recording into a single recording object Parameters ---------- - recording_list: list - List of BaseRecording objects to aggregate + recording_list_or_dict: list | dict + List or dict of BaseRecording objects to aggregate. renamed_channel_ids: array-like If given, channel ids are renamed as provided. @@ -217,4 +257,13 @@ def aggregate_channels(recording_list, renamed_channel_ids=None): aggregate_recording: ChannelsAggregationRecording The aggregated recording object """ - return ChannelsAggregationRecording(recording_list, renamed_channel_ids) + + if recording_list is not None: + warnings.warn( + "`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + recording_list_or_dict = recording_list + + return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 16a91a55e1..a9bb51dfed 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -1,7 +1,6 @@ import numpy as np from spikeinterface.core import aggregate_channels - from spikeinterface.core import generate_recording @@ -21,65 +20,98 @@ def test_channelsaggregationrecording(): recording2.set_channel_locations([[20.0, 20.0], [20.0, 40.0], [20.0, 60.0]]) recording3.set_channel_locations([[40.0, 20.0], [40.0, 40.0], [40.0, 60.0]]) - # test num channels - recording_agg = aggregate_channels([recording1, recording2, recording3]) - assert len(recording_agg.get_channel_ids()) == 3 * num_channels - - assert np.allclose(recording_agg.get_times(0), recording1.get_times(0)) - - # test traces - channel_ids = recording1.get_channel_ids() - - for seg in range(num_seg): - # single channels - traces1_1 = recording1.get_traces(channel_ids=[channel_ids[1]], segment_index=seg) - traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) - traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) - - assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) - assert np.allclose( - traces2_0, - recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), - ) - assert np.allclose( - traces3_2, - recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), - ) - # all traces - traces1 = recording1.get_traces(segment_index=seg) - traces2 = recording2.get_traces(segment_index=seg) - traces3 = recording3.get_traces(segment_index=seg) - - assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg)) - assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg)) - assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg)) - - # test rename channels - renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] - recording_agg_renamed = aggregate_channels( - [recording1, recording2, recording3], renamed_channel_ids=renamed_channel_ids - ) - assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids()) - - # test properties - # complete property - recording1.set_property("brain_area", ["CA1"] * num_channels) - recording2.set_property("brain_area", ["CA2"] * num_channels) - recording3.set_property("brain_area", ["CA3"] * num_channels) - - # skip for inconsistency - recording1.set_property("template", np.zeros((num_channels, 4, 30))) - recording2.set_property("template", np.zeros((num_channels, 20, 50))) - recording3.set_property("template", np.zeros((num_channels, 2, 10))) - - # incomplete property - recording1.set_property("quality", ["good"] * num_channels) - recording2.set_property("quality", ["bad"] * num_channels) - - recording_agg_prop = aggregate_channels([recording1, recording2, recording3]) - assert "brain_area" in recording_agg_prop.get_property_keys() - assert "quality" not in recording_agg_prop.get_property_keys() - print(recording_agg_prop.get_property("brain_area")) + recordings_list_possibilities = [ + [recording1, recording2, recording3], + {0: recording1, 1: recording2, 2: recording3}, + ] + + for recordings_list in recordings_list_possibilities: + + # test num channels + recording_agg = aggregate_channels(recordings_list) + assert len(recording_agg.get_channel_ids()) == 3 * num_channels + + assert np.allclose(recording_agg.get_times(0), recording1.get_times(0)) + + # test traces + channel_ids = recording1.get_channel_ids() + + for seg in range(num_seg): + # single channels + traces1_1 = recording1.get_traces(channel_ids=[channel_ids[1]], segment_index=seg) + traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) + traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) + + assert np.allclose( + traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg) + ) + assert np.allclose( + traces2_0, + recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), + ) + assert np.allclose( + traces3_2, + recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), + ) + # all traces + traces1 = recording1.get_traces(segment_index=seg) + traces2 = recording2.get_traces(segment_index=seg) + traces3 = recording3.get_traces(segment_index=seg) + + assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg)) + assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg)) + assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg)) + + # test rename channels + renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] + recording_agg_renamed = aggregate_channels(recordings_list, renamed_channel_ids=renamed_channel_ids) + assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids()) + + # test properties + # complete property + recording1.set_property("brain_area", ["CA1"] * num_channels) + recording2.set_property("brain_area", ["CA2"] * num_channels) + recording3.set_property("brain_area", ["CA3"] * num_channels) + + # skip for inconsistency + recording1.set_property("template", np.zeros((num_channels, 4, 30))) + recording2.set_property("template", np.zeros((num_channels, 20, 50))) + recording3.set_property("template", np.zeros((num_channels, 2, 10))) + + # incomplete property + recording1.set_property("quality", ["good"] * num_channels) + recording2.set_property("quality", ["bad"] * num_channels) + + recording_agg_prop = aggregate_channels(recordings_list) + assert "brain_area" in recording_agg_prop.get_property_keys() + assert "quality" not in recording_agg_prop.get_property_keys() + print(recording_agg_prop.get_property("brain_area")) + + +def test_split_then_aggreate_preserve_user_property(): + """ + Checks that splitting then aggregating a recording preserves the unit_id to property mapping. + """ + + num_channels = 10 + durations = [10, 5] + recording = generate_recording(num_channels=num_channels, durations=durations, set_probe=False) + + recording.set_property(key="group", values=[2, 0, 1, 1, 1, 0, 1, 0, 1, 2]) + + old_properties = recording.get_property(key="group") + old_channel_ids = recording.channel_ids + old_properties_ids_dict = dict(zip(old_channel_ids, old_properties)) + + split_recordings = recording.split_by("group") + + aggregated_recording = aggregate_channels(split_recordings) + + new_properties = aggregated_recording.get_property(key="group") + new_channel_ids = aggregated_recording.channel_ids + new_properties_ids_dict = dict(zip(new_channel_ids, new_properties)) + + assert np.all(old_properties_ids_dict == new_properties_ids_dict) def test_channel_aggregation_preserve_ids(): @@ -94,6 +126,64 @@ def test_channel_aggregation_preserve_ids(): assert list(aggregated_recording.get_channel_ids()) == ["a", "b", "c", "d", "e"] +def test_aggregation_labeling_for_lists(): + """Aggregated lists of recordings get different labels depending on their underlying `property`s""" + + recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) + recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) + + # If we don't label at all, aggregation will add a 'group' label + aggregated_recording = aggregate_channels([recording1, recording2]) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [0, 0, 0, 0, 1, 1]) + + # If we have different group labels, these should be respected + recording1.set_property("group", [2, 2, 2, 2]) + recording2.set_property("group", [6, 6]) + aggregated_recording = aggregate_channels([recording1, recording2]) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [2, 2, 2, 2, 6, 6]) + + # If we use `split_by`, aggregation should retain the split_by property, even if we only pass the list + recording1.set_property("user_group", [6, 7, 6, 7]) + recording_list = list(recording1.split_by("user_group").values()) + aggregated_recording = aggregate_channels(recording_list) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [2, 2, 2, 2]) + user_group_property = aggregated_recording.get_property("user_group") + # Note, aggregation reorders the channel_ids into the order of the ids of each individual recording + assert np.all(user_group_property == [6, 6, 7, 7]) + + +def test_aggretion_labelling_for_dicts(): + """Aggregated dicts of recordings get different labels depending on their underlying `property`s""" + + recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) + recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) + + # If we don't label at all, aggregation will add a 'group' label based on the dict keys + aggregated_recording = aggregate_channels({0: recording1, "cat": recording2}) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [0, 0, 0, 0, "cat", "cat"]) + + # If we have different group labels, these should be respected + recording1.set_property("group", [2, 2, 2, 2]) + recording2.set_property("group", [6, 6]) + aggregated_recording = aggregate_channels({0: recording1, "cat": recording2}) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [2, 2, 2, 2, 6, 6]) + + # If we use `split_by`, aggregation should retain the split_by property, even if we pass a different dict + recording1.set_property("user_group", [6, 7, 6, 7]) + recordings_dict = recording1.split_by("user_group") + aggregated_recording = aggregate_channels(recordings_dict) + group_property = aggregated_recording.get_property("group") + assert np.all(group_property == [2, 2, 2, 2]) + user_group_property = aggregated_recording.get_property("user_group") + # Note, aggregation reorders the channel_ids into the order of the ids of each individual recording + assert np.all(user_group_property == [6, 6, 7, 7]) + + def test_channel_aggregation_does_not_preserve_ids_if_not_unique(): recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check