Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions src/spikeinterface/preprocessing/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def __init__(
margin = int(margin_ms * recording.get_sampling_frequency() / 1000)

BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate, dtype=dtype)
# in case there was a time_vector, it will be dropped for sanity.
for parent_segment in recording._recording_segments:
parent_segment.time_vector = None
self.add_recording_segment(
ResampleRecordingSegment(
parent_segment,
Expand Down Expand Up @@ -96,24 +94,44 @@ def __init__(
margin,
dtype,
):
# Do not use BasePreprocessorSegment bcause we have to reset the sampling rate!
BaseRecordingSegment.__init__(
self,
sampling_frequency=resample_rate,
t_start=parent_recording_segment.t_start,
)
self._resample_rate = resample_rate
self._parent_segment = parent_recording_segment
self._parent_rate = parent_rate
self._margin = margin
self._dtype = dtype

# Compute time_vector or t_start, following the pattern from DecimateRecordingSegment.
# Do not use BasePreprocessorSegment because we have to reset the sampling rate!
if parent_recording_segment.time_vector is not None:
parent_tv = np.asarray(parent_recording_segment.time_vector)
n_out = int(len(parent_tv) / parent_rate * resample_rate)

if parent_rate % resample_rate == 0:
q_int = int(parent_rate / resample_rate)
time_vector = parent_tv[::q_int][:n_out]
else:
warnings.warn(
"Resampling with a non-integer ratio requires interpolating the time_vector. "
"An integer ratio (parent_rate / resample_rate) is more performant."
)
parent_indices = np.linspace(0, len(parent_tv) - 1, n_out)
time_vector = np.interp(parent_indices, np.arange(len(parent_tv)), parent_tv)

BaseRecordingSegment.__init__(self, sampling_frequency=None, t_start=None, time_vector=time_vector)
else:
BaseRecordingSegment.__init__(
self, sampling_frequency=resample_rate, t_start=parent_recording_segment.t_start
)

def get_num_samples(self):
return int(self._parent_segment.get_num_samples() / self._parent_rate * self.sampling_frequency)
if self.time_vector is not None:
return len(self.time_vector)
return int(self._parent_segment.get_num_samples() / self._parent_rate * self._resample_rate)

def get_traces(self, start_frame, end_frame, channel_indices):
# get parent traces with margin
parent_start_frame, parent_end_frame = [
int((frame / self.sampling_frequency) * self._parent_rate) for frame in [start_frame, end_frame]
int((frame / self._resample_rate) * self._parent_rate) for frame in [start_frame, end_frame]
]
parent_traces, left_margin, right_margin = get_chunk_with_margin(
self._parent_segment,
Expand All @@ -126,7 +144,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
)
# get left and right margins for the resampled case
left_margin_rs, right_margin_rs = [
int((margin / self._parent_rate) * self.sampling_frequency) for margin in [left_margin, right_margin]
int((margin / self._parent_rate) * self._resample_rate) for margin in [left_margin, right_margin]
]

# get the size for the resampled traces in case of resample:
Expand All @@ -136,9 +154,9 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# Check which method to use:
from scipy import signal

if np.mod(self._parent_rate, self.sampling_frequency) == 0:
if np.mod(self._parent_rate, self._resample_rate) == 0:
# Ratio between sampling frequencies
q = int(self._parent_rate / self.sampling_frequency)
q = int(self._parent_rate / self._resample_rate)
# Decimate can have issues for some cases, returning NaNs
resampled_traces = signal.decimate(parent_traces, q=q, axis=0)
# If that's the case, use signal.resample
Expand Down
92 changes: 92 additions & 0 deletions src/spikeinterface/preprocessing/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,98 @@ def test_resample_by_chunks():
plt.show()


def test_resample_preserves_t_start():
"""Resampling should preserve t_start when the parent has one."""
sampling_frequency = 30000
t_start = 100.5
traces = np.random.randn(sampling_frequency * 2, 2).astype(np.float32)
parent_rec = NumpyRecording(traces, sampling_frequency)
parent_rec._recording_segments[0].t_start = t_start

resampled = resample(parent_rec, 500)
assert resampled._recording_segments[0].t_start == t_start
assert not resampled.has_time_vector()
assert np.isclose(resampled.get_times()[0], t_start)


def test_resample_does_not_mutate_parent():
"""Resampling should not modify the parent recording's time_vector."""
sampling_frequency = 30000
n_samples = sampling_frequency * 2
traces = np.random.randn(n_samples, 2).astype(np.float32)
parent_rec = NumpyRecording(traces, sampling_frequency)
time_vector = np.arange(n_samples, dtype="float64") / sampling_frequency + 50.0
parent_rec.set_times(time_vector)

assert parent_rec.has_time_vector()
resample(parent_rec, 500)
assert parent_rec.has_time_vector(), "Parent time_vector was mutated by resample!"
np.testing.assert_array_equal(parent_rec.get_times(), time_vector)


def test_resample_preserves_time_vector_integer_ratio():
"""Resampling with integer ratio should slice the parent time_vector."""
sampling_frequency = 30000
resample_rate = 500
n_samples = sampling_frequency * 2
traces = np.random.randn(n_samples, 2).astype(np.float32)
parent_rec = NumpyRecording(traces, sampling_frequency)

# Create a time_vector with a gap (simulating artifact removal)
time_vector = np.arange(n_samples, dtype="float64") / sampling_frequency
# Insert a 5-second gap at the midpoint
midpoint = n_samples // 2
time_vector[midpoint:] += 5.0
parent_rec.set_times(time_vector)

resampled = resample(parent_rec, resample_rate)

assert resampled.has_time_vector()
resampled_times = resampled.get_times()
n_out = resampled.get_num_samples()

# Output length should be consistent
assert len(resampled_times) == n_out

# The gap should be preserved: check that the jump exists in the resampled times
diffs = np.diff(resampled_times)
normal_dt = 1.0 / resample_rate
gap_indices = np.where(diffs > normal_dt * 2)[0]
assert len(gap_indices) == 1, "The gap should appear exactly once in resampled times"
assert np.isclose(diffs[gap_indices[0]], normal_dt + 5.0, atol=normal_dt)

# Start time should match
assert np.isclose(resampled_times[0], time_vector[0])


def test_resample_preserves_time_vector_non_integer_ratio():
"""Resampling with non-integer ratio should interpolate the time_vector."""
sampling_frequency = 30000
resample_rate = 700 # 30000 / 700 is not integer
n_samples = sampling_frequency * 2
traces = np.random.randn(n_samples, 2).astype(np.float32)
parent_rec = NumpyRecording(traces, sampling_frequency)

time_vector = np.arange(n_samples, dtype="float64") / sampling_frequency + 10.0
parent_rec.set_times(time_vector)

import warnings as _warnings

with _warnings.catch_warnings(record=True) as w:
_warnings.simplefilter("always")
resampled = resample(parent_rec, resample_rate)
assert any("non-integer ratio" in str(warning.message).lower() for warning in w)

assert resampled.has_time_vector()
resampled_times = resampled.get_times()
assert len(resampled_times) == resampled.get_num_samples()
assert np.isclose(resampled_times[0], 10.0, atol=1.0 / sampling_frequency)


if __name__ == "__main__":
test_resample_freq_domain()
test_resample_by_chunks()
test_resample_preserves_t_start()
test_resample_does_not_mutate_parent()
test_resample_preserves_time_vector_integer_ratio()
test_resample_preserves_time_vector_non_integer_ratio()