Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,31 @@ class ChannelsAggregationRecording(BaseRecording):
"""

def __init__(self, recording_list, renamed_channel_ids=None):

# Generate a default list of channel ids that are unique and consecutive numbers as strings.
channel_map = {}
num_all_channels = sum(rec.get_num_channels() for rec in recording_list)

num_all_channels = sum([rec.get_num_channels() for rec in recording_list])
if renamed_channel_ids is not None:
assert len(np.unique(renamed_channel_ids)) == num_all_channels, (
"'renamed_channel_ids' doesn't have the " "right size or has duplicates!"
)
assert (
len(np.unique(renamed_channel_ids)) == num_all_channels
), "'renamed_channel_ids' doesn't have the right size or has duplicates!"
channel_ids = list(renamed_channel_ids)
else:
channel_ids = list(np.arange(num_all_channels))
# Collect channel IDs from all recordings
all_channels_have_same_type = np.unique([rec.channel_ids.dtype for rec in recording_list]).size == 1
all_channel_ids_are_unique = False
if all_channels_have_same_type:
combined_ids = np.concatenate([rec.channel_ids for rec in recording_list])
all_channel_ids_are_unique = np.unique(combined_ids).size == num_all_channels

if all_channels_have_same_type and all_channel_ids_are_unique:
channel_ids = combined_ids
else:
# If IDs are not unique or not of the same type, use default as stringify IDs
default_channel_ids = [str(i) for i in range(num_all_channels)]
channel_ids = default_channel_ids

# channel map maps channel indices that are used to get traces
ch_id = 0
for r_i, recording in enumerate(recording_list):
single_channel_ids = recording.get_channel_ids()
Expand All @@ -49,7 +62,9 @@ def __init__(self, recording_list, renamed_channel_ids=None):
break

if not (ok1 and ok2 and ok3 and ok4):
raise ValueError("Sortings don't have the same sampling_frequency/num_segments/dtype/num samples")
raise ValueError(
"Recordings do not have consistent sampling frequency, number of segments, data type, or number of samples."
Comment thread
h-mayorquin marked this conversation as resolved.
)

BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype)

Expand Down Expand Up @@ -91,7 +106,7 @@ def __init__(self, recording_list, renamed_channel_ids=None):
self.add_recording_segment(sub_segment)

self._recordings = recording_list
self._kwargs = {"recording_list": [rec for rec in recording_list], "renamed_channel_ids": renamed_channel_ids}
self._kwargs = {"recording_list": recording_list, "renamed_channel_ids": renamed_channel_ids}

@property
def recordings(self):
Expand Down Expand Up @@ -173,11 +188,11 @@ def aggregate_channels(recording_list, renamed_channel_ids=None):
recording_list: list
List of BaseRecording objects to aggregate
renamed_channel_ids: array-like
If given, channel ids are renamed as provided. If None, unit ids are sequential integers.
If given, channel ids are renamed as provided.

Returns
-------
aggregate_recording: UnitsAggregationSorting
The aggregated sorting object
aggregate_recording: ChannelsAggregationRecording
The aggregated recording object
"""
return ChannelsAggregationRecording(recording_list, renamed_channel_ids)
49 changes: 42 additions & 7 deletions src/spikeinterface/core/tests/test_channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_channelsaggregationrecording():

# test num channels
recording_agg = aggregate_channels([recording1, recording2, recording3])
print(recording_agg)
assert len(recording_agg.get_channel_ids()) == 3 * num_channels

assert np.allclose(recording_agg.get_times(0), recording1.get_times(0))
Expand All @@ -37,21 +36,21 @@ def test_channelsaggregationrecording():
traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg)
traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg)

assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[channel_ids[1]], segment_index=seg))
assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg))
assert np.allclose(
traces2_0, recording_agg.get_traces(channel_ids=[num_channels + channel_ids[0]], segment_index=seg)
traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg)
)
assert np.allclose(
traces3_2, recording_agg.get_traces(channel_ids=[2 * num_channels + channel_ids[2]], segment_index=seg)
traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg)
)
# all traces
traces1 = recording1.get_traces(segment_index=seg)
traces2 = recording2.get_traces(segment_index=seg)
traces3 = recording3.get_traces(segment_index=seg)

assert np.allclose(traces1, recording_agg.get_traces(channel_ids=[0, 1, 2], segment_index=seg))
assert np.allclose(traces2, recording_agg.get_traces(channel_ids=[3, 4, 5], segment_index=seg))
assert np.allclose(traces3, recording_agg.get_traces(channel_ids=[6, 7, 8], segment_index=seg))
assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg))
assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg))
assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg))

# test rename channels
renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)]
Expand Down Expand Up @@ -81,5 +80,41 @@ def test_channelsaggregationrecording():
print(recording_agg_prop.get_property("brain_area"))


def test_channel_aggregation_preserve_ids():

recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"])
recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False)
recording2 = recording2.rename_channels(new_channel_ids=["d", "e"])

aggregated_recording = aggregate_channels([recording1, recording2])
assert aggregated_recording.get_num_channels() == 5
assert list(aggregated_recording.get_channel_ids()) == ["a", "b", "c", "d", "e"]


def test_channel_aggregation_does_not_preserve_ids_if_not_unique():

recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"])
recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False)
recording2 = recording2.rename_channels(new_channel_ids=["a", "b"])

aggregated_recording = aggregate_channels([recording1, recording2])
assert aggregated_recording.get_num_channels() == 5
assert list(aggregated_recording.get_channel_ids()) == ["0", "1", "2", "3", "4"]


def test_channel_aggregation_does_not_preserve_ids_not_the_same_type():

recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"])
recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False)
recording2 = recording2.rename_channels(new_channel_ids=[1, 2])

aggregated_recording = aggregate_channels([recording1, recording2])
assert aggregated_recording.get_num_channels() == 5
assert list(aggregated_recording.get_channel_ids()) == ["0", "1", "2", "3", "4"]


if __name__ == "__main__":
test_channelsaggregationrecording()