From 2852992f066cfd4925f4bcb1aad7afe9fae5153a Mon Sep 17 00:00:00 2001 From: tayheau Date: Mon, 4 May 2026 13:50:45 +0200 Subject: [PATCH 1/2] add SelectSegmentEvent --- src/spikeinterface/core/baseevent.py | 5 +++++ src/spikeinterface/core/segmentutils.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/spikeinterface/core/baseevent.py b/src/spikeinterface/core/baseevent.py index 7c0fbef221..2e73774c04 100644 --- a/src/spikeinterface/core/baseevent.py +++ b/src/spikeinterface/core/baseevent.py @@ -66,6 +66,11 @@ def add_event_segment(self, event_segment): def get_num_segments(self): return len(self._event_segments) + def select_segment(self, segment_indices: int | list[int]): + from .segmentutils import SelectSegmentEvent + + return SelectSegmentEvent(self, segment_indices=segment_indices) + def get_events( self, channel_id: int | str | None = None, diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 3d99fd23c4..d8b5cf584a 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -1,5 +1,6 @@ import numpy as np +from .baseevent import BaseEvent from .baserecording import BaseRecording, BaseRecordingSegment from .basesorting import BaseSorting, BaseSortingSegment @@ -604,3 +605,21 @@ def __init__(self, sorting: BaseSorting, segment_indices: int | list[int]): select_segment_sorting = define_function_from_class(source_class=SelectSegmentSorting, name="select_segment_sorting") + +class SelectSegmentEvent(BaseEvent): + def __init__(self, event: BaseEvent, segment_indices: int | list[int]): + BaseEvent.__init__(self, event.channel_ids, event.structured_dtype) + + if isinstance(segment_indices, int): segment_indices = [segment_indices] + + num_segments = event.get_num_segments() + + if not all( 0 <= s < num_segments for s in segment_indices): + raise ValueError(f"'segment_index' must be between 0 and {num_segments - 1}") + + for seg_idx in segment_indices: + seg = event._event_segments[seg_idx] + self.add_event_segment(seg) + + self._parent = event + self._kwargs = {"event": event, "segment_indices": segment_indices} From fc5234d966e7c7daf4fc468fc1c2611572d2d304 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 May 2026 11:54:42 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/segmentutils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index d8b5cf584a..3a40901fab 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -606,19 +606,21 @@ def __init__(self, sorting: BaseSorting, segment_indices: int | list[int]): select_segment_sorting = define_function_from_class(source_class=SelectSegmentSorting, name="select_segment_sorting") + class SelectSegmentEvent(BaseEvent): def __init__(self, event: BaseEvent, segment_indices: int | list[int]): BaseEvent.__init__(self, event.channel_ids, event.structured_dtype) - if isinstance(segment_indices, int): segment_indices = [segment_indices] + if isinstance(segment_indices, int): + segment_indices = [segment_indices] num_segments = event.get_num_segments() - if not all( 0 <= s < num_segments for s in segment_indices): + if not all(0 <= s < num_segments for s in segment_indices): raise ValueError(f"'segment_index' must be between 0 and {num_segments - 1}") for seg_idx in segment_indices: - seg = event._event_segments[seg_idx] + seg = event._event_segments[seg_idx] self.add_event_segment(seg) self._parent = event