Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def split_by(self, property="group", outputs="dict"):
(inds,) = np.nonzero(values == value)
new_channel_ids = self.get_channel_ids()[inds]
subrec = self.select_channels(new_channel_ids)
subrec.set_annotation("split_by_property", value=property)
if outputs == "list":
recordings.append(subrec)
elif outputs == "dict":
Expand Down
59 changes: 54 additions & 5 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import warnings

import numpy as np

Expand All @@ -13,10 +14,26 @@ class ChannelsAggregationRecording(BaseRecording):

"""

def __init__(self, recording_list, renamed_channel_ids=None):
def __init__(self, recording_list_or_dict, renamed_channel_ids=None):

if isinstance(recording_list_or_dict, dict):
recording_list = list(recording_list_or_dict.values())
recording_ids = list(recording_list_or_dict.keys())
elif isinstance(recording_list_or_dict, list):
recording_list = recording_list_or_dict
recording_ids = range(len(recording_list))
else:
raise TypeError(
"`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings."
)

self._recordings = recording_list

splitting_known = self._is_splitting_known()
if not splitting_known:
for group_id, recording in zip(recording_ids, recording_list):
recording.set_property("group", [group_id] * recording.get_num_channels())

self._perform_consistency_checks()
sampling_frequency = recording_list[0].get_sampling_frequency()
dtype = recording_list[0].get_dtype()
Expand Down Expand Up @@ -101,6 +118,25 @@ def __init__(self, recording_list, renamed_channel_ids=None):
def recordings(self):
return self._recordings

def _is_splitting_known(self):

# If we have the `split_by_property` annotation, we know how the recording was split
if self._recordings[0].get_annotation("split_by_property") is not None:
return True

# Check if all 'group' properties are equal to 0
recording_groups = []
for recording in self._recordings:
if (group_labels := recording.get_property("group")) is not None:
recording_groups.extend(group_labels)
else:
recording_groups.extend([0])
# If so, we don't know the splitting
if np.all(np.unique(recording_groups) == np.array([0])):
return False
else:
return True

def _perform_consistency_checks(self):

# Check for consistent sampling frequency across recordings
Expand Down Expand Up @@ -201,14 +237,18 @@ def get_traces(
return np.concatenate(traces, axis=1)


def aggregate_channels(recording_list, renamed_channel_ids=None):
def aggregate_channels(
recording_list_or_dict=None,
renamed_channel_ids=None,
recording_list=None,
):
"""
Aggregates channels of multiple recording into a single recording object

Parameters
----------
recording_list: list
List of BaseRecording objects to aggregate
recording_list_or_dict: list | dict
List or dict of BaseRecording objects to aggregate.
renamed_channel_ids: array-like
If given, channel ids are renamed as provided.

Expand All @@ -217,4 +257,13 @@ def aggregate_channels(recording_list, renamed_channel_ids=None):
aggregate_recording: ChannelsAggregationRecording
The aggregated recording object
"""
return ChannelsAggregationRecording(recording_list, renamed_channel_ids)

if recording_list is not None:
warnings.warn(
"`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.",
category=DeprecationWarning,
stacklevel=2,
)
recording_list_or_dict = recording_list

return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids)
210 changes: 150 additions & 60 deletions src/spikeinterface/core/tests/test_channelsaggregationrecording.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np

from spikeinterface.core import aggregate_channels

from spikeinterface.core import generate_recording


Expand All @@ -21,65 +20,98 @@ def test_channelsaggregationrecording():
recording2.set_channel_locations([[20.0, 20.0], [20.0, 40.0], [20.0, 60.0]])
recording3.set_channel_locations([[40.0, 20.0], [40.0, 40.0], [40.0, 60.0]])

# test num channels
recording_agg = aggregate_channels([recording1, recording2, recording3])
assert len(recording_agg.get_channel_ids()) == 3 * num_channels

assert np.allclose(recording_agg.get_times(0), recording1.get_times(0))

# test traces
channel_ids = recording1.get_channel_ids()

for seg in range(num_seg):
# single channels
traces1_1 = recording1.get_traces(channel_ids=[channel_ids[1]], segment_index=seg)
traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg)
traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg)

assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg))
assert np.allclose(
traces2_0,
recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg),
)
assert np.allclose(
traces3_2,
recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg),
)
# all traces
traces1 = recording1.get_traces(segment_index=seg)
traces2 = recording2.get_traces(segment_index=seg)
traces3 = recording3.get_traces(segment_index=seg)

assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg))
assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg))
assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg))

# test rename channels
renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)]
recording_agg_renamed = aggregate_channels(
[recording1, recording2, recording3], renamed_channel_ids=renamed_channel_ids
)
assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids())

# test properties
# complete property
recording1.set_property("brain_area", ["CA1"] * num_channels)
recording2.set_property("brain_area", ["CA2"] * num_channels)
recording3.set_property("brain_area", ["CA3"] * num_channels)

# skip for inconsistency
recording1.set_property("template", np.zeros((num_channels, 4, 30)))
recording2.set_property("template", np.zeros((num_channels, 20, 50)))
recording3.set_property("template", np.zeros((num_channels, 2, 10)))

# incomplete property
recording1.set_property("quality", ["good"] * num_channels)
recording2.set_property("quality", ["bad"] * num_channels)

recording_agg_prop = aggregate_channels([recording1, recording2, recording3])
assert "brain_area" in recording_agg_prop.get_property_keys()
assert "quality" not in recording_agg_prop.get_property_keys()
print(recording_agg_prop.get_property("brain_area"))
recordings_list_possibilities = [
[recording1, recording2, recording3],
{0: recording1, 1: recording2, 2: recording3},
]

for recordings_list in recordings_list_possibilities:

# test num channels
recording_agg = aggregate_channels(recordings_list)
assert len(recording_agg.get_channel_ids()) == 3 * num_channels

assert np.allclose(recording_agg.get_times(0), recording1.get_times(0))

# test traces
channel_ids = recording1.get_channel_ids()

for seg in range(num_seg):
# single channels
traces1_1 = recording1.get_traces(channel_ids=[channel_ids[1]], segment_index=seg)
traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg)
traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg)

assert np.allclose(
traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)
)
assert np.allclose(
traces2_0,
recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg),
)
assert np.allclose(
traces3_2,
recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg),
)
# all traces
traces1 = recording1.get_traces(segment_index=seg)
traces2 = recording2.get_traces(segment_index=seg)
traces3 = recording3.get_traces(segment_index=seg)

assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg))
assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg))
assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg))

# test rename channels
renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)]
recording_agg_renamed = aggregate_channels(recordings_list, renamed_channel_ids=renamed_channel_ids)
assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids())

# test properties
# complete property
recording1.set_property("brain_area", ["CA1"] * num_channels)
recording2.set_property("brain_area", ["CA2"] * num_channels)
recording3.set_property("brain_area", ["CA3"] * num_channels)

# skip for inconsistency
recording1.set_property("template", np.zeros((num_channels, 4, 30)))
recording2.set_property("template", np.zeros((num_channels, 20, 50)))
recording3.set_property("template", np.zeros((num_channels, 2, 10)))

# incomplete property
recording1.set_property("quality", ["good"] * num_channels)
recording2.set_property("quality", ["bad"] * num_channels)

recording_agg_prop = aggregate_channels(recordings_list)
assert "brain_area" in recording_agg_prop.get_property_keys()
assert "quality" not in recording_agg_prop.get_property_keys()
print(recording_agg_prop.get_property("brain_area"))


def test_split_then_aggreate_preserve_user_property():
"""
Checks that splitting then aggregating a recording preserves the unit_id to property mapping.
"""

num_channels = 10
durations = [10, 5]
recording = generate_recording(num_channels=num_channels, durations=durations, set_probe=False)

recording.set_property(key="group", values=[2, 0, 1, 1, 1, 0, 1, 0, 1, 2])

old_properties = recording.get_property(key="group")
old_channel_ids = recording.channel_ids
old_properties_ids_dict = dict(zip(old_channel_ids, old_properties))

split_recordings = recording.split_by("group")

aggregated_recording = aggregate_channels(split_recordings)

new_properties = aggregated_recording.get_property(key="group")
new_channel_ids = aggregated_recording.channel_ids
new_properties_ids_dict = dict(zip(new_channel_ids, new_properties))

assert np.all(old_properties_ids_dict == new_properties_ids_dict)


def test_channel_aggregation_preserve_ids():
Expand All @@ -94,6 +126,64 @@ def test_channel_aggregation_preserve_ids():
assert list(aggregated_recording.get_channel_ids()) == ["a", "b", "c", "d", "e"]


def test_aggregation_labeling_for_lists():
"""Aggregated lists of recordings get different labels depending on their underlying `property`s"""

recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False)
recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False)

# If we don't label at all, aggregation will add a 'group' label
aggregated_recording = aggregate_channels([recording1, recording2])
group_property = aggregated_recording.get_property("group")
assert np.all(group_property == [0, 0, 0, 0, 1, 1])

# If we have different group labels, these should be respected
recording1.set_property("group", [2, 2, 2, 2])
recording2.set_property("group", [6, 6])
aggregated_recording = aggregate_channels([recording1, recording2])
group_property = aggregated_recording.get_property("group")
assert np.all(group_property == [2, 2, 2, 2, 6, 6])

# If we use `split_by`, aggregation should retain the split_by property, even if we only pass the list
recording1.set_property("user_group", [6, 7, 6, 7])
recording_list = list(recording1.split_by("user_group").values())
aggregated_recording = aggregate_channels(recording_list)
group_property = aggregated_recording.get_property("group")
assert np.all(group_property == [2, 2, 2, 2])
user_group_property = aggregated_recording.get_property("user_group")
# Note, aggregation reorders the channel_ids into the order of the ids of each individual recording
assert np.all(user_group_property == [6, 6, 7, 7])


def test_aggretion_labelling_for_dicts():
"""Aggregated dicts of recordings get different labels depending on their underlying `property`s"""

recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False)
recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False)

# If we don't label at all, aggregation will add a 'group' label based on the dict keys
aggregated_recording = aggregate_channels({0: recording1, "cat": recording2})
group_property = aggregated_recording.get_property("group")
assert np.all(group_property == [0, 0, 0, 0, "cat", "cat"])

# If we have different group labels, these should be respected
recording1.set_property("group", [2, 2, 2, 2])
recording2.set_property("group", [6, 6])
aggregated_recording = aggregate_channels({0: recording1, "cat": recording2})
group_property = aggregated_recording.get_property("group")
assert np.all(group_property == [2, 2, 2, 2, 6, 6])

# If we use `split_by`, aggregation should retain the split_by property, even if we pass a different dict
recording1.set_property("user_group", [6, 7, 6, 7])
recordings_dict = recording1.split_by("user_group")
aggregated_recording = aggregate_channels(recordings_dict)
group_property = aggregated_recording.get_property("group")
assert np.all(group_property == [2, 2, 2, 2])
user_group_property = aggregated_recording.get_property("user_group")
# Note, aggregation reorders the channel_ids into the order of the ids of each individual recording
assert np.all(user_group_property == [6, 6, 7, 7])


def test_channel_aggregation_does_not_preserve_ids_if_not_unique():

recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
Expand Down