Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ def register_recording(self, recording, check_spike_frames: bool = True):
"Might be necessary for further postprocessing."
)
self._recording = recording
# Copy the recording's start times into the sorting segments. This way,
# the sorting preserves the start time even if the recording is later
# detached (e.g. analyzer saved and reloaded without the recording).
for segment_index, segment in enumerate(self.segments):
segment._t_start = recording.get_start_time(segment_index=segment_index)

@property
def sorting_info(self):
Expand All @@ -352,6 +357,66 @@ def has_time_vector(self, segment_index: int | None = None) -> bool:
else:
return False

def get_start_time(self, segment_index: int | None = None) -> float:
"""Get the start time of the sorting segment.

Parameters
----------
segment_index : int or None, default: None
The segment index (required for multi-segment)

Returns
-------
float
The start time in seconds
"""
segment_index = self._check_segment_index(segment_index)
segment = self.segments[segment_index]
return segment._t_start if segment._t_start is not None else 0.0

def get_end_time(self, segment_index: int | None = None) -> float:
"""Get the end time of the sorting segment.

If a recording is registered, returns the recording's end time.
Otherwise returns the time of the last spike in the segment.

Parameters
----------
segment_index : int or None, default: None
The segment index (required for multi-segment)

Returns
-------
float
The end time in seconds
"""
segment_index = self._check_segment_index(segment_index)
if self.has_recording():
return self._recording.get_end_time(segment_index=segment_index)
else:
last_spike_frame = self.get_last_spike_frame(segment_index=segment_index)
return self.sample_index_to_time(last_spike_frame, segment_index=segment_index)

def get_last_spike_frame(self, segment_index: int | None = None) -> int:
"""Get the frame index of the last spike in a segment across all units.

Parameters
----------
segment_index : int or None, default: None
The segment index (required for multi-segment)

Returns
-------
int
The frame index of the last spike, or 0 if no spikes exist.
"""
segment_index = self._check_segment_index(segment_index)
spike_vector = self.to_spike_vector(concatenated=False)
spikes_in_segment = spike_vector[segment_index]
if len(spikes_in_segment) == 0:
return 0
return int(np.max(spikes_in_segment["sample_index"]))

def get_times(self, segment_index=None):
"""
Get time vector for a registered recording segment.
Expand Down
58 changes: 58 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,61 @@ def test_shift_times_with_None_as_t_start():
assert recording.segments[0].t_start is None
recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error
assert recording.get_start_time() == 1.0


class TestSortingTimeNoRecording:
"""Tests for time methods on BaseSorting without a registered recording."""

def test_get_start_time_default(self):
sorting = generate_sorting(num_units=5, durations=[10])
assert sorting.get_start_time(segment_index=0) == 0.0

def test_get_end_time_is_last_spike(self):
sorting = generate_sorting(num_units=5, durations=[10])
last_frame = sorting.get_last_spike_frame(segment_index=0)
expected_time = last_frame / sorting.get_sampling_frequency()
assert sorting.get_end_time(segment_index=0) == expected_time

def test_get_start_time_with_t_start(self):
sorting = generate_sorting(num_units=5, durations=[10])
sorting.segments[0]._t_start = 100.0
assert sorting.get_start_time(segment_index=0) == 100.0


class TestSortingTimeWithRecording:
"""
Tests for time methods on BaseSorting with a registered recording.
The key invariant: the recording is the source of truth for timestamps.
"""

def test_get_start_end_time(self):
recording = generate_recording(num_channels=4, durations=[10])
sorting = generate_sorting(num_units=5, durations=[10])
sorting.register_recording(recording)

assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0)
assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0)

def test_register_recording_copies_start_times(self):
"""Registering a recording copies its start times into the sorting segments."""
sorting = generate_sorting(num_units=5, durations=[10])
sorting.segments[0]._t_start = 100.0

recording = generate_recording(num_channels=4, durations=[10])
recording.shift_times(shift=50.0)
sorting.register_recording(recording)

# _t_start now mirrors the recording's start time, preserving it across
# save/load cycles even when the recording is not attached.
assert sorting.segments[0]._t_start == recording.get_start_time(segment_index=0)
assert sorting.get_start_time(segment_index=0) == 50.0

def test_with_recording_shifted_start(self):
"""Recording with a non-zero t_start is reflected in the sorting."""
recording = generate_recording(num_channels=4, durations=[10])
recording.shift_times(shift=50.0)

sorting = generate_sorting(num_units=5, durations=[10])
sorting.register_recording(recording)

assert sorting.get_start_time(segment_index=0) == 50.0
Loading