From 3e8ad1cd95309b6f4b33b09bfddcf99129fc5907 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 19 Feb 2025 16:11:22 +0000 Subject: [PATCH 01/20] add rec or dict of rects --- src/spikeinterface/core/core_tools.py | 23 +++++++++++++++++++ .../preprocessing/common_reference.py | 6 +++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 2ee378e870..ecbe2d92b2 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -13,6 +13,29 @@ import numpy as np +def _is_documented_by(original): + def wrapper(target): + target.__doc__ = original.__doc__ + return target + + return wrapper + + +def _make_pp_from_rec_or_dict(recording_or_dict_of_recordings, source_class, **args): + from spikeinterface.core.baserecording import BaseRecording + + if isinstance(recording_or_dict_of_recordings, dict): + pp_dict = { + property_id: source_class(recording, **args) + for property_id, recording in recording_or_dict_of_recordings.items() + } + return pp_dict + elif isinstance(recording_or_dict_of_recordings, BaseRecording): + return source_class(recording_or_dict_of_recordings, **args) + else: + raise TypeError("You must supply a recording or a dictionary of recordings") + + def define_function_from_class(source_class, name): "Wrapper to change the name of a class" diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index b9bc1b4b53..8041704e26 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -2,7 +2,7 @@ import numpy as np from typing import Optional, Literal -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import _make_pp_from_rec_or_dict, is_documented_by from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_closest_channels @@ -259,4 +259,6 @@ def slice_groups(self, channel_indices): return zip(group_indices, selected_channels, group_channels) -common_reference = define_function_from_class(source_class=CommonReferenceRecording, name="common_reference") +@_is_documented_by(CommonReferenceRecording) +def common_reference(recording, **args): + return _make_pp_from_rec_or_dict(recording, CommonReferenceRecording, **args) From 95fa8fe0646a8ac800193b1618aa84e0468419a7 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 19 Feb 2025 16:14:43 +0000 Subject: [PATCH 02/20] oups --- src/spikeinterface/preprocessing/common_reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 8041704e26..773454891a 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -2,7 +2,7 @@ import numpy as np from typing import Optional, Literal -from spikeinterface.core.core_tools import _make_pp_from_rec_or_dict, is_documented_by +from spikeinterface.core.core_tools import _make_pp_from_rec_or_dict, _is_documented_by from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_closest_channels From 6bf1715c87c5cf93aa18904b24399285723867af Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Fri, 21 Feb 2025 09:59:13 +0000 Subject: [PATCH 03/20] Another go using decorators --- src/spikeinterface/core/core_tools.py | 50 ++++++++++++------- .../preprocessing/common_reference.py | 6 +-- .../preprocessing/silence_periods.py | 4 +- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index ecbe2d92b2..1e068866f0 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -9,31 +9,47 @@ import importlib from math import prod from collections import namedtuple +import inspect import numpy as np -def _is_documented_by(original): - def wrapper(target): - target.__doc__ = original.__doc__ - return target +def dict_of_source_classes(rec_or_dict_of_recs, source_class, *args, **kwargs): + preprocessed_recordings_dict = { + property_id: source_class(recording, *args, **kwargs) for property_id, recording in rec_or_dict_of_recs.items() + } + return preprocessed_recordings_dict - return wrapper +def define_rec_or_dict_function(source_class, name): -def _make_pp_from_rec_or_dict(recording_or_dict_of_recordings, source_class, **args): - from spikeinterface.core.baserecording import BaseRecording + from spikeinterface.core import BaseRecording - if isinstance(recording_or_dict_of_recordings, dict): - pp_dict = { - property_id: source_class(recording, **args) - for property_id, recording in recording_or_dict_of_recordings.items() - } - return pp_dict - elif isinstance(recording_or_dict_of_recordings, BaseRecording): - return source_class(recording_or_dict_of_recordings, **args) - else: - raise TypeError("You must supply a recording or a dictionary of recordings") + def source_class_or_dict_of_sources_classes(*args, **kwargs): + + recording_in_kwargs = False + if rec_or_dict_of_recs := kwargs.get("recording"): + recording_in_kwargs = True + else: + rec_or_dict_of_recs = args[0] + + if isinstance(rec_or_dict_of_recs, BaseRecording): + return source_class(*args, **kwargs) + else: + # Edit args & kwargs to pass the dict of recordings but _not_ the original recording + new_kwargs = {key: kwarg for key, kwarg in kwargs.items() if key != "recording"} + if recording_in_kwargs: + new_args = args + else: + new_args = args[1:] + + return dict_of_source_classes(rec_or_dict_of_recs, source_class, *new_args, **new_kwargs) + + source_class_or_dict_of_sources_classes.__signature__ = inspect.signature(source_class) + source_class_or_dict_of_sources_classes.__doc__ = source_class.__doc__ + source_class_or_dict_of_sources_classes.__name__ = name + + return source_class_or_dict_of_sources_classes def define_function_from_class(source_class, name): diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 773454891a..88914837c9 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -2,7 +2,7 @@ import numpy as np from typing import Optional, Literal -from spikeinterface.core.core_tools import _make_pp_from_rec_or_dict, _is_documented_by +from spikeinterface.core.core_tools import define_rec_or_dict_function from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_closest_channels @@ -259,6 +259,4 @@ def slice_groups(self, channel_indices): return zip(group_indices, selected_channels, group_channels) -@_is_documented_by(CommonReferenceRecording) -def common_reference(recording, **args): - return _make_pp_from_rec_or_dict(recording, CommonReferenceRecording, **args) +common_reference = define_rec_or_dict_function(CommonReferenceRecording, name="common_reference") diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 00d9a1a407..c499053736 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_rec_or_dict_function from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_random_data_chunks, get_noise_levels @@ -137,4 +137,4 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -silence_periods = define_function_from_class(source_class=SilencedPeriodsRecording, name="silence_periods") +silence_periods = define_rec_or_dict_function(SilencedPeriodsRecording, name="silence_periods") From 4bc7f155d00731e01da76f4d512dd62ba9f63e8e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 25 Feb 2025 09:53:05 +0000 Subject: [PATCH 04/20] add tests for pp_rec --- .../tests/test_grouped_preprocessing.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py diff --git a/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py b/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py new file mode 100644 index 0000000000..45ff1bf2d4 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py @@ -0,0 +1,68 @@ +from spikeinterface.core import generate_recording +from spikeinterface.preprocessing import common_reference, silence_periods + +import numpy as np + + +def get_some_traces(recording): + return recording.get_traces(start_frame=10, end_frame=12) + + +def check_recordings_are_equal(recording_1, recording_2): + assert np.all(get_some_traces(recording_1) == get_some_traces(recording_2)) + assert recording_1._kwargs == recording_2._kwargs + + +def test_grouped_preprocessing(): + """Here we make a dict of two recordings and apply preprocessing steps directly + to the dict. This should give the same result as applying preprocessing steps to + each recording separately. + + The arg/kwarg logic in `source_class_or_dict_of_sources_classes` is non-trivial, + so we test some arg/kwarg possibilities here. + """ + + recording_1 = generate_recording(seed=1205, durations=[5]) + recording_2 = generate_recording(seed=1205, durations=[6]) + + dict_of_recordings = {"one": recording_1, "two": recording_2} + + # First use dict_of_recordings as an arg + operator = "average" + dict_of_preprocessed_recordings = common_reference(dict_of_recordings, operator=operator) + + pp_recording_1 = common_reference(recording_1, operator=operator) + pp_recording_2 = common_reference(recording_2, operator=operator) + + check_recordings_are_equal(dict_of_preprocessed_recordings["one"], pp_recording_1) + check_recordings_are_equal(dict_of_preprocessed_recordings["two"], pp_recording_2) + + # Re-try using recording as a kwarg + dict_of_preprocessed_recordings = common_reference(recording=dict_of_recordings, operator=operator) + check_recordings_are_equal(dict_of_preprocessed_recordings["one"], pp_recording_1) + check_recordings_are_equal(dict_of_preprocessed_recordings["two"], pp_recording_2) + + # Now try a `silence periods` which has two args + list_periods = [[1, 2]] + mode = "noise" + + sp_recording_1 = silence_periods(recording_1, list_periods=list_periods, mode=mode) + sp_recording_2 = silence_periods(recording_2, list_periods=list_periods, mode=mode) + + dict_of_preprocessed_recordings = silence_periods(dict_of_recordings, list_periods, mode=mode) + check_recordings_are_equal(dict_of_preprocessed_recordings["one"], sp_recording_1) + check_recordings_are_equal(dict_of_preprocessed_recordings["two"], sp_recording_2) + + dict_of_preprocessed_recordings = silence_periods(dict_of_recordings, list_periods=list_periods, mode=mode) + check_recordings_are_equal(dict_of_preprocessed_recordings["one"], sp_recording_1) + check_recordings_are_equal(dict_of_preprocessed_recordings["two"], sp_recording_2) + + dict_of_preprocessed_recordings = silence_periods( + recording=dict_of_recordings, list_periods=list_periods, mode=mode + ) + check_recordings_are_equal(dict_of_preprocessed_recordings["one"], sp_recording_1) + check_recordings_are_equal(dict_of_preprocessed_recordings["two"], sp_recording_2) + + +if __name__ == "__main__": + test_grouped_preprocessing() From 1f919e592eef8f79d735af0b4bb5ed37d98bc225 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 25 Feb 2025 11:41:17 +0000 Subject: [PATCH 05/20] use `check_recordings_equal` --- .../tests/test_grouped_preprocessing.py | 55 +++++++++---------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py b/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py index 45ff1bf2d4..569f2130f7 100644 --- a/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py +++ b/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py @@ -1,16 +1,6 @@ from spikeinterface.core import generate_recording from spikeinterface.preprocessing import common_reference, silence_periods - -import numpy as np - - -def get_some_traces(recording): - return recording.get_traces(start_frame=10, end_frame=12) - - -def check_recordings_are_equal(recording_1, recording_2): - assert np.all(get_some_traces(recording_1) == get_some_traces(recording_2)) - assert recording_1._kwargs == recording_2._kwargs +from spikeinterface.core.testing import check_recordings_equal def test_grouped_preprocessing(): @@ -22,8 +12,10 @@ def test_grouped_preprocessing(): so we test some arg/kwarg possibilities here. """ - recording_1 = generate_recording(seed=1205, durations=[5]) - recording_2 = generate_recording(seed=1205, durations=[6]) + seed = 1205 + + recording_1 = generate_recording(seed=seed, durations=[5]) + recording_2 = generate_recording(seed=seed, durations=[6]) dict_of_recordings = {"one": recording_1, "two": recording_2} @@ -34,34 +26,37 @@ def test_grouped_preprocessing(): pp_recording_1 = common_reference(recording_1, operator=operator) pp_recording_2 = common_reference(recording_2, operator=operator) - check_recordings_are_equal(dict_of_preprocessed_recordings["one"], pp_recording_1) - check_recordings_are_equal(dict_of_preprocessed_recordings["two"], pp_recording_2) + check_recordings_equal(dict_of_preprocessed_recordings["one"], pp_recording_1) + check_recordings_equal(dict_of_preprocessed_recordings["two"], pp_recording_2) # Re-try using recording as a kwarg dict_of_preprocessed_recordings = common_reference(recording=dict_of_recordings, operator=operator) - check_recordings_are_equal(dict_of_preprocessed_recordings["one"], pp_recording_1) - check_recordings_are_equal(dict_of_preprocessed_recordings["two"], pp_recording_2) + check_recordings_equal(dict_of_preprocessed_recordings["one"], pp_recording_1) + check_recordings_equal(dict_of_preprocessed_recordings["two"], pp_recording_2) # Now try a `silence periods` which has two args - list_periods = [[1, 2]] + list_periods = [[1, 10]] mode = "noise" - sp_recording_1 = silence_periods(recording_1, list_periods=list_periods, mode=mode) - sp_recording_2 = silence_periods(recording_2, list_periods=list_periods, mode=mode) - - dict_of_preprocessed_recordings = silence_periods(dict_of_recordings, list_periods, mode=mode) - check_recordings_are_equal(dict_of_preprocessed_recordings["one"], sp_recording_1) - check_recordings_are_equal(dict_of_preprocessed_recordings["two"], sp_recording_2) + sp_recording_1 = silence_periods(recording_1, list_periods=list_periods, mode=mode, seed=seed) + sp_recording_2 = silence_periods(recording_2, list_periods=list_periods, mode=mode, seed=seed) - dict_of_preprocessed_recordings = silence_periods(dict_of_recordings, list_periods=list_periods, mode=mode) - check_recordings_are_equal(dict_of_preprocessed_recordings["one"], sp_recording_1) - check_recordings_are_equal(dict_of_preprocessed_recordings["two"], sp_recording_2) + dict_of_silence_period_recordings = silence_periods(dict_of_recordings, list_periods, mode=mode, seed=seed) + check_recordings_equal(dict_of_silence_period_recordings["one"], sp_recording_1) + check_recordings_equal(dict_of_silence_period_recordings["two"], sp_recording_2) dict_of_preprocessed_recordings = silence_periods( - recording=dict_of_recordings, list_periods=list_periods, mode=mode + dict_of_recordings, list_periods=list_periods, mode=mode, seed=seed + ) + check_recordings_equal(dict_of_silence_period_recordings["one"], sp_recording_1) + check_recordings_equal(dict_of_silence_period_recordings["two"], sp_recording_2) + + # Now pass dict_of_recordings as a kwarg + dict_of_silence_period_recordings = silence_periods( + recording=dict_of_recordings, list_periods=list_periods, mode=mode, seed=seed ) - check_recordings_are_equal(dict_of_preprocessed_recordings["one"], sp_recording_1) - check_recordings_are_equal(dict_of_preprocessed_recordings["two"], sp_recording_2) + check_recordings_equal(dict_of_silence_period_recordings["one"], sp_recording_1) + check_recordings_equal(dict_of_silence_period_recordings["two"], sp_recording_2) if __name__ == "__main__": From ca37f96c883b8938a0c16782d35dcf07a1c9a4a2 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 3 Mar 2025 09:29:10 +0000 Subject: [PATCH 06/20] respond to zach --- src/spikeinterface/core/core_tools.py | 14 ++++++++++---- .../preprocessing/common_reference.py | 4 ++-- .../preprocessing/silence_periods.py | 4 ++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 1e068866f0..339044ae26 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -14,14 +14,20 @@ import numpy as np -def dict_of_source_classes(rec_or_dict_of_recs, source_class, *args, **kwargs): +def create_dict_of_source_classes(dict_of_recs, source_class, *args, **kwargs): + """Given a dict of recordings, return a dict of recordings with `source_class` applied to them.""" preprocessed_recordings_dict = { - property_id: source_class(recording, *args, **kwargs) for property_id, recording in rec_or_dict_of_recs.items() + property_id: source_class(recording, *args, **kwargs) for property_id, recording in dict_of_recs.items() } return preprocessed_recordings_dict -def define_rec_or_dict_function(source_class, name): +def define_function_or_dict_from_class(source_class, name): + """ + Depending on whether `source_class` is passed a `Recording` object or a dict of + `Recording` objects, this function will return `source_class` or a dict of + `source_class` objects to match the input. + """ from spikeinterface.core import BaseRecording @@ -43,7 +49,7 @@ def source_class_or_dict_of_sources_classes(*args, **kwargs): else: new_args = args[1:] - return dict_of_source_classes(rec_or_dict_of_recs, source_class, *new_args, **new_kwargs) + return create_dict_of_source_classes(rec_or_dict_of_recs, source_class, *new_args, **new_kwargs) source_class_or_dict_of_sources_classes.__signature__ = inspect.signature(source_class) source_class_or_dict_of_sources_classes.__doc__ = source_class.__doc__ diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 88914837c9..c7447a4457 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -2,7 +2,7 @@ import numpy as np from typing import Optional, Literal -from spikeinterface.core.core_tools import define_rec_or_dict_function +from spikeinterface.core.core_tools import define_function_or_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_closest_channels @@ -259,4 +259,4 @@ def slice_groups(self, channel_indices): return zip(group_indices, selected_channels, group_channels) -common_reference = define_rec_or_dict_function(CommonReferenceRecording, name="common_reference") +common_reference = define_function_or_dict_from_class(CommonReferenceRecording, name="common_reference") diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c499053736..e89ece2fa4 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_rec_or_dict_function +from spikeinterface.core.core_tools import define_function_or_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_random_data_chunks, get_noise_levels @@ -137,4 +137,4 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -silence_periods = define_rec_or_dict_function(SilencedPeriodsRecording, name="silence_periods") +silence_periods = define_function_or_dict_from_class(SilencedPeriodsRecording, name="silence_periods") From 29a8ca1f75f1cb81b9f66289965571ae7c17fd65 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 4 Mar 2025 15:48:42 +0000 Subject: [PATCH 07/20] reply to sam --- src/spikeinterface/core/core_tools.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 339044ae26..9c195a9dc7 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -14,14 +14,6 @@ import numpy as np -def create_dict_of_source_classes(dict_of_recs, source_class, *args, **kwargs): - """Given a dict of recordings, return a dict of recordings with `source_class` applied to them.""" - preprocessed_recordings_dict = { - property_id: source_class(recording, *args, **kwargs) for property_id, recording in dict_of_recs.items() - } - return preprocessed_recordings_dict - - def define_function_or_dict_from_class(source_class, name): """ Depending on whether `source_class` is passed a `Recording` object or a dict of @@ -41,7 +33,7 @@ def source_class_or_dict_of_sources_classes(*args, **kwargs): if isinstance(rec_or_dict_of_recs, BaseRecording): return source_class(*args, **kwargs) - else: + elif isinstance(rec_or_dict_of_recs, dict): # Edit args & kwargs to pass the dict of recordings but _not_ the original recording new_kwargs = {key: kwarg for key, kwarg in kwargs.items() if key != "recording"} if recording_in_kwargs: @@ -49,7 +41,14 @@ def source_class_or_dict_of_sources_classes(*args, **kwargs): else: new_args = args[1:] - return create_dict_of_source_classes(rec_or_dict_of_recs, source_class, *new_args, **new_kwargs) + preprocessed_recordings_dict = { + property_id: source_class(recording, *new_args, **new_kwargs) + for property_id, recording in rec_or_dict_of_recs.items() + } + + return preprocessed_recordings_dict + else: + raise TypeError(f"The function `{name}` only accepts a recording or a dict of recordings.") source_class_or_dict_of_sources_classes.__signature__ = inspect.signature(source_class) source_class_or_dict_of_sources_classes.__doc__ = source_class.__doc__ From c2be8085cfc16f47d39e66edc12fc327467b1e74 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 7 Mar 2025 15:06:47 +0100 Subject: [PATCH 08/20] Change define_function* name and propagate to all preprocessing --- src/spikeinterface/preprocessing/astype.py | 4 ++-- .../preprocessing/average_across_direction.py | 4 ++-- src/spikeinterface/preprocessing/clip.py | 8 +++++--- src/spikeinterface/preprocessing/common_reference.py | 4 ++-- src/spikeinterface/preprocessing/decimate.py | 4 ++-- .../deepinterpolation/deepinterpolation.py | 6 ++++-- src/spikeinterface/preprocessing/depth_order.py | 4 ++-- .../preprocessing/directional_derivative.py | 4 ++-- src/spikeinterface/preprocessing/filter.py | 10 +++++----- src/spikeinterface/preprocessing/filter_gaussian.py | 4 ++-- .../preprocessing/highpass_spatial_filter.py | 4 ++-- .../preprocessing/interpolate_bad_channels.py | 4 ++-- src/spikeinterface/preprocessing/normalize_scale.py | 10 +++++----- src/spikeinterface/preprocessing/phase_shift.py | 4 ++-- src/spikeinterface/preprocessing/rectify.py | 4 ++-- src/spikeinterface/preprocessing/remove_artifacts.py | 6 ++++-- src/spikeinterface/preprocessing/resample.py | 4 ++-- src/spikeinterface/preprocessing/silence_periods.py | 4 ++-- src/spikeinterface/preprocessing/unsigned_to_signed.py | 6 ++++-- src/spikeinterface/preprocessing/whiten.py | 4 ++-- src/spikeinterface/preprocessing/zero_channel_pad.py | 8 +++++--- 21 files changed, 60 insertions(+), 50 deletions(-) diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index a05610ea2e..cace73a0a3 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -2,7 +2,7 @@ import numpy as np -from ..core.core_tools import define_function_from_class +from ..core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from .filter import fix_dtype @@ -80,4 +80,4 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -astype = define_function_from_class(source_class=AstypeRecording, name="astype") +astype = define_function_handling_dict_from_class(source_class=AstypeRecording, name="astype") diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 88c5f7301a..ce23cb3f49 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core import BaseRecording, BaseRecordingSegment from .basepreprocessor import BasePreprocessorSegment -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class class AverageAcrossDirectionRecording(BaseRecording): @@ -139,7 +139,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -average_across_direction = define_function_from_class( +average_across_direction = define_function_handling_dict_from_class( source_class=AverageAcrossDirectionRecording, name="average_across_direction", ) diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 47a4a20d21..3bc35b25a8 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_random_data_chunks @@ -169,5 +169,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces -clip = define_function_from_class(source_class=ClipRecording, name="clip") -blank_staturation = define_function_from_class(source_class=BlankSaturationRecording, name="blank_staturation") +clip = define_function_handling_dict_from_class(source_class=ClipRecording, name="clip") +blank_staturation = define_function_handling_dict_from_class( + source_class=BlankSaturationRecording, name="blank_staturation" +) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index c7447a4457..3b572cf90f 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -2,7 +2,7 @@ import numpy as np from typing import Optional, Literal -from spikeinterface.core.core_tools import define_function_or_dict_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_closest_channels @@ -259,4 +259,4 @@ def slice_groups(self, channel_indices): return zip(group_indices, selected_channels, group_channels) -common_reference = define_function_or_dict_from_class(CommonReferenceRecording, name="common_reference") +common_reference = define_function_handling_dict_from_class(CommonReferenceRecording, name="common_reference") diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index d5fc9d2025..e857d6edf3 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -2,7 +2,7 @@ import numpy as np from spikeinterface.core.core_tools import ( - define_function_from_class, + define_function_handling_dict_from_class, ) from .basepreprocessor import BasePreprocessor @@ -136,4 +136,4 @@ def get_traces(self, start_frame, end_frame, channel_indices): ].astype(self._dtype) -decimate = define_function_from_class(source_class=DecimateRecording, name="decimate") +decimate = define_function_handling_dict_from_class(source_class=DecimateRecording, name="decimate") diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index f58bc5b578..5d9beac047 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -5,7 +5,7 @@ from packaging.version import parse from .tf_utils import has_tf, import_tf -from ...core.core_tools import define_function_from_class +from ...core.core_tools import define_function_handling_dict_from_class from ..basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -194,4 +194,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -deepinterpolate = define_function_from_class(source_class=DeepInterpolatedRecording, name="deepinterpolate") +deepinterpolate = define_function_handling_dict_from_class( + source_class=DeepInterpolatedRecording, name="deepinterpolate" +) diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index a112774fb1..f9c1f095f5 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -1,7 +1,7 @@ from __future__ import annotations from ..core import order_channels_by_depth, ChannelSliceRecording -from ..core.core_tools import define_function_from_class +from ..core.core_tools import define_function_handling_dict_from_class class DepthOrderRecording(ChannelSliceRecording): @@ -43,4 +43,4 @@ def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y"), fl ) -depth_order = define_function_from_class(source_class=DepthOrderRecording, name="depth_order") +depth_order = define_function_handling_dict_from_class(source_class=DepthOrderRecording, name="depth_order") diff --git a/src/spikeinterface/preprocessing/directional_derivative.py b/src/spikeinterface/preprocessing/directional_derivative.py index 3a6a480f59..c945e4e6d4 100644 --- a/src/spikeinterface/preprocessing/directional_derivative.py +++ b/src/spikeinterface/preprocessing/directional_derivative.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core import BaseRecording, BaseRecordingSegment from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class class DirectionalDerivativeRecording(BasePreprocessor): @@ -136,6 +136,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -directional_derivative = define_function_from_class( +directional_derivative = define_function_handling_dict_from_class( source_class=DirectionalDerivativeRecording, name="directional_derivative" ) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index a67d163d3d..66d757735a 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_chunk_with_margin @@ -322,10 +322,10 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): # functions for API -filter = define_function_from_class(source_class=FilterRecording, name="filter") -bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter") -notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") -highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") +filter = define_function_handling_dict_from_class(source_class=FilterRecording, name="filter") +bandpass_filter = define_function_handling_dict_from_class(source_class=BandpassFilterRecording, name="bandpass_filter") +notch_filter = define_function_handling_dict_from_class(source_class=NotchFilterRecording, name="notch_filter") +highpass_filter = define_function_handling_dict_from_class(source_class=HighpassFilterRecording, name="highpass_filter") def causal_filter( diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index b16df9be69..b053ef6533 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -5,7 +5,7 @@ import numpy as np from spikeinterface.core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin, normal_pdf -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -136,4 +136,4 @@ def _create_gaussian(self, N: int, cutoff_f: float): return gaussian -gaussian_filter = define_function_from_class(source_class=GaussianFilterRecording, name="gaussian_filter") +gaussian_filter = define_function_handling_dict_from_class(source_class=GaussianFilterRecording, name="gaussian_filter") diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index 86836f262b..d2074bce43 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -5,7 +5,7 @@ from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from .filter import fix_dtype from ..core import order_channels_by_depth, get_chunk_with_margin -from ..core.core_tools import define_function_from_class +from ..core.core_tools import define_function_handling_dict_from_class class HighpassSpatialFilterRecording(BasePreprocessor): @@ -245,7 +245,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -highpass_spatial_filter = define_function_from_class( +highpass_spatial_filter = define_function_handling_dict_from_class( source_class=HighpassSpatialFilterRecording, name="highpass_spatial_filter" ) diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 508868e0bb..87bb0f936b 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -3,7 +3,7 @@ import numpy as np from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing import preprocessing_tools @@ -113,6 +113,6 @@ def estimate_recommended_sigma_um(recording): return scipy.stats.mode(np.diff(np.unique(y_sorted)), keepdims=False)[0] -interpolate_bad_channels = define_function_from_class( +interpolate_bad_channels = define_function_handling_dict_from_class( source_class=InterpolateBadChannelsRecording, name="interpolate_bad_channels" ) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index d464c95f4f..c53aceb17a 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -326,9 +326,9 @@ def __init__( # functions for API -normalize_by_quantile = define_function_from_class( +normalize_by_quantile = define_function_handling_dict_from_class( source_class=NormalizeByQuantileRecording, name="normalize_by_quantile" ) -scale = define_function_from_class(source_class=ScaleRecording, name="scale") -center = define_function_from_class(source_class=CenterRecording, name="center") -zscore = define_function_from_class(source_class=ZScoreRecording, name="zscore") +scale = define_function_handling_dict_from_class(source_class=ScaleRecording, name="scale") +center = define_function_handling_dict_from_class(source_class=CenterRecording, name="center") +zscore = define_function_handling_dict_from_class(source_class=ZScoreRecording, name="zscore") diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 664964fcf2..5eccdace8f 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from ..core import get_chunk_with_margin @@ -108,7 +108,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -phase_shift = define_function_from_class(source_class=PhaseShiftRecording, name="phase_shift") +phase_shift = define_function_handling_dict_from_class(source_class=PhaseShiftRecording, name="phase_shift") def apply_frequency_shift(signal, shift_samples, axis=0): diff --git a/src/spikeinterface/preprocessing/rectify.py b/src/spikeinterface/preprocessing/rectify.py index aea866452b..96d68dda90 100644 --- a/src/spikeinterface/preprocessing/rectify.py +++ b/src/spikeinterface/preprocessing/rectify.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -27,4 +27,4 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -rectify = define_function_from_class(source_class=RectifyRecording, name="rectify") +rectify = define_function_handling_dict_from_class(source_class=RectifyRecording, name="rectify") diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index aa1746df25..1c1ee8bafa 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -4,7 +4,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.core import NumpySorting, estimate_templates @@ -444,4 +444,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -remove_artifacts = define_function_from_class(source_class=RemoveArtifactsRecording, name="remove_artifacts") +remove_artifacts = define_function_handling_dict_from_class( + source_class=RemoveArtifactsRecording, name="remove_artifacts" +) diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index f076646fdb..ccab7df332 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -4,7 +4,7 @@ import warnings from spikeinterface.core.core_tools import ( - define_function_from_class, + define_function_handling_dict_from_class, recursive_key_finder, ) @@ -154,7 +154,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): return resampled_traces.astype(self._dtype) -resample = define_function_from_class(source_class=ResampleRecording, name="resample") +resample = define_function_handling_dict_from_class(source_class=ResampleRecording, name="resample") # Some helpers to do checks diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index e89ece2fa4..1921835b37 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_or_dict_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..core import get_random_data_chunks, get_noise_levels @@ -137,4 +137,4 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -silence_periods = define_function_or_dict_from_class(SilencedPeriodsRecording, name="silence_periods") +silence_periods = define_function_handling_dict_from_class(SilencedPeriodsRecording, name="silence_periods") diff --git a/src/spikeinterface/preprocessing/unsigned_to_signed.py b/src/spikeinterface/preprocessing/unsigned_to_signed.py index 244fab1bd9..c5291da358 100644 --- a/src/spikeinterface/preprocessing/unsigned_to_signed.py +++ b/src/spikeinterface/preprocessing/unsigned_to_signed.py @@ -2,7 +2,7 @@ import numpy as np -from ..core.core_tools import define_function_from_class +from ..core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -66,4 +66,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -unsigned_to_signed = define_function_from_class(source_class=UnsignedToSignedRecording, name="unsigned_to_signed") +unsigned_to_signed = define_function_handling_dict_from_class( + source_class=UnsignedToSignedRecording, name="unsigned_to_signed" +) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 00c454f8f3..7768cbb948 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -3,7 +3,7 @@ import numpy as np from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from ..core import get_random_data_chunks, get_channel_distances from .filter import fix_dtype @@ -285,4 +285,4 @@ def compute_sklearn_covariance_matrix(data, regularize_kwargs): # function for API -whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") +whiten = define_function_handling_dict_from_class(source_class=WhitenRecording, name="whiten") diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index ab1c90dfd9..c06baf525a 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -6,7 +6,7 @@ from spikeinterface.core import BaseRecording, BaseRecordingSegment from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class class TracePaddedRecording(BasePreprocessor): @@ -201,5 +201,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # function for API -zero_channel_pad = define_function_from_class(source_class=ZeroChannelPaddedRecording, name="zero_channel_pad") -pad_traces = define_function_from_class(source_class=TracePaddedRecording, name="pad_traces") +zero_channel_pad = define_function_handling_dict_from_class( + source_class=ZeroChannelPaddedRecording, name="zero_channel_pad" +) +pad_traces = define_function_handling_dict_from_class(source_class=TracePaddedRecording, name="pad_traces") From 8a1a9fe181199b17425547e1a6759c2935462c80 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 7 Mar 2025 15:09:39 +0100 Subject: [PATCH 09/20] oups --- src/spikeinterface/core/core_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 9c195a9dc7..5a688e4869 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -14,7 +14,7 @@ import numpy as np -def define_function_or_dict_from_class(source_class, name): +def define_function_handling_dict_from_class(source_class, name): """ Depending on whether `source_class` is passed a `Recording` object or a dict of `Recording` objects, this function will return `source_class` or a dict of From 543ab272b2451f5925c9bfad7fbb0220d769848a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 11 Mar 2025 18:58:29 +0100 Subject: [PATCH 10/20] debug deepinterpolation --- .github/workflows/deepinterpolation.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/deepinterpolation.yml b/.github/workflows/deepinterpolation.yml index cdbef0f0e0..3af34d4955 100644 --- a/.github/workflows/deepinterpolation.yml +++ b/.github/workflows/deepinterpolation.yml @@ -44,6 +44,10 @@ jobs: pip install deepinterpolation@git+https://github.com/AllenInstitute/deepinterpolation.git pip install protobuf==3.20.* pip install -e .[full,test_core] + - name: Pip list + if: ${{ steps.modules-changed.outputs.DEEPINTERPOLATION_CHANGED == 'true' }} + run: | + pip list - name: Test DeepInterpolation with pytest if: ${{ steps.modules-changed.outputs.DEEPINTERPOLATION_CHANGED == 'true' }} run: | From 5727d018f819bc57c6c1ebccc9d118508565508e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 11 Mar 2025 19:01:12 +0100 Subject: [PATCH 11/20] debug deepinterpolation 2 --- .github/workflows/deepinterpolation.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/deepinterpolation.yml b/.github/workflows/deepinterpolation.yml index 3af34d4955..1b5ff48037 100644 --- a/.github/workflows/deepinterpolation.yml +++ b/.github/workflows/deepinterpolation.yml @@ -43,6 +43,7 @@ jobs: pip install tensorflow==2.7.0 pip install deepinterpolation@git+https://github.com/AllenInstitute/deepinterpolation.git pip install protobuf==3.20.* + pip install numpy==1.26.4 pip install -e .[full,test_core] - name: Pip list if: ${{ steps.modules-changed.outputs.DEEPINTERPOLATION_CHANGED == 'true' }} From c822995aff07d7a3b7d7dbc0a2f099c9c7227f29 Mon Sep 17 00:00:00 2001 From: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 12 Mar 2025 10:13:53 +0000 Subject: [PATCH 12/20] Update src/spikeinterface/preprocessing/clip.py --- src/spikeinterface/preprocessing/clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 3bc35b25a8..549706bee5 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -171,5 +171,5 @@ def get_traces(self, start_frame, end_frame, channel_indices): clip = define_function_handling_dict_from_class(source_class=ClipRecording, name="clip") blank_staturation = define_function_handling_dict_from_class( - source_class=BlankSaturationRecording, name="blank_staturation" + source_class=BlankSaturationRecording, name="blank_saturation" ) From c50a773602742f104cfc5dc77d546bda700f8646 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 14:53:05 +0000 Subject: [PATCH 13/20] update docs --- doc/how_to/process_by_channel_group.rst | 34 ++++++++----------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index 08a87ab738..2b9e0e12a9 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -87,40 +87,28 @@ Splitting a recording by channel group returns a dictionary containing separate Preprocessing a Recording by Channel Group ------------------------------------------ -The essence of preprocessing by channel group is to first split the recording -into separate recordings, perform the preprocessing steps, then aggregate -the channels back together. - -In the below example, we loop over the split recordings, preprocessing each channel group -individually. At the end, we use the :py:func:`~aggregate_channels` function -to combine the separate channel group recordings back together. +If a preprocessing function is given a dictionary of recordings, it will apply the preprocessing +seperately to each recording in the dict, and return a dictionary of preprocessed recordings. +Hence we can pass the ``split_recording_dict`` in the same way as we would pass a single recording +to any preprocessing function. .. code-block:: python - preprocessed_recordings = [] - - # loop over the recordings contained in the dictionary - for chan_group_rec in split_recordings_dict.values(): - - # Apply the preprocessing steps to the channel group in isolation - shifted_recording = spre.phase_shift(chan_group_rec) + shifted_recordings = spre.phase_shift(split_recording_dict) + filtered_recording = spre.bandpass_filter(shifted_recording) + referenced_recording = spre.common_reference(filtered_recording) - filtered_recording = spre.bandpass_filter(shifted_recording) +We can then aggregate the recordings back together using the ``aggregate_channels`` function - referenced_recording = spre.common_reference(filtered_recording) - - preprocessed_recordings.append(referenced_recording) +.. code-block:: python - # Combine our preprocessed channel groups back together - combined_preprocessed_recording = aggregate_channels(preprocessed_recordings) + combined_preprocessed_recording = aggregate_channels(referenced_recording) -Now, when this recording is used in sorting, plotting, or whenever +Now, when ``combined_preprocessed_recording`` is used in sorting, plotting, or whenever calling its :py:func:`~get_traces` method, the data will have been preprocessed separately per-channel group (then concatenated back together under the hood). -It is strongly recommended to use the above structure to preprocess by channel group. - .. note:: The splitting and aggregation of channels for preprocessing is flexible. From 5c2c7a363ab9b693ce447c63285d2c70a7706d71 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 14:57:27 +0000 Subject: [PATCH 14/20] deepinterpolation bug --- .../preprocessing/deepinterpolation/deepinterpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 6d7c227534..aa3f9a34cc 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -5,7 +5,7 @@ from packaging.version import parse from .tf_utils import has_tf, import_tf -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment From 6a800c3070e1e88bd6859ee92fbaf4f6b2e2771a Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 15:08:56 +0000 Subject: [PATCH 15/20] Revert "Merge branch 'group-splitting-in-pp' of https://github.com/chrishalcrow/spikeinterface into group-splitting-in-pp" This reverts commit 3c40783bc7f6faedb7e0e0562cb6346add5766a3, reversing changes made to 5c2c7a363ab9b693ce447c63285d2c70a7706d71. I did something weird --- .github/scripts/check_kilosort4_releases.py | 9 +- .github/scripts/test_kilosort4_ci.py | 19 +- .../core/baserecordingsnippets.py | 1 - .../core/channelsaggregationrecording.py | 59 +---- .../test_channelsaggregationrecording.py | 210 +++++------------- .../sorters/external/kilosort4.py | 30 +-- src/spikeinterface/widgets/sorting_summary.py | 2 +- 7 files changed, 84 insertions(+), 246 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 8bd0163e3a..7a6368f3cf 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -1,9 +1,10 @@ import os +import re from pathlib import Path import requests import json from packaging.version import parse - +import spikeinterface def get_pypi_versions(package_name): """ @@ -15,10 +16,8 @@ def get_pypi_versions(package_name): response.raise_for_status() data = response.json() versions = list(sorted(data["releases"].keys())) - # Filter out versions that are less than 4.0.16 and different from 4.0.26 and 4.0.27 - # (buggy - https://github.com/MouseLand/Kilosort/releases/tag/v4.0.26) - versions = [ver for ver in versions if parse(ver) >= parse("4.0.16") and - parse(ver) not in [parse("4.0.26"), parse("4.0.27")]] + # Filter out versions that are less than 4.0.16 + versions = [ver for ver in versions if parse(ver) >= parse("4.0.16")] return versions diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 16e3c1ec7d..e19faccb6e 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -258,22 +258,17 @@ def test_compute_preprocessing_arguments(self): self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"]) def test_compute_drift_location_arguments(self): - expected_arguments = ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"] - if parse(kilosort.__version__) >= parse("4.0.28"): - expected_arguments += ["verbose"] - self._check_arguments(compute_drift_correction, expected_arguments) + self._check_arguments( + compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"] + ) def test_detect_spikes_arguments(self): - expected_arguments = ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"] - if parse(kilosort.__version__) >= parse("4.0.28"): - expected_arguments += ["verbose"] - self._check_arguments(detect_spikes, expected_arguments) + self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) def test_cluster_spikes_arguments(self): - expected_arguments = ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"] - if parse(kilosort.__version__) >= parse("4.0.28"): - expected_arguments += ["verbose"] - self._check_arguments(cluster_spikes, expected_arguments) + self._check_arguments( + cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"] + ) def test_save_sorting_arguments(self): expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 3b8ffc7b03..b224e0d282 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -564,7 +564,6 @@ def split_by(self, property="group", outputs="dict"): (inds,) = np.nonzero(values == value) new_channel_ids = self.get_channel_ids()[inds] subrec = self.select_channels(new_channel_ids) - subrec.set_annotation("split_by_property", value=property) if outputs == "list": recordings.append(subrec) elif outputs == "dict": diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 3947a0decc..4fa1d88974 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -1,5 +1,4 @@ from __future__ import annotations -import warnings import numpy as np @@ -14,26 +13,10 @@ class ChannelsAggregationRecording(BaseRecording): """ - def __init__(self, recording_list_or_dict, renamed_channel_ids=None): - - if isinstance(recording_list_or_dict, dict): - recording_list = list(recording_list_or_dict.values()) - recording_ids = list(recording_list_or_dict.keys()) - elif isinstance(recording_list_or_dict, list): - recording_list = recording_list_or_dict - recording_ids = range(len(recording_list)) - else: - raise TypeError( - "`aggregate_channels` only accepts a list of recordings or a dict whose values are all recordings." - ) + def __init__(self, recording_list, renamed_channel_ids=None): self._recordings = recording_list - splitting_known = self._is_splitting_known() - if not splitting_known: - for group_id, recording in zip(recording_ids, recording_list): - recording.set_property("group", [group_id] * recording.get_num_channels()) - self._perform_consistency_checks() sampling_frequency = recording_list[0].get_sampling_frequency() dtype = recording_list[0].get_dtype() @@ -118,25 +101,6 @@ def __init__(self, recording_list_or_dict, renamed_channel_ids=None): def recordings(self): return self._recordings - def _is_splitting_known(self): - - # If we have the `split_by_property` annotation, we know how the recording was split - if self._recordings[0].get_annotation("split_by_property") is not None: - return True - - # Check if all 'group' properties are equal to 0 - recording_groups = [] - for recording in self._recordings: - if (group_labels := recording.get_property("group")) is not None: - recording_groups.extend(group_labels) - else: - recording_groups.extend([0]) - # If so, we don't know the splitting - if np.all(np.unique(recording_groups) == np.array([0])): - return False - else: - return True - def _perform_consistency_checks(self): # Check for consistent sampling frequency across recordings @@ -237,18 +201,14 @@ def get_traces( return np.concatenate(traces, axis=1) -def aggregate_channels( - recording_list_or_dict=None, - renamed_channel_ids=None, - recording_list=None, -): +def aggregate_channels(recording_list, renamed_channel_ids=None): """ Aggregates channels of multiple recording into a single recording object Parameters ---------- - recording_list_or_dict: list | dict - List or dict of BaseRecording objects to aggregate. + recording_list: list + List of BaseRecording objects to aggregate renamed_channel_ids: array-like If given, channel ids are renamed as provided. @@ -257,13 +217,4 @@ def aggregate_channels( aggregate_recording: ChannelsAggregationRecording The aggregated recording object """ - - if recording_list is not None: - warnings.warn( - "`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.", - category=DeprecationWarning, - stacklevel=2, - ) - recording_list_or_dict = recording_list - - return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids) + return ChannelsAggregationRecording(recording_list, renamed_channel_ids) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index a9bb51dfed..16a91a55e1 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -1,6 +1,7 @@ import numpy as np from spikeinterface.core import aggregate_channels + from spikeinterface.core import generate_recording @@ -20,98 +21,65 @@ def test_channelsaggregationrecording(): recording2.set_channel_locations([[20.0, 20.0], [20.0, 40.0], [20.0, 60.0]]) recording3.set_channel_locations([[40.0, 20.0], [40.0, 40.0], [40.0, 60.0]]) - recordings_list_possibilities = [ - [recording1, recording2, recording3], - {0: recording1, 1: recording2, 2: recording3}, - ] - - for recordings_list in recordings_list_possibilities: - - # test num channels - recording_agg = aggregate_channels(recordings_list) - assert len(recording_agg.get_channel_ids()) == 3 * num_channels - - assert np.allclose(recording_agg.get_times(0), recording1.get_times(0)) - - # test traces - channel_ids = recording1.get_channel_ids() - - for seg in range(num_seg): - # single channels - traces1_1 = recording1.get_traces(channel_ids=[channel_ids[1]], segment_index=seg) - traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) - traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) - - assert np.allclose( - traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg) - ) - assert np.allclose( - traces2_0, - recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), - ) - assert np.allclose( - traces3_2, - recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), - ) - # all traces - traces1 = recording1.get_traces(segment_index=seg) - traces2 = recording2.get_traces(segment_index=seg) - traces3 = recording3.get_traces(segment_index=seg) - - assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg)) - assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg)) - assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg)) - - # test rename channels - renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] - recording_agg_renamed = aggregate_channels(recordings_list, renamed_channel_ids=renamed_channel_ids) - assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids()) - - # test properties - # complete property - recording1.set_property("brain_area", ["CA1"] * num_channels) - recording2.set_property("brain_area", ["CA2"] * num_channels) - recording3.set_property("brain_area", ["CA3"] * num_channels) - - # skip for inconsistency - recording1.set_property("template", np.zeros((num_channels, 4, 30))) - recording2.set_property("template", np.zeros((num_channels, 20, 50))) - recording3.set_property("template", np.zeros((num_channels, 2, 10))) - - # incomplete property - recording1.set_property("quality", ["good"] * num_channels) - recording2.set_property("quality", ["bad"] * num_channels) - - recording_agg_prop = aggregate_channels(recordings_list) - assert "brain_area" in recording_agg_prop.get_property_keys() - assert "quality" not in recording_agg_prop.get_property_keys() - print(recording_agg_prop.get_property("brain_area")) - - -def test_split_then_aggreate_preserve_user_property(): - """ - Checks that splitting then aggregating a recording preserves the unit_id to property mapping. - """ - - num_channels = 10 - durations = [10, 5] - recording = generate_recording(num_channels=num_channels, durations=durations, set_probe=False) - - recording.set_property(key="group", values=[2, 0, 1, 1, 1, 0, 1, 0, 1, 2]) - - old_properties = recording.get_property(key="group") - old_channel_ids = recording.channel_ids - old_properties_ids_dict = dict(zip(old_channel_ids, old_properties)) - - split_recordings = recording.split_by("group") - - aggregated_recording = aggregate_channels(split_recordings) - - new_properties = aggregated_recording.get_property(key="group") - new_channel_ids = aggregated_recording.channel_ids - new_properties_ids_dict = dict(zip(new_channel_ids, new_properties)) - - assert np.all(old_properties_ids_dict == new_properties_ids_dict) + # test num channels + recording_agg = aggregate_channels([recording1, recording2, recording3]) + assert len(recording_agg.get_channel_ids()) == 3 * num_channels + + assert np.allclose(recording_agg.get_times(0), recording1.get_times(0)) + + # test traces + channel_ids = recording1.get_channel_ids() + + for seg in range(num_seg): + # single channels + traces1_1 = recording1.get_traces(channel_ids=[channel_ids[1]], segment_index=seg) + traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) + traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) + + assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) + assert np.allclose( + traces2_0, + recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), + ) + assert np.allclose( + traces3_2, + recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), + ) + # all traces + traces1 = recording1.get_traces(segment_index=seg) + traces2 = recording2.get_traces(segment_index=seg) + traces3 = recording3.get_traces(segment_index=seg) + + assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg)) + assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg)) + assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg)) + + # test rename channels + renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] + recording_agg_renamed = aggregate_channels( + [recording1, recording2, recording3], renamed_channel_ids=renamed_channel_ids + ) + assert all(chan in renamed_channel_ids for chan in recording_agg_renamed.get_channel_ids()) + + # test properties + # complete property + recording1.set_property("brain_area", ["CA1"] * num_channels) + recording2.set_property("brain_area", ["CA2"] * num_channels) + recording3.set_property("brain_area", ["CA3"] * num_channels) + + # skip for inconsistency + recording1.set_property("template", np.zeros((num_channels, 4, 30))) + recording2.set_property("template", np.zeros((num_channels, 20, 50))) + recording3.set_property("template", np.zeros((num_channels, 2, 10))) + + # incomplete property + recording1.set_property("quality", ["good"] * num_channels) + recording2.set_property("quality", ["bad"] * num_channels) + + recording_agg_prop = aggregate_channels([recording1, recording2, recording3]) + assert "brain_area" in recording_agg_prop.get_property_keys() + assert "quality" not in recording_agg_prop.get_property_keys() + print(recording_agg_prop.get_property("brain_area")) def test_channel_aggregation_preserve_ids(): @@ -126,64 +94,6 @@ def test_channel_aggregation_preserve_ids(): assert list(aggregated_recording.get_channel_ids()) == ["a", "b", "c", "d", "e"] -def test_aggregation_labeling_for_lists(): - """Aggregated lists of recordings get different labels depending on their underlying `property`s""" - - recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) - recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) - - # If we don't label at all, aggregation will add a 'group' label - aggregated_recording = aggregate_channels([recording1, recording2]) - group_property = aggregated_recording.get_property("group") - assert np.all(group_property == [0, 0, 0, 0, 1, 1]) - - # If we have different group labels, these should be respected - recording1.set_property("group", [2, 2, 2, 2]) - recording2.set_property("group", [6, 6]) - aggregated_recording = aggregate_channels([recording1, recording2]) - group_property = aggregated_recording.get_property("group") - assert np.all(group_property == [2, 2, 2, 2, 6, 6]) - - # If we use `split_by`, aggregation should retain the split_by property, even if we only pass the list - recording1.set_property("user_group", [6, 7, 6, 7]) - recording_list = list(recording1.split_by("user_group").values()) - aggregated_recording = aggregate_channels(recording_list) - group_property = aggregated_recording.get_property("group") - assert np.all(group_property == [2, 2, 2, 2]) - user_group_property = aggregated_recording.get_property("user_group") - # Note, aggregation reorders the channel_ids into the order of the ids of each individual recording - assert np.all(user_group_property == [6, 6, 7, 7]) - - -def test_aggretion_labelling_for_dicts(): - """Aggregated dicts of recordings get different labels depending on their underlying `property`s""" - - recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) - recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) - - # If we don't label at all, aggregation will add a 'group' label based on the dict keys - aggregated_recording = aggregate_channels({0: recording1, "cat": recording2}) - group_property = aggregated_recording.get_property("group") - assert np.all(group_property == [0, 0, 0, 0, "cat", "cat"]) - - # If we have different group labels, these should be respected - recording1.set_property("group", [2, 2, 2, 2]) - recording2.set_property("group", [6, 6]) - aggregated_recording = aggregate_channels({0: recording1, "cat": recording2}) - group_property = aggregated_recording.get_property("group") - assert np.all(group_property == [2, 2, 2, 2, 6, 6]) - - # If we use `split_by`, aggregation should retain the split_by property, even if we pass a different dict - recording1.set_property("user_group", [6, 7, 6, 7]) - recordings_dict = recording1.split_by("user_group") - aggregated_recording = aggregate_channels(recordings_dict) - group_property = aggregated_recording.get_property("group") - assert np.all(group_property == [2, 2, 2, 2]) - user_group_property = aggregated_recording.get_property("user_group") - # Note, aggregation reorders the channel_ids into the order of the ids of each individual recording - assert np.all(user_group_property == [6, 6, 7, 7]) - - def test_channel_aggregation_does_not_preserve_ids_if_not_unique(): recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index e40e4bf877..346f3e1aea 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -156,7 +156,6 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - from kilosort import __version__ as ks_version from kilosort.run_kilosort import ( set_files, initialize_ops, @@ -179,9 +178,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) - if version.parse(cls.get_sorter_version()) < version.parse("4.0.16"): + if version.parse(cls.get_sorter_version()) < version.parse("4.0.5"): raise RuntimeError( - "Kilosort versions before 4.0.16 are not supported" + "Kilosort versions before 4.0.5 are not supported" "in SpikeInterface. " "Please upgrade Kilosort version." ) @@ -315,7 +314,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Skipping drift correction.") ops["nblocks"] = 0 - drift_kwargs = dict( + # this function applies both preprocessing and drift correction + ops, bfile, st0 = compute_drift_correction( ops=ops, device=device, tic0=tic0, @@ -323,29 +323,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): file_object=file_object, clear_cache=clear_cache, ) - if version.parse(ks_version) >= version.parse("4.0.28"): - drift_kwargs.update(dict(verbose=verbose)) - - # this function applies both preprocessing and drift correction - ops, bfile, st0 = compute_drift_correction(**drift_kwargs) if save_preprocessed_copy: save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) # Sort spikes and save results - detect_spikes_kwargs = dict( - ops=ops, - device=device, - bfile=bfile, - tic0=tic0, - progress_bar=progress_bar, - clear_cache=clear_cache, + st, tF, _, _ = detect_spikes( + ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache ) - if version.parse(ks_version) >= version.parse("4.0.28"): - detect_spikes_kwargs.update(dict(verbose=verbose)) - st, tF, _, _ = detect_spikes(**detect_spikes_kwargs) - cluster_spikes_kwargs = dict( + clu, Wall = cluster_spikes( st=st, tF=tF, ops=ops, @@ -355,9 +342,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): progress_bar=progress_bar, clear_cache=clear_cache, ) - if version.parse(ks_version) >= version.parse("4.0.28"): - cluster_spikes_kwargs.update(dict(verbose=verbose)) - clu, Wall = cluster_spikes(**cluster_spikes_kwargs) if params["skip_kilosort_preprocessing"]: ops["preprocessing"] = dict( diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 67322398fb..d608f3ae8d 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -16,7 +16,7 @@ from spikeinterface.core import SortingAnalyzer -_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violations"] +_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violation"] class SortingSummaryWidget(BaseWidget): From 0b345c6101da0838c95717a88fb89bf158cc568a Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 15:14:58 +0000 Subject: [PATCH 16/20] Revert "Merge branch 'main' into group-splitting-in-pp" This reverts commit 0d5cf229b4f8f851fbdf9f1dbc6d90cf0fecaedf, reversing changes made to 6a800c3070e1e88bd6859ee92fbaf4f6b2e2771a. --- src/spikeinterface/widgets/peak_activity.py | 98 +++------------------ 1 file changed, 14 insertions(+), 84 deletions(-) diff --git a/src/spikeinterface/widgets/peak_activity.py b/src/spikeinterface/widgets/peak_activity.py index 45a1db138a..f611927813 100644 --- a/src/spikeinterface/widgets/peak_activity.py +++ b/src/spikeinterface/widgets/peak_activity.py @@ -16,8 +16,11 @@ class PeakActivityMapWidget(BaseWidget): ---------- recording : RecordingExtractor The recording extractor object. - peaks : numpy array with peak_dtype - The pre detected peaks (with the `detect_peaks()` function). + peaks : None or numpy array + Optionally can give already detected peaks + to avoid multiple computation. + detect_peaks_kwargs : None or dict, default: None + If peaks is None here the kwargs for detect_peak function. bin_duration_s : None or float, default: None If None then static image If not None then it is an animation per bin. @@ -27,10 +30,8 @@ class PeakActivityMapWidget(BaseWidget): Plot rates with interpolated map with_channel_ids : bool, default: False Add channel ids text on the probe - color_range : tuple | list | None, default: None - Sets the color bar range when animating or plotting. - When None, uses the min-max of the entire time-series via imshow defaults. - If tuple/list, the length must be 2 representing the range. + + """ def __init__( @@ -42,7 +43,6 @@ def __init__( with_interpolated_map=True, with_channel_ids=False, with_color_bar=True, - color_range=None, backend=None, **backend_kwargs, ): @@ -54,7 +54,6 @@ def __init__( with_interpolated_map=with_interpolated_map, with_channel_ids=with_channel_ids, with_color_bar=with_color_bar, - color_range=color_range, ) BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) @@ -64,6 +63,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) rec = dp.recording @@ -79,52 +81,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) probe = probes[0] - if dp.color_range is not None: - assert isinstance(dp.color_range, (tuple, list)), "color_range must be a tuple/list" - assert len(dp.color_range) == 2, "color_range must be a tuple/list of length 2 representing range" - vmin, vmax = dp.color_range - else: - vmin, vmax = None, None - if dp.bin_duration_s is None: - # plot aggregated activity map self._plot_one_bin( - rec, - probe, - peaks, - duration, - dp.with_channel_ids, - dp.with_contact_color, - dp.with_interpolated_map, - dp.with_color_bar, - vmin=vmin, - vmax=vmax, + rec, probe, peaks, duration, dp.with_channel_ids, dp.with_contact_color, dp.with_interpolated_map ) else: - # plot animated activity map bin_size = int(dp.bin_duration_s * fs) num_frames = int(duration / dp.bin_duration_s) - # Compute max values across all time bins if needed - if vmin is None: - all_rates = [] - for i in range(num_frames): - i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) - local_peaks = peaks[i0:i1] - rates = self._compute_rates(rec, local_peaks, dp.bin_duration_s) - all_rates.append(rates) - all_rates = np.concatenate(all_rates) - vmin, vmax = np.min(all_rates), np.max(all_rates) - - # Create a colorbar once - dummy_image = self.ax.imshow([[0, 1]], visible=False, aspect="auto") - self.cbar = self.figure.colorbar(dummy_image, ax=self.ax, label="Peaks (Hz)") - - # Create a text artist for displaying time - self.time_text = self.ax.text(0.02, 0.98, "", transform=self.ax.transAxes, va="top", ha="left") - def animate_func(i): - self.ax.clear() i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] artists = self._plot_one_bin( @@ -135,49 +100,19 @@ def animate_func(i): with_channel_ids=dp.with_channel_ids, with_contact_color=dp.with_contact_color, with_interpolated_map=dp.with_interpolated_map, - with_color_bar=False, - vmin=vmin, - vmax=vmax, ) - - # Update colorbar - if artists and isinstance(artists[-1], plt.matplotlib.image.AxesImage): - self.cbar.update_normal(artists[-1]) - - # Update time text - current_time = i * dp.bin_duration_s - self.time_text.set_text(f"Time: {current_time:.2f} s") - self.ax.add_artist(self.time_text) - - artists += (self.time_text,) return artists from matplotlib.animation import FuncAnimation - self.animation = FuncAnimation(self.figure, animate_func, frames=num_frames, interval=100, blit=False) + self.animation = FuncAnimation(self.figure, animate_func, frames=num_frames, interval=100, blit=True) - def _compute_rates(self, rec, peaks, duration): + def _plot_one_bin(self, rec, probe, peaks, duration, with_channel_ids, with_contact_color, with_interpolated_map): rates = np.zeros(rec.get_num_channels(), dtype="float64") for chan_ind, chan_id in enumerate(rec.channel_ids): mask = peaks["channel_index"] == chan_ind num_spike = np.sum(mask) rates[chan_ind] = num_spike / duration - return rates - - def _plot_one_bin( - self, - rec, - probe, - peaks, - duration, - with_channel_ids, - with_contact_color, - with_interpolated_map, - with_color_bar, - vmin=None, - vmax=None, - ): - rates = self._compute_rates(rec, peaks, duration) artists = () if with_contact_color: @@ -195,18 +130,13 @@ def _plot_one_bin( contacts_kargs={"alpha": 1.0}, text_on_contact=text_on_contact, ) - if vmin is not None and vmax is not None: - poly.set_clim(vmin, vmax) artists = artists + (poly, poly_contour) if with_interpolated_map: image, xlims, ylims = probe.to_image( rates, pixel_size=0.5, num_pixel=None, method="linear", xlims=None, ylims=None ) - im = self.ax.imshow(image, extent=xlims + ylims, origin="lower", alpha=0.5, vmin=vmin, vmax=vmax) + im = self.ax.imshow(image, extent=xlims + ylims, origin="lower", alpha=0.5) artists = artists + (im,) - if with_color_bar: - self.cbar = self.figure.colorbar(im, ax=self.ax, label=f"Peak rate (Hz)") - return artists From 251ae0a9f5b162084f546edf761879d397b40d10 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 15:15:14 +0000 Subject: [PATCH 17/20] Revert "deepinterpolation bug" This reverts commit 5c2c7a363ab9b693ce447c63285d2c70a7706d71. --- .../preprocessing/deepinterpolation/deepinterpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index aa3f9a34cc..6d7c227534 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -5,7 +5,7 @@ from packaging.version import parse from .tf_utils import has_tf, import_tf -from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment From e6d196d529d38b9cd52179efaf92f24ceb6e14f1 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 15:16:53 +0000 Subject: [PATCH 18/20] Revert "update docs" This reverts commit c50a773602742f104cfc5dc77d546bda700f8646. --- doc/how_to/process_by_channel_group.rst | 34 +++++++++++++++++-------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index 2b9e0e12a9..08a87ab738 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -87,28 +87,40 @@ Splitting a recording by channel group returns a dictionary containing separate Preprocessing a Recording by Channel Group ------------------------------------------ -If a preprocessing function is given a dictionary of recordings, it will apply the preprocessing -seperately to each recording in the dict, and return a dictionary of preprocessed recordings. -Hence we can pass the ``split_recording_dict`` in the same way as we would pass a single recording -to any preprocessing function. +The essence of preprocessing by channel group is to first split the recording +into separate recordings, perform the preprocessing steps, then aggregate +the channels back together. + +In the below example, we loop over the split recordings, preprocessing each channel group +individually. At the end, we use the :py:func:`~aggregate_channels` function +to combine the separate channel group recordings back together. .. code-block:: python - shifted_recordings = spre.phase_shift(split_recording_dict) - filtered_recording = spre.bandpass_filter(shifted_recording) - referenced_recording = spre.common_reference(filtered_recording) + preprocessed_recordings = [] -We can then aggregate the recordings back together using the ``aggregate_channels`` function + # loop over the recordings contained in the dictionary + for chan_group_rec in split_recordings_dict.values(): -.. code-block:: python + # Apply the preprocessing steps to the channel group in isolation + shifted_recording = spre.phase_shift(chan_group_rec) + + filtered_recording = spre.bandpass_filter(shifted_recording) - combined_preprocessed_recording = aggregate_channels(referenced_recording) + referenced_recording = spre.common_reference(filtered_recording) -Now, when ``combined_preprocessed_recording`` is used in sorting, plotting, or whenever + preprocessed_recordings.append(referenced_recording) + + # Combine our preprocessed channel groups back together + combined_preprocessed_recording = aggregate_channels(preprocessed_recordings) + +Now, when this recording is used in sorting, plotting, or whenever calling its :py:func:`~get_traces` method, the data will have been preprocessed separately per-channel group (then concatenated back together under the hood). +It is strongly recommended to use the above structure to preprocess by channel group. + .. note:: The splitting and aggregation of channels for preprocessing is flexible. From cfa56612a73f665dc775885ade27953dfcaaa4cd Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 15:17:41 +0000 Subject: [PATCH 19/20] Revert "Merge branch 'main' into group-splitting-in-pp" This reverts commit f53150ab0741849cb0fd32fc2f3ef15370130380, reversing changes made to eb9bfe40b429ee67113cc63194c83d31468da22a. --- examples/how_to/drift_with_lfp.py | 2 +- src/spikeinterface/benchmark/benchmark_sorter.py | 4 ++-- src/spikeinterface/comparison/paircomparisons.py | 6 +++--- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/curation/auto_merge.py | 4 ++-- src/spikeinterface/curation/remove_excess_spikes.py | 4 ++-- src/spikeinterface/curation/remove_redundant.py | 4 ++-- .../extractors/cellexplorersortingextractor.py | 4 ++-- .../extractors/neoextractors/neo_utils.py | 2 +- .../extractors/sinapsrecordingextractors.py | 6 +++--- src/spikeinterface/generation/__init__.py | 2 +- .../postprocessing/amplitude_scalings.py | 2 +- src/spikeinterface/postprocessing/noise_level.py | 2 +- .../postprocessing/template_metrics.py | 6 +++--- .../postprocessing/template_similarity.py | 4 ++-- src/spikeinterface/postprocessing/unit_locations.py | 2 +- src/spikeinterface/preprocessing/astype.py | 2 +- src/spikeinterface/preprocessing/clip.py | 2 +- src/spikeinterface/preprocessing/common_reference.py | 2 +- src/spikeinterface/preprocessing/correct_lsb.py | 2 +- src/spikeinterface/preprocessing/decimate.py | 2 +- .../deepinterpolation/deepinterpolation.py | 4 ++-- .../preprocessing/deepinterpolation/generators.py | 2 +- .../preprocessing/deepinterpolation/train.py | 2 +- src/spikeinterface/preprocessing/depth_order.py | 4 ++-- .../preprocessing/detect_bad_channels.py | 2 +- src/spikeinterface/preprocessing/filter.py | 2 +- src/spikeinterface/preprocessing/filter_opencl.py | 2 +- .../preprocessing/highpass_spatial_filter.py | 4 ++-- src/spikeinterface/preprocessing/normalize_scale.py | 2 +- src/spikeinterface/preprocessing/phase_shift.py | 2 +- src/spikeinterface/preprocessing/resample.py | 2 +- src/spikeinterface/preprocessing/silence_periods.py | 4 ++-- .../preprocessing/unsigned_to_signed.py | 3 ++- src/spikeinterface/preprocessing/whiten.py | 2 +- src/spikeinterface/qualitymetrics/misc_metrics.py | 8 ++++---- src/spikeinterface/qualitymetrics/pca_metrics.py | 4 ++-- src/spikeinterface/sorters/external/combinato.py | 4 ++-- src/spikeinterface/sorters/external/hdsort.py | 4 ++-- src/spikeinterface/sorters/external/herdingspikes.py | 2 +- src/spikeinterface/sorters/external/ironclust.py | 4 ++-- src/spikeinterface/sorters/external/kilosort.py | 4 ++-- src/spikeinterface/sorters/external/kilosort2.py | 4 ++-- src/spikeinterface/sorters/external/kilosort2_5.py | 4 ++-- src/spikeinterface/sorters/external/kilosort3.py | 4 ++-- src/spikeinterface/sorters/external/kilosort4.py | 6 +++--- src/spikeinterface/sorters/external/kilosortbase.py | 4 ++-- src/spikeinterface/sorters/external/klusta.py | 4 ++-- src/spikeinterface/sorters/external/mountainsort4.py | 2 +- src/spikeinterface/sorters/external/mountainsort5.py | 2 +- src/spikeinterface/sorters/external/pykilosort.py | 2 +- .../sorters/external/spyking_circus.py | 4 ++-- src/spikeinterface/sorters/external/tridesclous.py | 2 +- src/spikeinterface/sorters/external/waveclus.py | 4 ++-- .../sorters/external/waveclus_snippets.py | 4 ++-- src/spikeinterface/sorters/external/yass.py | 4 ++-- src/spikeinterface/sorters/runsorter.py | 6 +++--- .../sortingcomponents/motion/decentralized.py | 2 +- .../sortingcomponents/motion/dredge.py | 2 +- .../sortingcomponents/motion/iterative_template.py | 2 +- .../sortingcomponents/motion/motion_estimation.py | 2 +- .../sortingcomponents/peak_localization.py | 7 +++++-- .../widgets/all_amplitudes_distributions.py | 2 +- src/spikeinterface/widgets/amplitudes.py | 2 +- src/spikeinterface/widgets/base.py | 4 ++-- src/spikeinterface/widgets/crosscorrelograms.py | 6 +++--- src/spikeinterface/widgets/metrics.py | 2 +- src/spikeinterface/widgets/potential_merges.py | 2 +- src/spikeinterface/widgets/quality_metrics.py | 2 +- src/spikeinterface/widgets/sorting_summary.py | 2 +- src/spikeinterface/widgets/spike_locations.py | 2 +- .../widgets/spike_locations_by_time.py | 2 +- src/spikeinterface/widgets/spikes_on_traces.py | 12 ++++++------ src/spikeinterface/widgets/template_metrics.py | 2 +- src/spikeinterface/widgets/template_similarity.py | 2 +- src/spikeinterface/widgets/traces.py | 6 +++--- src/spikeinterface/widgets/unit_depths.py | 2 +- src/spikeinterface/widgets/unit_locations.py | 2 +- src/spikeinterface/widgets/unit_probe_map.py | 4 ++-- src/spikeinterface/widgets/unit_templates.py | 2 +- src/spikeinterface/widgets/unit_waveforms.py | 4 ++-- .../widgets/unit_waveforms_density_map.py | 2 +- src/spikeinterface/widgets/utils_sortingview.py | 4 ++-- 83 files changed, 138 insertions(+), 134 deletions(-) diff --git a/examples/how_to/drift_with_lfp.py b/examples/how_to/drift_with_lfp.py index bdb4e101e3..66a31bd6f2 100644 --- a/examples/how_to/drift_with_lfp.py +++ b/examples/how_to/drift_with_lfp.py @@ -47,7 +47,7 @@ # the dataset has been downloaded locally base_folder = Path("/mnt/data/sam/DataSpikeSorting/") -np_data_drift = base_folder / "human_neuropixel" / "Pt02" +np_data_drift = base_folder / 'human_neuropixel" / "Pt02" # ### Read the spikeglx file diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index 303ef2dd51..edf1d0f6cc 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -3,9 +3,9 @@ """ import numpy as np -from spikeinterface.core import NumpySorting +from ..core import NumpySorting from .benchmark_base import Benchmark, BenchmarkStudy -from spikeinterface.sorters import run_sorter +from ..sorters import run_sorter from spikeinterface.comparison import compare_sorter_to_ground_truth diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 7257411c10..e46ac74605 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -2,8 +2,8 @@ import numpy as np -from spikeinterface.core import BaseSorting -from spikeinterface.core.core_tools import define_function_from_class +from ..core import BaseSorting +from ..core.core_tools import define_function_from_class from .basecomparison import BasePairComparison, MixinSpikeTrainComparison, MixinTemplateComparison from .comparisontools import ( do_count_event, @@ -15,7 +15,7 @@ do_count_score, compute_performance, ) -from spikeinterface.postprocessing import compute_template_similarity_by_pair +from ..postprocessing import compute_template_similarity_by_pair class BasePairSorterComparison(BasePairComparison, MixinSpikeTrainComparison): diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 0edc2da49a..581a9eae53 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -882,7 +882,7 @@ def binary_compatible_with( return True def astype(self, dtype, round: bool | None = None): - from spikeinterface.preprocessing.astype import astype + from ..preprocessing.astype import astype return astype(self, dtype=dtype, round=round) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 8447728216..4c487315a4 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -13,8 +13,8 @@ except ImportError: HAVE_NUMBA = False -from spikeinterface.core import SortingAnalyzer -from spikeinterface.qualitymetrics import compute_refrac_period_violations, compute_firing_rates +from ..core import SortingAnalyzer +from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 8663b0fdbd..0d70e264a9 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -2,8 +2,8 @@ from typing import Optional import numpy as np -from spikeinterface.core import BaseSorting, BaseSortingSegment, BaseRecording -from spikeinterface.core.waveform_tools import has_exceeding_spikes +from ..core import BaseSorting, BaseSortingSegment, BaseRecording +from ..core.waveform_tools import has_exceeding_spikes class RemoveExcessSpikesSorting(BaseSorting): diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index 39fd462505..09c0b2f270 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -4,8 +4,8 @@ from spikeinterface import SortingAnalyzer -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift, get_template_amplitudes -from spikeinterface.postprocessing import align_sorting +from ..core.template_tools import get_template_extremum_channel_peak_shift, get_template_amplitudes +from ..postprocessing import align_sorting _remove_strategies = ("minimum_shift", "highest_amplitude", "max_spikes") diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 490ea61547..0dfa3a85ad 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -4,8 +4,8 @@ from pathlib import Path -from spikeinterface.core import BaseSorting, BaseSortingSegment -from spikeinterface.core.core_tools import define_function_from_class +from ..core import BaseSorting, BaseSortingSegment +from ..core.core_tools import define_function_from_class class CellExplorerSortingExtractor(BaseSorting): diff --git a/src/spikeinterface/extractors/neoextractors/neo_utils.py b/src/spikeinterface/extractors/neoextractors/neo_utils.py index 3de83ff607..ec6aae06c9 100644 --- a/src/spikeinterface/extractors/neoextractors/neo_utils.py +++ b/src/spikeinterface/extractors/neoextractors/neo_utils.py @@ -56,7 +56,7 @@ def get_neo_num_blocks(extractor_name, *args, **kwargs) -> int: def get_neo_extractor(extractor_name): - from spikeinterface.extractors.extractorlist import recording_extractor_full_dict + from ..extractorlist import recording_extractor_full_dict assert extractor_name in recording_extractor_full_dict, ( f"{extractor_name} not an extractor name:" f"\n{list(recording_extractor_full_dict.keys())}" diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index 31a2a81f82..c3e92a63ff 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -6,8 +6,8 @@ from probeinterface import get_probe -from spikeinterface.core import BaseRecording, BaseRecordingSegment, BinaryRecordingExtractor, ChannelSliceRecording -from spikeinterface.core.core_tools import define_function_from_class +from ..core import BaseRecording, BaseRecordingSegment, BinaryRecordingExtractor, ChannelSliceRecording +from ..core.core_tools import define_function_from_class class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): @@ -24,7 +24,7 @@ class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): """ def __init__(self, file_path: str | Path, stream_name: str = "filt"): - from spikeinterface.preprocessing import UnsignedToSignedRecording + from ..preprocessing import UnsignedToSignedRecording file_path = Path(file_path) meta_file = file_path.parent / f"metadata_{file_path.stem}.txt" diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index bf6158f031..5d18ce5676 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -31,7 +31,7 @@ ) # expose the core generate functions -from spikeinterface.core.generate import ( +from ..core.generate import ( generate_recording, generate_sorting, generate_snippets, diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 278151a930..298953c94a 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -11,7 +11,7 @@ from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type -from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore +from ..core.template_tools import get_dense_templates_array, _get_nbefore class ComputeAmplitudeScalings(AnalyzerExtension): diff --git a/src/spikeinterface/postprocessing/noise_level.py b/src/spikeinterface/postprocessing/noise_level.py index 59ede21fcb..a168f34c7b 100644 --- a/src/spikeinterface/postprocessing/noise_level.py +++ b/src/spikeinterface/postprocessing/noise_level.py @@ -1,3 +1,3 @@ # "noise_levels" extensions is now in core # this is kept name space compatibility but should be removed soon -from spikeinterface.core.analyzer_extension_core import ComputeNoiseLevels, compute_noise_levels +from ..core.analyzer_extension_core import ComputeNoiseLevels, compute_noise_levels diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 3f66692a02..3055888e82 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -10,9 +10,9 @@ import warnings from copy import deepcopy -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.template_tools import get_template_extremum_channel -from spikeinterface.core.template_tools import get_dense_templates_array +from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension +from ..core.template_tools import get_template_extremum_channel +from ..core.template_tools import get_dense_templates_array # DEBUG = False diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1928e12edc..6c30e2730b 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -4,8 +4,8 @@ import warnings from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.template_tools import get_dense_templates_array -from spikeinterface.core.sparsity import ChannelSparsity +from ..core.template_tools import get_dense_templates_array +from ..core.sparsity import ChannelSparsity try: import numba diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 5618499770..df19458316 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -3,7 +3,7 @@ import numpy as np import warnings -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension +from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension from .localization_tools import _unit_location_methods diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index 89b9a3aaaa..cace73a0a3 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_from_class +from ..core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from .filter import fix_dtype diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 785da185b4..549706bee5 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -5,7 +5,7 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_random_data_chunks +from ..core import get_random_data_chunks class ClipRecording(BasePreprocessor): diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 1a532d4410..3b572cf90f 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -5,7 +5,7 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_closest_channels +from ..core import get_closest_channels from spikeinterface.core.baserecording import BaseRecording from .filter import fix_dtype diff --git a/src/spikeinterface/preprocessing/correct_lsb.py b/src/spikeinterface/preprocessing/correct_lsb.py index 714d5803ea..a8d21b165f 100644 --- a/src/spikeinterface/preprocessing/correct_lsb.py +++ b/src/spikeinterface/preprocessing/correct_lsb.py @@ -4,7 +4,7 @@ import numpy as np from .normalize_scale import scale -from spikeinterface.core import get_random_data_chunks +from ..core import get_random_data_chunks def correct_lsb(recording, num_chunks_per_segment=20, chunk_size=10000, seed=None, verbose=False): diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 8a2ec1f839..e857d6edf3 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -7,7 +7,7 @@ from .basepreprocessor import BasePreprocessor from .filter import fix_dtype -from spikeinterface.core import BaseRecordingSegment +from ..core import BaseRecordingSegment class DecimateRecording(BasePreprocessor): diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 6d7c227534..5d9beac047 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -5,8 +5,8 @@ from packaging.version import parse from .tf_utils import has_tf, import_tf -from spikeinterface.core.core_tools import define_function_from_class -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from ...core.core_tools import define_function_handling_dict_from_class +from ..basepreprocessor import BasePreprocessor, BasePreprocessorSegment class DeepInterpolatedRecording(BasePreprocessor): diff --git a/src/spikeinterface/preprocessing/deepinterpolation/generators.py b/src/spikeinterface/preprocessing/deepinterpolation/generators.py index 4a0ecd7a16..9f00ab9bbf 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/generators.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/generators.py @@ -2,7 +2,7 @@ from typing import Optional import numpy as np -from spikeinterface.core import concatenate_recordings, BaseRecording, BaseRecordingSegment +from ...core import concatenate_recordings, BaseRecording, BaseRecordingSegment from deepinterpolation.generator_collection import SequentialGenerator diff --git a/src/spikeinterface/preprocessing/deepinterpolation/train.py b/src/spikeinterface/preprocessing/deepinterpolation/train.py index b05cd3ebd3..71b9e2c3f5 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/train.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/train.py @@ -8,7 +8,7 @@ import multiprocessing as mp from .tf_utils import import_tf -from spikeinterface.core import BaseRecording +from ...core import BaseRecording global train_func diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 0587841667..f9c1f095f5 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -1,7 +1,7 @@ from __future__ import annotations -from spikeinterface.core import order_channels_by_depth, ChannelSliceRecording -from spikeinterface.core.core_tools import define_function_from_class +from ..core import order_channels_by_depth, ChannelSliceRecording +from ..core.core_tools import define_function_handling_dict_from_class class DepthOrderRecording(ChannelSliceRecording): diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 2175351f0b..5d8f7107c7 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -5,7 +5,7 @@ from typing import Literal from .filter import highpass_filter -from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording +from ..core import get_random_data_chunks, order_channels_by_depth, BaseRecording def detect_bad_channels( diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 42b7090c0d..66d757735a 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -5,7 +5,7 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_chunk_with_margin +from ..core import get_chunk_with_margin _common_filter_docs = """**filter_kwargs : dict diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 1f4e18663b..903fef0b6e 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -5,7 +5,7 @@ from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_chunk_with_margin +from ..core import get_chunk_with_margin try: import pyopencl diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index 1380a41ffc..d2074bce43 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -4,8 +4,8 @@ from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from .filter import fix_dtype -from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin -from spikeinterface.core.core_tools import define_function_from_class +from ..core import order_channels_by_depth, get_chunk_with_margin +from ..core.core_tools import define_function_handling_dict_from_class class HighpassSpatialFilterRecording(BasePreprocessor): diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 62d1ea0cf0..c53aceb17a 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -8,7 +8,7 @@ from .filter import fix_dtype -from spikeinterface.core import get_random_data_chunks +from ..core import get_random_data_chunks class ScaleRecordingSegment(BasePreprocessorSegment): diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 872793a30e..5eccdace8f 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -4,7 +4,7 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from spikeinterface.core import get_chunk_with_margin +from ..core import get_chunk_with_margin from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 0fbf8e54e0..ccab7df332 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -10,7 +10,7 @@ from .basepreprocessor import BasePreprocessor from .filter import fix_dtype -from spikeinterface.core import get_chunk_with_margin, BaseRecordingSegment +from ..core import get_chunk_with_margin, BaseRecordingSegment class ResampleRecording(BasePreprocessor): diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index fc1665dbd8..1921835b37 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -5,8 +5,8 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_random_data_chunks, get_noise_levels -from spikeinterface.core.generate import NoiseGeneratorRecording +from ..core import get_random_data_chunks, get_noise_levels +from ..core.generate import NoiseGeneratorRecording class SilencedPeriodsRecording(BasePreprocessor): diff --git a/src/spikeinterface/preprocessing/unsigned_to_signed.py b/src/spikeinterface/preprocessing/unsigned_to_signed.py index c209fff224..c5291da358 100644 --- a/src/spikeinterface/preprocessing/unsigned_to_signed.py +++ b/src/spikeinterface/preprocessing/unsigned_to_signed.py @@ -1,7 +1,8 @@ from __future__ import annotations import numpy as np -from spikeinterface.core.core_tools import define_function_from_class + +from ..core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 7cb16e82bf..7768cbb948 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -5,7 +5,7 @@ from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from spikeinterface.core import get_random_data_chunks, get_channel_distances +from ..core import get_random_data_chunks, get_channel_distances from .filter import fix_dtype diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 91d237d8ef..6464a99afa 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -17,9 +17,9 @@ import warnings from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs -from spikeinterface.postprocessing import correlogram_for_one_segment -from spikeinterface.core import SortingAnalyzer, get_noise_levels -from spikeinterface.core.template_tools import ( +from ..postprocessing import correlogram_for_one_segment +from ..core import SortingAnalyzer, get_noise_levels +from ..core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, @@ -1469,7 +1469,7 @@ def compute_sd_ratio( The number of spikes, across all segments, for each unit ID. """ import numba - from spikeinterface.curation.curation_tools import _find_duplicated_spikes_keep_first_iterative + from ..curation.curation_tools import _find_duplicated_spikes_keep_first_iterative kwargs, job_kwargs = split_job_kwargs(kwargs) job_kwargs = fix_job_kwargs(job_kwargs) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index f4e36b24c0..b1f61ce2b2 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -16,8 +16,8 @@ from .misc_metrics import compute_num_spikes, compute_firing_rates -from spikeinterface.core import get_random_data_chunks, compute_sparsity -from spikeinterface.core.template_tools import get_template_extremum_channel +from ..core import get_random_data_chunks, compute_sparsity +from ..core.template_tools import get_template_extremum_channel _possible_pc_metric_names = [ "isolation_distance", diff --git a/src/spikeinterface/sorters/external/combinato.py b/src/spikeinterface/sorters/external/combinato.py index d5d6114b1b..082c1d172e 100644 --- a/src/spikeinterface/sorters/external/combinato.py +++ b/src/spikeinterface/sorters/external/combinato.py @@ -6,9 +6,9 @@ import sys import json -from spikeinterface.sorters.utils import ShellScript +from ..utils import ShellScript from spikeinterface.core import write_to_h5_dataset_format -from spikeinterface.sorters.basesorter import BaseSorter +from ..basesorter import BaseSorter from spikeinterface.extractors import CombinatoSortingExtractor from spikeinterface.preprocessing import ScaleRecording diff --git a/src/spikeinterface/sorters/external/hdsort.py b/src/spikeinterface/sorters/external/hdsort.py index c5f510a631..fb8e09186d 100644 --- a/src/spikeinterface/sorters/external/hdsort.py +++ b/src/spikeinterface/sorters/external/hdsort.py @@ -9,8 +9,8 @@ import numpy as np from spikeinterface.core import write_to_h5_dataset_format -from spikeinterface.sorters.basesorter import BaseSorter -from spikeinterface.sorters.utils import ShellScript +from ..basesorter import BaseSorter +from ..utils import ShellScript # from spikeinterface.extractors import MaxOneRecordingExtractor from spikeinterface.extractors import HDSortSortingExtractor diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index 17964f3064..94d66e7f86 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -3,7 +3,7 @@ from pathlib import Path from packaging import version -from spikeinterface.sorters.basesorter import BaseSorter +from ..basesorter import BaseSorter from spikeinterface.extractors import HerdingspikesSortingExtractor diff --git a/src/spikeinterface/sorters/external/ironclust.py b/src/spikeinterface/sorters/external/ironclust.py index 4373cad254..18b764eda2 100644 --- a/src/spikeinterface/sorters/external/ironclust.py +++ b/src/spikeinterface/sorters/external/ironclust.py @@ -5,8 +5,8 @@ from typing import Union import sys -from spikeinterface.sorters.utils import ShellScript -from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs +from ..utils import ShellScript +from ..basesorter import BaseSorter, get_job_kwargs from spikeinterface.extractors import MdaRecordingExtractor, MdaSortingExtractor diff --git a/src/spikeinterface/sorters/external/kilosort.py b/src/spikeinterface/sorters/external/kilosort.py index a94c35f84b..2beb802f63 100644 --- a/src/spikeinterface/sorters/external/kilosort.py +++ b/src/spikeinterface/sorters/external/kilosort.py @@ -5,9 +5,9 @@ from typing import Union import numpy as np -from spikeinterface.sorters.basesorter import BaseSorter +from ..basesorter import BaseSorter from .kilosortbase import KilosortBase -from spikeinterface.sorters.utils import get_git_commit +from ..utils import get_git_commit def check_if_installed(kilosort_path: Union[str, None]): diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index cf79098ae7..643769b6f9 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -4,9 +4,9 @@ import os from typing import Union -from spikeinterface.sorters.basesorter import BaseSorter +from ..basesorter import BaseSorter from .kilosortbase import KilosortBase -from spikeinterface.sorters.utils import get_git_commit +from ..utils import get_git_commit PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index 21c2cf94ad..df8f4e6873 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -4,9 +4,9 @@ import os from typing import Union -from spikeinterface.sorters.basesorter import BaseSorter +from ..basesorter import BaseSorter from .kilosortbase import KilosortBase -from spikeinterface.sorters.utils import get_git_commit +from ..utils import get_git_commit PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index af0b840d8a..3681b036a2 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -4,9 +4,9 @@ import os from typing import Union -from spikeinterface.sorters.basesorter import BaseSorter +from ..basesorter import BaseSorter from .kilosortbase import KilosortBase -from spikeinterface.sorters.utils import get_git_commit +from ..utils import get_git_commit PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 346f3e1aea..c08500b7a0 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,10 +6,10 @@ from packaging import version -from spikeinterface.core import write_binary_recording -from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs +from ...core import write_binary_recording +from ..basesorter import BaseSorter, get_job_kwargs from .kilosortbase import KilosortBase -from spikeinterface.sorters.basesorter import get_job_kwargs +from ..basesorter import get_job_kwargs from importlib.metadata import version as importlib_version PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 0e185a4051..4407cf4b69 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -8,8 +8,8 @@ import numpy as np -from spikeinterface.sorters.utils import ShellScript, get_matlab_shell_name, get_bash_path -from spikeinterface.sorters.basesorter import get_job_kwargs +from ..utils import ShellScript, get_matlab_shell_name, get_bash_path +from ..basesorter import get_job_kwargs from spikeinterface.extractors import KiloSortSortingExtractor from spikeinterface.core import write_binary_recording from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording diff --git a/src/spikeinterface/sorters/external/klusta.py b/src/spikeinterface/sorters/external/klusta.py index bc51d80537..db23821ef5 100644 --- a/src/spikeinterface/sorters/external/klusta.py +++ b/src/spikeinterface/sorters/external/klusta.py @@ -5,8 +5,8 @@ import sys import shutil -from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs -from spikeinterface.sorters.utils import ShellScript +from ..basesorter import BaseSorter, get_job_kwargs +from ..utils import ShellScript from probeinterface import write_prb diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 3b57c3870b..4553ae534e 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -6,7 +6,7 @@ from spikeinterface.preprocessing import bandpass_filter, whiten -from spikeinterface.sorters.basesorter import BaseSorter +from ..basesorter import BaseSorter from spikeinterface.core.old_api_utils import NewToOldRecording from spikeinterface.core import load_extractor diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 7422829402..cf6933c9e6 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -10,7 +10,7 @@ from spikeinterface.preprocessing import bandpass_filter, whiten from spikeinterface.core.baserecording import BaseRecording -from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs +from ..basesorter import BaseSorter, get_job_kwargs from spikeinterface.extractors import NpzSortingExtractor diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index f73ac7257a..9d0aab9702 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -8,7 +8,7 @@ from spikeinterface.extractors import KiloSortSortingExtractor from spikeinterface.core import write_binary_recording import json -from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs +from ..basesorter import BaseSorter, get_job_kwargs class PyKilosortSorter(BaseSorter): diff --git a/src/spikeinterface/sorters/external/spyking_circus.py b/src/spikeinterface/sorters/external/spyking_circus.py index 693fc046da..826ffb244a 100644 --- a/src/spikeinterface/sorters/external/spyking_circus.py +++ b/src/spikeinterface/sorters/external/spyking_circus.py @@ -8,8 +8,8 @@ import sys from spikeinterface.extractors import SpykingCircusSortingExtractor -from spikeinterface.sorters.basesorter import BaseSorter -from spikeinterface.sorters.utils import ShellScript +from ..basesorter import BaseSorter +from ..utils import ShellScript from probeinterface import write_prb diff --git a/src/spikeinterface/sorters/external/tridesclous.py b/src/spikeinterface/sorters/external/tridesclous.py index 5e094b9b41..e9a22d0951 100644 --- a/src/spikeinterface/sorters/external/tridesclous.py +++ b/src/spikeinterface/sorters/external/tridesclous.py @@ -10,7 +10,7 @@ from spikeinterface.extractors import TridesclousSortingExtractor -from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs +from ..basesorter import BaseSorter, get_job_kwargs from spikeinterface.core import write_binary_recording from probeinterface import write_prb diff --git a/src/spikeinterface/sorters/external/waveclus.py b/src/spikeinterface/sorters/external/waveclus.py index 4f8fb8e164..5926e07b98 100644 --- a/src/spikeinterface/sorters/external/waveclus.py +++ b/src/spikeinterface/sorters/external/waveclus.py @@ -8,8 +8,8 @@ import json -from spikeinterface.sorters.basesorter import BaseSorter -from spikeinterface.sorters.utils import ShellScript +from ..basesorter import BaseSorter +from ..utils import ShellScript from spikeinterface.core import write_to_h5_dataset_format from spikeinterface.extractors import WaveClusSortingExtractor diff --git a/src/spikeinterface/sorters/external/waveclus_snippets.py b/src/spikeinterface/sorters/external/waveclus_snippets.py index 20d45b8c7f..411a6ef8f1 100644 --- a/src/spikeinterface/sorters/external/waveclus_snippets.py +++ b/src/spikeinterface/sorters/external/waveclus_snippets.py @@ -8,8 +8,8 @@ import json -from spikeinterface.sorters.basesorter import BaseSorter -from spikeinterface.sorters.utils import ShellScript +from ..basesorter import BaseSorter +from ..utils import ShellScript from spikeinterface.extractors import WaveClusSortingExtractor from spikeinterface.extractors import WaveClusSnippetsExtractor diff --git a/src/spikeinterface/sorters/external/yass.py b/src/spikeinterface/sorters/external/yass.py index 8c750ad287..9f948a3a79 100644 --- a/src/spikeinterface/sorters/external/yass.py +++ b/src/spikeinterface/sorters/external/yass.py @@ -5,8 +5,8 @@ import numpy as np import sys -from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs -from spikeinterface.sorters.utils import ShellScript +from ..basesorter import BaseSorter, get_job_kwargs +from ..utils import ShellScript from spikeinterface.core import write_binary_recording from spikeinterface.extractors import YassSortingExtractor diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index e25260dd58..d536d2480a 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -13,11 +13,11 @@ import spikeinterface -from spikeinterface import __version__ as si_version +from .. import __version__ as si_version -from spikeinterface.core import BaseRecording, NumpySorting, load -from spikeinterface.core.core_tools import check_json, is_editable_mode +from ..core import BaseRecording, NumpySorting, load +from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict from .utils import ( SpikeSortingError, diff --git a/src/spikeinterface/sortingcomponents/motion/decentralized.py b/src/spikeinterface/sortingcomponents/motion/decentralized.py index 956f23efba..f315b109a7 100644 --- a/src/spikeinterface/sortingcomponents/motion/decentralized.py +++ b/src/spikeinterface/sortingcomponents/motion/decentralized.py @@ -2,7 +2,7 @@ from tqdm.auto import tqdm, trange -from spikeinterface.core.motion import Motion +from ...core.motion import Motion from .motion_utils import get_spatial_windows, get_spatial_bin_edges, make_2d_motion_histogram, scipy_conv1d from .dredge import normxcorr1d diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index 0b1fdeeed2..0475fda2bc 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -28,7 +28,7 @@ import numpy as np from tqdm.auto import trange -from spikeinterface.core.motion import Motion +from ...core.motion import Motion from .motion_utils import ( get_spatial_bin_edges, get_spatial_windows, diff --git a/src/spikeinterface/sortingcomponents/motion/iterative_template.py b/src/spikeinterface/sortingcomponents/motion/iterative_template.py index 7bb067b5bd..59f456a55d 100644 --- a/src/spikeinterface/sortingcomponents/motion/iterative_template.py +++ b/src/spikeinterface/sortingcomponents/motion/iterative_template.py @@ -1,6 +1,6 @@ import numpy as np -from spikeinterface.core.motion import Motion +from ...core.motion import Motion from .motion_utils import get_spatial_windows, get_spatial_bin_edges, make_3d_motion_histograms diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index 040737390e..d10734daaa 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -6,7 +6,7 @@ from spikeinterface.sortingcomponents.tools import make_multi_method_doc -from spikeinterface.core.motion import Motion +from ...core.motion import Motion from .decentralized import DecentralizedRegistration from .iterative_template import IterativeTemplateRegistration from .dredge import DredgeLfpRegistration, DredgeApRegistration diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index d2ba69103a..1e4e0edded 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -21,9 +21,12 @@ from spikeinterface.core import get_channel_distances -from spikeinterface.postprocessing.unit_locations import dtype_localize_by_method, possible_localization_methods +from ..postprocessing.unit_locations import ( + dtype_localize_by_method, + possible_localization_methods, +) -from spikeinterface.postprocessing.localization_tools import ( +from ..postprocessing.localization_tools import ( make_radial_order_parents, solve_monopolar_triangulation, enforce_decrease_shells_data, diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 2f799db4ed..f6e38b3f5e 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -6,7 +6,7 @@ from .base import BaseWidget, to_attr from .utils import get_some_colors -from spikeinterface.core import SortingAnalyzer +from ..core import SortingAnalyzer class AllAmplitudesDistributionsWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 197fefbab2..956c0d3c11 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -7,7 +7,7 @@ from .base import BaseWidget, to_attr from .utils import get_some_colors -from spikeinterface.core.sortinganalyzer import SortingAnalyzer +from ..core.sortinganalyzer import SortingAnalyzer from spikeinterface.core import SortingAnalyzer diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 6f69e7ea66..9566989d31 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -5,8 +5,8 @@ global default_backend_ default_backend_ = "matplotlib" -from spikeinterface.core import SortingAnalyzer, BaseSorting -from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor +from ..core import SortingAnalyzer, BaseSorting +from ..core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor def get_default_plotter_backend(): diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index ee43f65852..ae07b79e6d 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -4,9 +4,9 @@ from typing import Union from .base import BaseWidget, to_attr -from spikeinterface.core.sortinganalyzer import SortingAnalyzer -from spikeinterface.core.basesorting import BaseSorting -from spikeinterface.postprocessing import compute_correlograms +from ..core.sortinganalyzer import SortingAnalyzer +from ..core.basesorting import BaseSorting +from ..postprocessing import compute_correlograms class CrossCorrelogramsWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index c0812e85ad..e1b1b423f2 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -5,7 +5,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core.core_tools import check_json +from ..core.core_tools import check_json class MetricsBaseWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py index 1762c68c40..be882209b8 100644 --- a/src/spikeinterface/widgets/potential_merges.py +++ b/src/spikeinterface/widgets/potential_merges.py @@ -11,7 +11,7 @@ from .utils import get_some_colors -from spikeinterface.core.sortinganalyzer import SortingAnalyzer +from ..core.sortinganalyzer import SortingAnalyzer class PotentialMergesWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 908c5345a0..7d1ff44326 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -1,7 +1,7 @@ from __future__ import annotations from .metrics import MetricsBaseWidget -from spikeinterface.core.sortinganalyzer import SortingAnalyzer +from ..core.sortinganalyzer import SortingAnalyzer class QualityMetricsWidget(MetricsBaseWidget): diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index d608f3ae8d..ed0bfc4180 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -13,7 +13,7 @@ from .unit_templates import UnitTemplatesWidget -from spikeinterface.core import SortingAnalyzer +from ..core import SortingAnalyzer _default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violation"] diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 4673fae2d1..ada1546ac6 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -4,7 +4,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core.sortinganalyzer import SortingAnalyzer +from ..core.sortinganalyzer import SortingAnalyzer class SpikeLocationsWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/spike_locations_by_time.py b/src/spikeinterface/widgets/spike_locations_by_time.py index 72524cf063..89cc6227fe 100644 --- a/src/spikeinterface/widgets/spike_locations_by_time.py +++ b/src/spikeinterface/widgets/spike_locations_by_time.py @@ -6,7 +6,7 @@ from .base import BaseWidget, to_attr from .utils import get_some_colors -from spikeinterface.core.sortinganalyzer import SortingAnalyzer +from ..core.sortinganalyzer import SortingAnalyzer class LocationsWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 35d0b988c4..f1d5891967 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -5,12 +5,12 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors from .traces import TracesWidget -from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel -from spikeinterface.core.sortinganalyzer import SortingAnalyzer -from spikeinterface.core.baserecording import BaseRecording -from spikeinterface.core.basesorting import BaseSorting -from spikeinterface.postprocessing import compute_unit_locations +from ..core import ChannelSparsity +from ..core.template_tools import get_template_extremum_channel +from ..core.sortinganalyzer import SortingAnalyzer +from ..core.baserecording import BaseRecording +from ..core.basesorting import BaseSorting +from ..postprocessing import compute_unit_locations class SpikesOnTracesWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index cc1b0118f2..4df719eda5 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -1,7 +1,7 @@ from __future__ import annotations from .metrics import MetricsBaseWidget -from spikeinterface.core.sortinganalyzer import SortingAnalyzer +from ..core.sortinganalyzer import SortingAnalyzer class TemplateMetricsWidget(MetricsBaseWidget): diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index aef620aa41..b469d9901f 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -3,7 +3,7 @@ import numpy as np from .base import BaseWidget, to_attr -from spikeinterface.core.sortinganalyzer import SortingAnalyzer +from ..core.sortinganalyzer import SortingAnalyzer class TemplateSimilarityWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index e7c3b5bec4..f944a4a80e 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -4,7 +4,7 @@ import numpy as np -from spikeinterface.core import BaseRecording +from ..core import BaseRecording from .base import BaseWidget, to_attr from .utils import get_some_colors, array_to_image @@ -107,7 +107,7 @@ def __init__( ) if order_channel_by_depth and rec0.has_channel_location(): - from spikeinterface.preprocessing import depth_order + from ..preprocessing import depth_order rec0 = depth_order(rec0) recordings = {k: depth_order(rec) for k, rec in recordings.items()} @@ -642,7 +642,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): def plot_ephyviewer(self, data_plot, **backend_kwargs): import ephyviewer - from spikeinterface.preprocessing import depth_order + from ..preprocessing import depth_order dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 98b8389714..d41982f766 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -7,7 +7,7 @@ from .utils import get_unit_colors -from spikeinterface.core.template_tools import get_template_extremum_amplitude +from ..core.template_tools import get_template_extremum_amplitude class UnitDepthsWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 569fc56989..37f9bb0491 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -7,7 +7,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core.sortinganalyzer import SortingAnalyzer +from ..core.sortinganalyzer import SortingAnalyzer class UnitLocationsWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index 22a8592a0c..bba4bd774e 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -8,8 +8,8 @@ from .base import BaseWidget, to_attr # from .utils import get_unit_colors -from spikeinterface.core.sortinganalyzer import SortingAnalyzer -from spikeinterface.core.template_tools import get_dense_templates_array +from ..core.sortinganalyzer import SortingAnalyzer +from ..core.template_tools import get_dense_templates_array class UnitProbeMapWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index d3ad0ada9f..1d0191092f 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -1,6 +1,6 @@ from __future__ import annotations -from spikeinterface.core import SortingAnalyzer +from ..core import SortingAnalyzer from .unit_waveforms import UnitWaveformsWidget from .base import to_attr diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 02cc39014b..ee2158d78e 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -6,8 +6,8 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core import ChannelSparsity, SortingAnalyzer, Templates -from spikeinterface.core.basesorting import BaseSorting +from ..core import ChannelSparsity, SortingAnalyzer, Templates +from ..core.basesorting import BaseSorting class UnitWaveformsWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9543cbf734..f7da0ef1f3 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -5,7 +5,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core import ChannelSparsity, get_template_extremum_channel +from ..core import ChannelSparsity, get_template_extremum_channel class UnitWaveformDensityMapWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 1e30a1f4c1..ea0112c9eb 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -4,8 +4,8 @@ import numpy as np -from spikeinterface.core import SortingAnalyzer, BaseSorting -from spikeinterface.core.core_tools import check_json +from ..core import SortingAnalyzer, BaseSorting +from ..core.core_tools import check_json from .utils import make_units_table_from_sorting, make_units_table_from_analyzer From 070896b82fde5914e5da03ba7fc4de32ec2f0566 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 14 Mar 2025 15:20:16 +0000 Subject: [PATCH 20/20] I did something bad --- .github/workflows/all-tests.yml | 23 +- .github/workflows/full-test-with-codecov.yml | 6 +- doc/api.rst | 5 +- doc/conf.py | 20 +- doc/how_to/load_your_data_into_sorting.rst | 10 +- doc/index.rst | 5 + doc/modules/comparison.rst | 2 +- doc/modules/core.rst | 9 +- doc/modules/index.rst | 1 - doc/modules/sorters.rst | 2 + .../core/plot_2_sorting_extractor.py | 2 +- .../curation/plot_2_train_a_model.py | 8 +- .../curation/plot_3_upload_a_model.py | 2 +- pyproject.toml | 11 +- .../benchmark/benchmark_base.py | 33 -- .../benchmark/benchmark_clustering.py | 115 +++++- .../benchmark/benchmark_matching.py | 47 ++- .../benchmark/benchmark_peak_localization.py | 2 +- .../benchmark/benchmark_peak_selection.py | 2 +- .../benchmark/benchmark_plot_tools.py | 274 ++++--------- .../benchmark/benchmark_sorter.py | 20 - .../comparison/basecomparison.py | 3 +- .../comparison/multicomparisons.py | 10 +- .../comparison/tests/test_comparisontools.py | 4 +- .../tests/test_groundtruthcomparison.py | 4 +- .../tests/test_multisortingcomparison.py | 16 +- .../tests/test_symmetricsortingcomparison.py | 4 +- src/spikeinterface/core/__init__.py | 2 - src/spikeinterface/core/baserecording.py | 7 + .../core/baserecordingsnippets.py | 2 +- .../core/channelsaggregationrecording.py | 13 +- src/spikeinterface/core/channelslice.py | 9 +- src/spikeinterface/core/generate.py | 4 +- src/spikeinterface/core/node_pipeline.py | 15 +- src/spikeinterface/core/numpyextractors.py | 59 +-- src/spikeinterface/core/sortinganalyzer.py | 33 +- src/spikeinterface/core/sparsity.py | 51 +-- .../test_channelsaggregationrecording.py | 20 - .../core/tests/test_numpy_extractors.py | 14 +- .../core/tests/test_segmentutils.py | 12 +- .../core/tests/test_sorting_tools.py | 10 +- .../core/tests/test_sparsity.py | 17 - .../curation/sortingview_curation.py | 39 +- .../tests/test_remove_excess_spikes.py | 2 +- .../tests/test_sortingview_curation.py | 6 +- .../curation/train_manual_curation.py | 42 +- .../extractors/iblextractors.py | 21 +- .../extractors/neoextractors/intan.py | 50 +-- .../neoextractors/neobaseextractor.py | 72 ++-- .../extractors/neoextractors/openephys.py | 2 +- .../extractors/tests/test_iblextractors.py | 15 +- src/spikeinterface/extractors/toy_example.py | 2 +- .../postprocessing/localization_tools.py | 4 +- .../postprocessing/spike_amplitudes.py | 1 + .../postprocessing/tests/test_correlograms.py | 12 +- .../preprocessing/remove_artifacts.py | 4 +- .../tests/test_metrics_functions.py | 2 +- .../sorters/internal/simplesorter.py | 2 +- .../sorters/internal/tridesclous2.py | 2 +- .../clustering/method_list.py | 8 - .../clustering/position_and_features.py | 2 +- .../tests/test_template_matching.py | 2 +- .../widgets/all_amplitudes_distributions.py | 5 +- src/spikeinterface/widgets/amplitudes.py | 194 +++++++-- .../widgets/crosscorrelograms.py | 5 +- src/spikeinterface/widgets/metrics.py | 5 +- src/spikeinterface/widgets/motion.py | 85 ++-- src/spikeinterface/widgets/quality_metrics.py | 5 +- src/spikeinterface/widgets/rasters.py | 377 +++--------------- src/spikeinterface/widgets/spike_locations.py | 5 +- .../widgets/spike_locations_by_time.py | 258 ------------ .../widgets/spikes_on_traces.py | 6 +- .../widgets/template_metrics.py | 5 +- .../widgets/tests/test_widgets.py | 24 -- src/spikeinterface/widgets/unit_depths.py | 5 +- src/spikeinterface/widgets/unit_locations.py | 5 +- src/spikeinterface/widgets/unit_summary.py | 17 +- src/spikeinterface/widgets/unit_waveforms.py | 10 +- .../widgets/unit_waveforms_density_map.py | 6 +- src/spikeinterface/widgets/widget_list.py | 25 +- 80 files changed, 778 insertions(+), 1462 deletions(-) delete mode 100644 src/spikeinterface/widgets/spike_locations_by_time.py diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 82fcc0fbdd..cfab49ef09 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -10,7 +10,9 @@ on: - main env: - KACHERY_API_KEY: ${{ secrets.KACHERY_API_KEY }} + KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} + KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + KACHERY_ZONE: ${{ secrets.KACHERY_ZONE }} concurrency: # Cancel previous workflows on the same pull request group: ${{ github.workflow }}-${{ github.ref }} @@ -74,9 +76,6 @@ jobs: pip install -e .[test_core] shell: bash - - name: Pip list - run: pip list - - name: Test core run: pytest -m "core" shell: bash @@ -137,7 +136,6 @@ jobs: if: env.RUN_EXTRACTORS_TESTS == 'true' run: | pip install -e .[extractors,streaming_extractors,test_extractors] - pip list ./.github/run_tests.sh "extractors and not streaming_extractors" --no-virtual-env - name: Test streaming extractors @@ -145,7 +143,6 @@ jobs: if: env.RUN_STREAMING_EXTRACTORS_TESTS == 'true' run: | pip install -e .[streaming_extractors,test_extractors] - pip list ./.github/run_tests.sh "streaming_extractors" --no-virtual-env - name: Test preprocessing @@ -153,21 +150,18 @@ jobs: if: env.RUN_PREPROCESSING_TESTS == 'true' run: | pip install -e .[preprocessing,test_preprocessing] - pip list ./.github/run_tests.sh "preprocessing and not deepinterpolation" --no-virtual-env - name: Install remaining testing dependencies # TODO: Remove this step once we have better modularization shell: bash run: | pip install -e .[test] - pip list - name: Test postprocessing shell: bash if: env.RUN_POSTPROCESSING_TESTS == 'true' run: | pip install -e .[full] - pip list ./.github/run_tests.sh postprocessing --no-virtual-env - name: Test quality metrics @@ -175,7 +169,6 @@ jobs: if: env.RUN_QUALITYMETRICS_TESTS == 'true' run: | pip install -e .[qualitymetrics] - pip list ./.github/run_tests.sh qualitymetrics --no-virtual-env - name: Test comparison @@ -183,7 +176,6 @@ jobs: if: env.RUN_COMPARISON_TESTS == 'true' run: | pip install -e .[full] - pip list ./.github/run_tests.sh comparison --no-virtual-env - name: Test core sorters @@ -191,7 +183,6 @@ jobs: if: env.RUN_SORTERS_TESTS == 'true' run: | pip install -e .[full] - pip list ./.github/run_tests.sh sorters --no-virtual-env - name: Test internal sorters @@ -199,7 +190,6 @@ jobs: if: env.RUN_INTERNAL_SORTERS_TESTS == 'true' run: | pip install -e .[full] - pip list ./.github/run_tests.sh sorters_internal --no-virtual-env - name: Test curation @@ -207,17 +197,13 @@ jobs: if: env.RUN_CURATION_TESTS == 'true' run: | pip install -e .[full] - pip list ./.github/run_tests.sh curation --no-virtual-env - name: Test widgets shell: bash if: env.RUN_WIDGETS_TESTS == 'true' - env: - KACHERY_ZONE: "scratch" run: | pip install -e .[full,widgets] - pip list ./.github/run_tests.sh widgets --no-virtual-env - name: Test exporters @@ -225,7 +211,6 @@ jobs: if: env.RUN_EXPORTERS_TESTS == 'true' run: | pip install -e .[full] - pip list ./.github/run_tests.sh exporters --no-virtual-env - name: Test sortingcomponents @@ -233,7 +218,6 @@ jobs: if: env.RUN_SORTINGCOMPONENTS_TESTS == 'true' run: | pip install -e .[full] - pip list ./.github/run_tests.sh sortingcomponents --no-virtual-env - name: Test generation @@ -241,5 +225,4 @@ jobs: if: env.RUN_GENERATION_TESTS == 'true' run: | pip install -e .[full] - pip list ./.github/run_tests.sh generation --no-virtual-env diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 9d56be8498..a53f5d2915 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -6,7 +6,9 @@ on: - cron: "0 12 * * *" # Daily at noon UTC env: - KACHERY_API_KEY: ${{ secrets.KACHERY_API_KEY }} + KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} + KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + KACHERY_ZONE: ${{ secrets.KACHERY_ZONE }} jobs: full-tests-with-codecov: @@ -39,8 +41,6 @@ jobs: restore-keys: ${{ runner.os }}-datasets - name: Install packages uses: ./.github/actions/build-test-environment - - name: Pip list - run: pip list - name: run tests env: HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell diff --git a/doc/api.rst b/doc/api.rst index 21770f2095..eb9a61eb9c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -22,8 +22,6 @@ spikeinterface.core .. autofunction:: estimate_sparsity .. autoclass:: ChannelSparsity :members: - .. autoclass:: Motion - :members: .. autoclass:: BinaryRecordingExtractor .. autoclass:: ZarrRecordingExtractor .. autoclass:: BinaryFolderRecording @@ -279,6 +277,8 @@ spikeinterface.comparison .. autoclass:: CollisionGTComparison .. autoclass:: CorrelogramGTComparison + .. autoclass:: CollisionGTStudy + .. autoclass:: CorrelogramGTStudy @@ -449,6 +449,7 @@ Motion Correction ~~~~~~~~~~~~~~~~~ .. automodule:: spikeinterface.sortingcomponents.motion + .. autoclass:: Motion .. autofunction:: estimate_motion .. autofunction:: interpolate_motion .. autofunction:: correct_motion_on_peaks diff --git a/doc/conf.py b/doc/conf.py index 742f46662e..d229dc18ee 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -31,12 +31,14 @@ folders = [ '../examples/tutorials/core/my_recording', '../examples/tutorials/core/my_sorting', - '../examples/tutorials/core/analyzer_folder', - '../examples/tutorials/core/analyzer_some_units', - '../examples/tutorials/core/analyzer.zarr', - '../examples/tutorials/curation/my_folder', - '../examples/tutorials/qualitymetrics/curated_sorting', - '../examples/tutorials/qualitymetrics/clean_analyzer.zarr', + '../examples/tutorials/core/waveform_folder', + '../examples/tutorials/core/waveform_folder_parallel', + '../examples/tutorials/core/waveform_folder_sparse', + '../examples/tutorials/core/waveform_folder_sparse_direct', + '../examples/tutorials/core/waveform_folder2', + '../examples/tutorials/core/waveform_folder', + '../examples/tutorials/qualitymetrics/waveforms_mearec', + '../examples/tutorials/qualitymetrics/wfs_mearec', '../examples/tutorials/widgets/waveforms_mearec', ] @@ -99,6 +101,7 @@ import sphinx_rtd_theme html_theme = "sphinx_rtd_theme" + html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] except ImportError: print("RTD theme not installed, using default") html_theme = 'alabaster' @@ -107,9 +110,8 @@ # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] -# html_css_files = ['custom.css'] -html_favicon = "images/logo.png" -html_logo = "images/logo.png" + +html_favicon = "images/favicon-32x32.png" from sphinx_gallery.sorting import ExplicitOrder diff --git a/doc/how_to/load_your_data_into_sorting.rst b/doc/how_to/load_your_data_into_sorting.rst index 4c5d27fefb..e250cfa6e9 100644 --- a/doc/how_to/load_your_data_into_sorting.rst +++ b/doc/how_to/load_your_data_into_sorting.rst @@ -65,8 +65,8 @@ the requested unit_ids). # in this case we are making a monosegment sorting # we have four spikes that are spread among two neurons - my_sorting = NumpySorting.from_samples_and_labels( - samples_list=[ + my_sorting = NumpySorting.from_times_labels( + times_list=[ np.array([1000,12000,15000,22000]) # Note these are samples/frames not times in seconds ], labels_list=[ @@ -120,7 +120,7 @@ Loading multisegment data into a :code:`Sorting` One of the great advantages of SpikeInterface :code:`Sorting` objects is that they can also handle multisegment recordings and sortings (e.g. you have a baseline, stimulus, post-stimulus). The exact same machinery can be used to generate your sorting, but in this case we do a list of arrays instead of -a single list. Let's go through one example for using :code:`from_samples_and_labels`: +a single list. Let's go through one example for using :code:`from_times_labels`: .. code-block:: python @@ -130,8 +130,8 @@ a single list. Let's go through one example for using :code:`from_samples_and_la # in this case we are making three-segment sorting # we have four spikes that are spread among two neurons # in each segment - my_sorting = NumpySorting.from_samples_and_labels( - samples_list=[ + my_sorting = NumpySorting.from_times_labels( + times_list=[ np.array([1000,12000,15000,22000]), np.array([30000,33000, 41000, 47000]), np.array([50000,53000,64000,70000]), diff --git a/doc/index.rst b/doc/index.rst index 2e98f15392..e6d8aa3fea 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -2,6 +2,11 @@ Welcome to SpikeInterface's documentation! ========================================== +.. image:: images/logo.png + :scale: 100 % + :align: center + + SpikeInterface is a Python module to analyze extracellular electrophysiology data. With a few lines of code, SpikeInterface enables you to load and pre-process the recording, run several diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index eb8b33edd0..a02d76664d 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -26,7 +26,7 @@ For spike train comparison, there are three use cases: A ground-truth dataset can be a paired recording, in which a neuron is recorded both extracellularly and with a patch or juxtacellular electrode (either **in vitro** or **in vivo**), or it can be a simulated dataset -(**in silico**) using spiking activity simulators such as `MEArec `_. +(**in silico**) using spiking activity simulators such as `MEArec`_. The comparison to ground-truth datasets is useful to benchmark spike sorting algorithms. diff --git a/doc/modules/core.rst b/doc/modules/core.rst index ef26fb14c7..85844d4440 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -711,13 +711,8 @@ In this example, we create a recording and a sorting object from numpy objects: spike_trains += spike_trains_i labels += labels_i - # construct a mono-segment - samples_list = [np.array(spike_trains)] - labels_list = [np.array(labels)] - - sorting_memory = NumpySorting.from_samples_and_labels( - samples_list=samples_list, labels_list=labels_list, sampling_frequency=sampling_frequency - ) + sorting_memory = NumpySorting.from_times_labels(times=spike_trains, labels=labels, + sampling_frequency=sampling_frequency) Any sorting object can be transformed into a :py:class:`~spikeinterface.core.NumpySorting` or diff --git a/doc/modules/index.rst b/doc/modules/index.rst index a759569ae9..d849798a01 100644 --- a/doc/modules/index.rst +++ b/doc/modules/index.rst @@ -17,4 +17,3 @@ Modules documentation sortingcomponents motion_correction generation - benchmark diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index d8a4708236..a58fba1c98 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -473,6 +473,8 @@ Here is the list of external sorters accessible using the run_sorter wrapper: * **HDSort** :code:`run_sorter(sorter_name='hdsort')` * **YASS** :code:`run_sorter(sorter_name='yass')` +Internal Sorters +---------------- Here a list of internal sorter based on `spikeinterface.sortingcomponents`; they are totally experimental for now: diff --git a/examples/tutorials/core/plot_2_sorting_extractor.py b/examples/tutorials/core/plot_2_sorting_extractor.py index bf45e15aa7..b572218ed8 100644 --- a/examples/tutorials/core/plot_2_sorting_extractor.py +++ b/examples/tutorials/core/plot_2_sorting_extractor.py @@ -40,7 +40,7 @@ ############################################################################## # And instantiate a :py:class:`~spikeinterface.core.NumpySorting` object: -sorting = se.NumpySorting.from_samples_and_labels([times0, times1], [labels0, labels1], sampling_frequency) +sorting = se.NumpySorting.from_times_labels([times0, times1], [labels0, labels1], sampling_frequency) print(sorting) ############################################################################## diff --git a/examples/tutorials/curation/plot_2_train_a_model.py b/examples/tutorials/curation/plot_2_train_a_model.py index 6e08657af3..1a38836527 100644 --- a/examples/tutorials/curation/plot_2_train_a_model.py +++ b/examples/tutorials/curation/plot_2_train_a_model.py @@ -1,6 +1,6 @@ """ Training a model for automated curation -======================================= +============================= If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface. """ @@ -57,7 +57,7 @@ sw.plot_unit_templates(analyzer, unit_ids=["0", "5"]) ############################################################################## -# This is as expected: great! (Find out more about plotting `using widgets `_.) +# This is as expected: great! (Find out more about plotting using widgets `here `_.) # We've set up our system so that the first five units are 'good' and the next five are 'bad'. # So we can make a list of labels which contain this information. For real data, you could # use a manual curation tool to make your own list. @@ -129,8 +129,8 @@ # half were pure noise and half were not. # # The model also contains some more information, such as which features are "important", -# as defined by sklearn (learn about feature importance of a -# `Random Forest Classifier `_.) +# as defined by sklearn (learn about feature importance of a Random Forest Classifier +# `here `_.) # We can plot these: # Plot feature importances diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py index ad9d16cab5..0a9ea402db 100644 --- a/examples/tutorials/curation/plot_3_upload_a_model.py +++ b/examples/tutorials/curation/plot_3_upload_a_model.py @@ -136,4 +136,4 @@ # # Chris Halcrow # -# You can see the repo with this `Model card `_. +# You can see the repo with this Model card `here `_. diff --git a/pyproject.toml b/pyproject.toml index 15ce837774..97ba77299e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,15 +67,15 @@ extractors = [ "sonpy;python_version<'3.10'", "lxml", # lxml for neuroscope "scipy", - "ONE-api>=2.7.0,<3.0.0", # alf sorter and streaming IBL - "ibllib>=2.36.0,<3.0.0", # streaming IBL + "ONE-api>=2.7.0", # alf sorter and streaming IBL + "ibllib>=2.36.0", # streaming IBL "pymatreader>=0.0.32", # For cell explorer matlab files "zugbruecke>=0.2; sys_platform!='win32'", # For plexon2 ] streaming_extractors = [ - "ONE-api>=2.7.0,<3.0.0", # alf sorter and streaming IBL - "ibllib>=2.36.0,<3.0.0", # streaming IBL + "ONE-api>=2.7.0,<2.10.0", # alf sorter and streaming IBL + "ibllib>=2.36.0", # streaming IBL # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", "fsspec", @@ -149,7 +149,6 @@ test = [ "pytest", "pytest-dependency", "pytest-cov", - "psutil", "huggingface_hub", @@ -165,7 +164,7 @@ test = [ "hdbscan>=0.8.33", # Previous version had a broken wheel # for sortingview backend - "sortingview>=0.12.0", + "sortingview", # for motion and sortingcomponents "torch", diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 361048cc63..f4a1808b2c 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -366,39 +366,6 @@ def get_units_snr(self, key): def get_result(self, key): return self.benchmarks[key].result - def get_pairs_by_level(self, level): - """ - usefull for function like plot_performance_losses() where you need to plot one pair of results - This generate list of pairs for a given level. - """ - - level_index = self.levels.index(level) - - possible_values = [] - for key in self.cases.keys(): - assert isinstance(key, tuple), "get_pairs_by_level need tuple keys" - level_value = key[level_index] - if level_value not in possible_values: - possible_values.append(level_value) - assert len(possible_values) == 2, "get_pairs_by_level() : you need exactly 2 value for this levels" - - pairs = [] - for key in self.cases.keys(): - - case0 = list(key) - case1 = list(key) - case0[level_index] = possible_values[0] - case1[level_index] = possible_values[1] - case0 = tuple(case0) - case1 = tuple(case1) - - pair = (case0, case1) - - if pair not in pairs: - pairs.append(pair) - - return pairs - class Benchmark: """ diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index b6ab68d9a0..2f81344a58 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -60,7 +60,7 @@ def compute_result(self, **result_params): self.result["sliced_gt_sorting"].set_property("gt_unit_locations", gt_unit_locations) - self.result["clustering"] = NumpySorting.from_samples_and_labels( + self.result["clustering"] = NumpySorting.from_times_labels( data["sample_index"], self.result["peak_labels"][~self.noise], self.recording.sampling_frequency ) @@ -181,21 +181,6 @@ def plot_performances_vs_snr(self, **kwargs): return plot_performances_vs_snr(self, **kwargs) - def plot_performances_comparison(self, *args, **kwargs): - from .benchmark_plot_tools import plot_performances_comparison - - return plot_performances_comparison(self, *args, **kwargs) - - def plot_performance_losses(self, *args, **kwargs): - from .benchmark_plot_tools import plot_performance_losses - - return plot_performance_losses(self, *args, **kwargs) - - def plot_performances_vs_depth_and_snr(self, *args, **kwargs): - from .benchmark_plot_tools import plot_performances_vs_depth_and_snr - - return plot_performances_vs_depth_and_snr(self, *args, **kwargs) - def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: @@ -366,6 +351,104 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs return fig + def plot_unit_losses(self, cases_before, cases_after, metric="agreement", figsize=None): + + fig, axs = plt.subplots(ncols=len(cases_before), nrows=1, figsize=figsize) + + for count, (case_before, case_after) in enumerate(zip(cases_before, cases_after)): + + ax = axs[count] + dataset_key = self.cases[case_before]["dataset"] + _, gt_sorting1 = self.datasets[dataset_key] + positions = gt_sorting1.get_property("gt_unit_locations") + + analyzer = self.get_sorting_analyzer(case_before) + metrics_before = analyzer.get_extension("quality_metrics").get_data() + x = metrics_before["snr"].values + + y_before = self.get_result(case_before)["gt_comparison"].get_performance()[metric].values + y_after = self.get_result(case_after)["gt_comparison"].get_performance()[metric].values + ax.set_ylabel("depth (um)") + ax.set_ylabel("snr") + if count > 0: + ax.set_ylabel("") + ax.set_yticks([], []) + im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm") + im.set_clim(-1, 1) + # fig.colorbar(im, ax=ax) + # ax.set_title(k) + + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + cbar = fig.colorbar(im, cax=cbar_ax, label=metric) + # cbar.set_clim(-1, 1) + + return fig + + def plot_comparison_clustering( + self, + case_keys=None, + performance_names=["accuracy", "recall", "precision"], + colors=["g", "b", "r"], + ylim=(-0.1, 1.1), + figsize=None, + ): + + if case_keys is None: + case_keys = list(self.cases.keys()) + import pylab as plt + + num_methods = len(case_keys) + fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) + for i, key1 in enumerate(case_keys): + for j, key2 in enumerate(case_keys): + if len(axs.shape) > 1: + ax = axs[i, j] + else: + ax = axs[j] + comp1 = self.get_result(key1)["gt_comparison"] + comp2 = self.get_result(key2)["gt_comparison"] + if i <= j: + for performance, color in zip(performance_names, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.plot(perf2, perf1, ".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + label1 = self.cases[key1]["label"] + label2 = self.cases[key2]["label"] + if j == i: + ax.set_ylabel(f"{label1}") + else: + ax.set_yticks([]) + if i == j: + ax.set_xlabel(f"{label2}") + else: + ax.set_xticks([]) + if i == num_methods - 1 and j == num_methods - 1: + patches = [] + import matplotlib.patches as mpatches + + for color, name in zip(colors, performance_names): + patches.append(mpatches.Patch(color=color, label=name)) + ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) + else: + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + + plt.tight_layout(h_pad=0, w_pad=0) + + return fig + def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 86ea68bc0d..3a1106e1fd 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -1,7 +1,5 @@ from __future__ import annotations -import warnings - from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import NumpySorting from spikeinterface.comparison import CollisionGTComparison, compare_sorter_to_ground_truth @@ -79,11 +77,6 @@ def plot_performances_comparison(self, **kwargs): return plot_performances_comparison(self, **kwargs) - def plot_performances_vs_depth_and_snr(self, *args, **kwargs): - from .benchmark_plot_tools import plot_performances_vs_depth_and_snr - - return plot_performances_vs_depth_and_snr(self, *args, **kwargs) - def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -145,13 +138,39 @@ def plot_unit_counts(self, case_keys=None, **kwargs): return plot_unit_counts(self, case_keys, **kwargs) - def plot_unit_losses(self, *args, **kwargs): - from .benchmark_plot_tools import plot_performance_losses + def plot_unit_losses(self, before, after, metric=["accuracy"], figsize=None): + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False) + + for count, k in enumerate(metric): - warnings.warn("plot_unit_losses() is now plot_performance_losses()") - return plot_performance_losses(self, *args, **kwargs) + ax = axs[0, count] - def plot_performance_losses(self, *args, **kwargs): - from .benchmark_plot_tools import plot_performance_losses + label = self.cases[after]["label"] - return plot_performance_losses(self, *args, **kwargs) + positions = self.get_result(before)["gt_comparison"].sorting1.get_property("gt_unit_locations") + + analyzer = self.get_sorting_analyzer(before) + metrics_before = analyzer.get_extension("quality_metrics").get_data() + x = metrics_before["snr"].values + + y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values + y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values + # if count < 2: + # ax.set_xticks([], []) + # elif count == 2: + ax.set_xlabel("depth (um)") + im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm") + fig.colorbar(im, ax=ax, label=k) + im.set_clim(-1, 1) + ax.set_title(k) + ax.set_ylabel("snr") + + # fig.subplots_adjust(right=0.85) + # cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + # cbar = fig.colorbar(im, cax=cbar_ax, label=metric) + + # if count == 2: + # ax.legend() + return fig diff --git a/src/spikeinterface/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py index 3923a229e2..568ff2bb9b 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -403,7 +403,7 @@ def plot_comparison_positions(self, case_keys=None): # ax2.spines["top"].set_visible(False) # ax2.spines["right"].set_visible(False) # ax2.set_xlim(xmin, xmax) -# ax2.set_xlabel(r"x ($\\mu$m)") +# ax2.set_xlabel(r"x ($\mu$m)") # ax2.set_ylabel("# spikes") diff --git a/src/spikeinterface/benchmark/benchmark_peak_selection.py b/src/spikeinterface/benchmark/benchmark_peak_selection.py index 54cbe0bfe1..41edea156f 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_selection.py @@ -160,7 +160,7 @@ def create_benchmark(self, key): # nb_garbage = len(garbage_peaks) # ratio = 100 * len(garbage_peaks) / len(times2) -# self.garbage_sorting = NumpySorting.from_samples_and_labels(garbage_peaks, garbage_channels, self.sampling_rate) +# self.garbage_sorting = NumpySorting.from_times_labels(garbage_peaks, garbage_channels, self.sampling_rate) # print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 44e42656ab..31af101639 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -37,7 +37,6 @@ def plot_run_times(study, case_keys=None): ax.bar(i, rt, width=0.8, color=colors[key]) ax.set_xticks(np.arange(len(case_keys))) ax.set_xticklabels(labels, rotation=45.0) - ax.set_ylabel("run time (s)") return fig @@ -95,6 +94,84 @@ def plot_unit_counts(study, case_keys=None): return fig +def plot_performances(study, mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None): + """ + Plot performances over case for a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + mode : "ordered" | "snr" | "swarm", default: "ordered" + Which plot mode to use: + + * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy + * "snr": plot performance metrics vs snr + * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) + performance_names : list or tuple, default: ("accuracy", "precision", "recall") + Which performances to plot ("accuracy", "precision", "recall") + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns + + if case_keys is None: + case_keys = list(study.cases.keys()) + + perfs = study.get_performance_by_unit(case_keys=case_keys) + colors = study.get_colors() + + if mode in ("ordered", "snr"): + num_axes = len(performance_names) + fig, axs = plt.subplots(ncols=num_axes) + else: + fig, ax = plt.subplots() + + if mode == "ordered": + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + for key in case_keys: + label = study.cases[key]["label"] + val = perfs.xs(key).loc[:, performance_name].values + val = np.sort(val)[::-1] + ax.plot(val, label=label, c=colors[key]) + ax.set_title(performance_name) + if count == len(performance_names) - 1: + ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) + + elif mode == "snr": + metric_name = mode + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + + max_metric = 0 + for key in case_keys: + x = study.get_metrics(key).loc[:, metric_name].values + y = perfs.xs(key).loc[:, performance_name].values + label = study.cases[key]["label"] + ax.scatter(x, y, s=10, label=label, color=colors[key]) + max_metric = max(max_metric, np.max(x)) + ax.set_title(performance_name) + ax.set_xlim(0, max_metric * 1.05) + ax.set_ylim(0, 1.05) + if count == 0: + ax.legend(loc="lower right") + + elif mode == "swarm": + levels = perfs.index.names + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=performance_names, + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True, ax=ax) + + def plot_agreement_matrix(study, ordered=True, case_keys=None): """ Plot agreement matri ces for cases in a study. @@ -138,57 +215,20 @@ def plot_agreement_matrix(study, ordered=True, case_keys=None): ax.set_xticks([]) -def plot_performances(study, mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None): - """ - Plot performances over case for a study. - - Parameters - ---------- - study : BenchmarkStudy - A study object. - mode : "ordered" | "snr" | "swarm", default: "ordered" - Which plot mode to use: - - * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy - * "snr": plot performance metrics vs snr - * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) - performance_names : list or tuple, default: ("accuracy", "precision", "recall") - Which performances to plot ("accuracy", "precision", "recall") - case_keys : list or None - A selection of cases to plot, if None, then all. - """ - if mode == "snr": - warnings.warn("Use study.plot_performances_vs_snr() instead") - return plot_performances_vs_snr(study, case_keys=case_keys, performance_names=performance_names) - elif mode == "ordered": - warnings.warn("Use study.plot_performances_ordered() instead") - return plot_performances_ordered(study, case_keys=case_keys, performance_names=performance_names) - elif mode == "swarm": - warnings.warn("Use study.plot_performances_swarm() instead") - return plot_performances_swarm(study, case_keys=case_keys, performance_names=performance_names) - else: - raise ValueError("plot_performances() : wrong mode ") - - def plot_performances_vs_snr( - study, - case_keys=None, - figsize=None, - performance_names=("accuracy", "recall", "precision"), - snr_dataset_reference=None, + study, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"], snr_dataset_reference=None ): import matplotlib.pyplot as plt if case_keys is None: case_keys = list(study.cases.keys()) - fig, axs = plt.subplots(ncols=1, nrows=len(performance_names), figsize=figsize, squeeze=False) + fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) - for count, k in enumerate(performance_names): + for count, k in enumerate(metrics): ax = axs[count, 0] for key in case_keys: - color = study.get_colors()[key] label = study.cases[key]["label"] if snr_dataset_reference is None: @@ -198,16 +238,12 @@ def plot_performances_vs_snr( # use the same SNR from a reference dataset analyzer = study.get_sorting_analyzer(dataset_key=snr_dataset_reference) - quality_metrics = analyzer.get_extension("quality_metrics").get_data() - x = quality_metrics["snr"].values + metrics = analyzer.get_extension("quality_metrics").get_data() + x = metrics["snr"].values y = study.get_result(key)["gt_comparison"].get_performance()[k].values - ax.scatter(x, y, marker=".", label=label, color=color) + ax.scatter(x, y, marker=".", label=label) ax.set_title(k) - popt = fit_sigmoid(x, y, p0=None) - xfit = np.linspace(0, max(x), 100) - ax.plot(xfit, sigmoid(xfit, *popt), color=color) - ax.set_ylim(-0.05, 1.05) if count == 2: @@ -216,75 +252,11 @@ def plot_performances_vs_snr( return fig -def plot_performances_ordered( - study, - case_keys=None, - performance_names=("accuracy", "recall", "precision"), - figsize=None, -): - import matplotlib.pyplot as plt - - num_axes = len(performance_names) - fig, axs = plt.subplots(nrows=num_axes, figsize=figsize, squeeze=False) - - if case_keys is None: - case_keys = list(study.cases.keys()) - - perfs = study.get_performance_by_unit(case_keys=case_keys) - colors = study.get_colors() - - for count, performance_name in enumerate(performance_names): - ax = axs[count, 0] - - for key in case_keys: - color = study.get_colors()[key] - label = study.cases[key]["label"] - - val = perfs.xs(key).loc[:, performance_name].values - val = np.sort(val)[::-1] - ax.plot(val, label=label, c=colors[key]) - - ax.set_title(performance_name) - if count == len(performance_names) - 1: - ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) - - return fig - - -def plot_performances_swarm(study, case_keys=None, performance_names=("accuracy", "recall", "precision"), figsize=None): - - import matplotlib.pyplot as plt - import pandas as pd - import seaborn as sns - - if case_keys is None: - case_keys = list(study.cases.keys()) - - perfs = study.get_performance_by_unit(case_keys=case_keys) - colors = study.get_colors() - - fig, ax = plt.subplots() - - levels = perfs.index.names - - df = pd.melt( - perfs.reset_index(), - id_vars=levels, - var_name="Metric", - value_name="Score", - value_vars=performance_names, - ) - df["x"] = df.apply(lambda r: " ".join([str(r[col]) for col in levels]), axis=1) - sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True, ax=ax) - - return fig - - def plot_performances_comparison( study, case_keys=None, figsize=None, - performance_names=("accuracy", "recall", "precision"), + metrics=["accuracy", "recall", "precision"], colors=["g", "b", "r"], ylim=(-0.1, 1.1), ): @@ -296,7 +268,7 @@ def plot_performances_comparison( num_methods = len(case_keys) assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!" - fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=figsize, squeeze=False) + fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False) for i, key1 in enumerate(case_keys): for j, key2 in enumerate(case_keys): @@ -306,7 +278,7 @@ def plot_performances_comparison( comp1 = study.get_result(key1)["gt_comparison"] comp2 = study.get_result(key2)["gt_comparison"] - for performance, color in zip(performance_names, colors): + for performance, color in zip(metrics, colors): perf1 = comp1.get_performance()[performance] perf2 = comp2.get_performance()[performance] ax.scatter(perf2, perf1, marker=".", label=performance, color=color) @@ -335,83 +307,13 @@ def plot_performances_comparison( patches = [] from matplotlib.patches import Patch - for color, name in zip(colors, performance_names): + for color, name in zip(colors, metrics): patches.append(Patch(color=color, label=name)) ax.legend(handles=patches) fig.tight_layout() return fig -def plot_performances_vs_depth_and_snr(study, performance_name="accuracy", case_keys=None, figsize=None): - - import pylab as plt - - if case_keys is None: - case_keys = list(study.cases.keys()) - - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) - - for count, key in enumerate(case_keys): - - result = study.get_result(key) - - positions = result["sliced_gt_sorting"].get_property("gt_unit_locations") - depth = positions[:, 1] - - analyzer = study.get_sorting_analyzer(key) - metrics = analyzer.get_extension("quality_metrics").get_data() - snr = metrics["snr"] - perfs = result["gt_comparison"].get_performance()[performance_name].values - - ax = axs[0, count] - points = ax.scatter(depth, snr, c=perfs, label="matched") - points.set_clim(0, 1) - ax.set_xlabel("depth") - ax.set_ylabel("snr") - label = study.cases[key]["label"] - ax.set_title(label) - if count > 0: - ax.set_ylabel("") - ax.set_yticks([], []) - - fig.subplots_adjust(right=0.85) - cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) - fig.colorbar(points, cax=cbar_ax, label=performance_name) - - return fig - - -def plot_performance_losses(study, case0, case1, performance_names=["accuracy"], figsize=None): - import matplotlib.pyplot as plt - - fig, axs = plt.subplots(ncols=1, nrows=len(performance_names), figsize=figsize, squeeze=False) - - for count, perf_name in enumerate(performance_names): - - ax = axs[0, count] - - positions = study.get_result(case0)["gt_comparison"].sorting1.get_property("gt_unit_locations") - - analyzer = study.get_sorting_analyzer(case0) - metrics_case0 = analyzer.get_extension("quality_metrics").get_data() - x = metrics_case0["snr"].values - - y_case0 = study.get_result(case0)["gt_comparison"].get_performance()[perf_name].values - y_case1 = study.get_result(case1)["gt_comparison"].get_performance()[perf_name].values - - ax.set_xlabel("depth (um)") - im = ax.scatter(positions[:, 1], x, c=(y_case1 - y_case0), cmap="coolwarm") - fig.colorbar(im, ax=ax, label=perf_name) - im.set_clim(-1, 1) - - label0 = study.cases[case0]["label"] - label1 = study.cases[case1]["label"] - ax.set_title(f"{label0}\n vs \n{label1}") - ax.set_ylabel("snr") - - return fig - - def sigmoid(x, x0, k, b): with warnings.catch_warnings(action="ignore"): out = (1 / (1 + np.exp(-k * (x - x0)))) + b diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index edf1d0f6cc..3cf6dca04f 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -138,27 +138,7 @@ def plot_performances(self, **kwargs): return plot_performances(self, **kwargs) - def plot_performances_vs_snr(self, **kwargs): - from .benchmark_plot_tools import plot_performances_vs_snr - - return plot_performances_vs_snr(self, **kwargs) - - def plot_performances_ordered(self, **kwargs): - from .benchmark_plot_tools import plot_performances_ordered - - return plot_performances_ordered(self, **kwargs) - - def plot_performances_swarm(self, **kwargs): - from .benchmark_plot_tools import plot_performances_swarm - - return plot_performances_swarm(self, **kwargs) - def plot_agreement_matrix(self, **kwargs): from .benchmark_plot_tools import plot_agreement_matrix return plot_agreement_matrix(self, **kwargs) - - def plot_performance_losses(self, *args, **kwargs): - from .benchmark_plot_tools import plot_performance_losses - - return plot_performance_losses(self, *args, **kwargs) diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index e3f23726c7..3a39f08a7c 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -283,10 +283,9 @@ class MixinSpikeTrainComparison: * n_jobs """ - def __init__(self, delta_time=0.4, agreement_method="count", n_jobs=-1): + def __init__(self, delta_time=0.4, n_jobs=-1): self.delta_time = delta_time self.n_jobs = n_jobs - self.agreement_method = agreement_method self.sampling_frequency = None self.delta_frames = None diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 8990b1e586..6a4be86796 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -35,9 +35,6 @@ class MultiSortingComparison(BaseMultiComparison, MixinSpikeTrainComparison): Minimum agreement score to match units chance_score : float, default: 0.1 Minimum agreement score to for a possible match - agreement_method : "count" | "distance", default: "count" - The method to compute agreement scores. The "count" method computes agreement scores from spike counts. - The "distance" method computes agreement scores from spike time distance functions. n_jobs : int, default: -1 Number of cores to use in parallel. Uses all available if -1 spiketrain_mode : "union" | "intersection", default: "union" @@ -63,7 +60,6 @@ def __init__( delta_time=0.4, # sampling_frequency=None, match_score=0.5, chance_score=0.1, - agreement_method="count", n_jobs=-1, spiketrain_mode="union", verbose=False, @@ -79,9 +75,7 @@ def __init__( chance_score=chance_score, verbose=verbose, ) - MixinSpikeTrainComparison.__init__( - self, delta_time=delta_time, agreement_method=agreement_method, n_jobs=n_jobs - ) + MixinSpikeTrainComparison.__init__(self, delta_time=delta_time, n_jobs=n_jobs) self.set_frames_and_frequency(self.object_list) self._spiketrain_mode = spiketrain_mode self._spiketrains = None @@ -99,8 +93,6 @@ def _compare_ij(self, i, j): sorting2_name=self.name_list[j], delta_time=self.delta_time, match_score=self.match_score, - chance_score=self.chance_score, - agreement_method=self.agreement_method, n_jobs=self.n_jobs, verbose=False, ) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index 43288415d1..31adee8ca4 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -20,8 +20,8 @@ def make_sorting(times1, labels1, times2, labels2): sampling_frequency = 30000.0 - sorting1 = NumpySorting.from_samples_and_labels([times1], [labels1], sampling_frequency) - sorting2 = NumpySorting.from_samples_and_labels([times2], [labels2], sampling_frequency) + sorting1 = NumpySorting.from_times_labels([times1], [labels1], sampling_frequency) + sorting2 = NumpySorting.from_times_labels([times2], [labels2], sampling_frequency) return sorting1, sorting2 diff --git a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py index 27b8f73077..b58a03eff2 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py @@ -8,8 +8,8 @@ def make_sorting(times1, labels1, times2, labels2): sampling_frequency = 30000.0 - gt_sorting = NumpySorting.from_samples_and_labels([times1], [labels1], sampling_frequency) - tested_sorting = NumpySorting.from_samples_and_labels([times2], [labels2], sampling_frequency) + gt_sorting = NumpySorting.from_times_labels([times1], [labels1], sampling_frequency) + tested_sorting = NumpySorting.from_times_labels([times2], [labels2], sampling_frequency) return gt_sorting, tested_sorting diff --git a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py index 6c16999d01..9ea8ba3e80 100644 --- a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py @@ -19,9 +19,9 @@ def setup_module(tmp_path_factory): def make_sorting(times1, labels1, times2, labels2, times3, labels3): sampling_frequency = 30000.0 - sorting1 = NumpySorting.from_samples_and_labels([times1], [labels1], sampling_frequency) - sorting2 = NumpySorting.from_samples_and_labels([times2], [labels2], sampling_frequency) - sorting3 = NumpySorting.from_samples_and_labels([times3], [labels3], sampling_frequency) + sorting1 = NumpySorting.from_times_labels([times1], [labels1], sampling_frequency) + sorting2 = NumpySorting.from_times_labels([times2], [labels2], sampling_frequency) + sorting3 = NumpySorting.from_times_labels([times3], [labels3], sampling_frequency) sorting1 = sorting1.save() sorting2 = sorting2.save() sorting3 = sorting3.save() @@ -41,15 +41,12 @@ def test_compare_multiple_sorters(setup_module): ) msc = compare_multiple_sorters([sorting1, sorting2, sorting3], verbose=True) msc_shuffle = compare_multiple_sorters([sorting3, sorting1, sorting2]) - msc_dist = compare_multiple_sorters([sorting3, sorting1, sorting2], agreement_method="distance") agr = msc._do_agreement_matrix() agr_shuffle = msc_shuffle._do_agreement_matrix() - agr_dist = msc_dist._do_agreement_matrix() print(agr) print(agr_shuffle) - print(agr_dist) assert len(msc.get_agreement_sorting(minimum_agreement_count=3).get_unit_ids()) == 3 assert len(msc.get_agreement_sorting(minimum_agreement_count=2).get_unit_ids()) == 5 @@ -60,14 +57,7 @@ def test_compare_multiple_sorters(setup_module): assert len(msc.get_agreement_sorting(minimum_agreement_count=2).get_unit_ids()) == len( msc_shuffle.get_agreement_sorting(minimum_agreement_count=2).get_unit_ids() ) - assert len(msc.get_agreement_sorting(minimum_agreement_count=3).get_unit_ids()) == len( - msc_dist.get_agreement_sorting(minimum_agreement_count=3).get_unit_ids() - ) - assert len(msc.get_agreement_sorting(minimum_agreement_count=2).get_unit_ids()) == len( - msc_dist.get_agreement_sorting(minimum_agreement_count=2).get_unit_ids() - ) assert len(msc.get_agreement_sorting().get_unit_ids()) == len(msc_shuffle.get_agreement_sorting().get_unit_ids()) - assert len(msc.get_agreement_sorting().get_unit_ids()) == len(msc_dist.get_agreement_sorting().get_unit_ids()) agreement_2 = msc.get_agreement_sorting(minimum_agreement_count=2, minimum_agreement_count_only=True) assert np.all([agreement_2.get_unit_property(u, "agreement_number")] == 2 for u in agreement_2.get_unit_ids()) diff --git a/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py b/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py index c6ab04707b..5725206a23 100644 --- a/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_symmetricsortingcomparison.py @@ -7,8 +7,8 @@ def make_sorting(times1, labels1, times2, labels2): sampling_frequency = 30000.0 - sorting1 = NumpySorting.from_samples_and_labels([times1], [labels1], sampling_frequency) - sorting2 = NumpySorting.from_samples_and_labels([times2], [labels2], sampling_frequency) + sorting1 = NumpySorting.from_times_labels([times1], [labels1], sampling_frequency) + sorting2 = NumpySorting.from_times_labels([times2], [labels2], sampling_frequency) return sorting1, sorting2 diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index fb2e173b3e..f09458f6a6 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -172,8 +172,6 @@ compute_noise_levels, ) -from .motion import Motion - # Important not for compatibility!! # This wil be uncommented after 0.100 from .waveforms_extractor_backwards_compatibility import ( diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 581a9eae53..089b249fd4 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -344,6 +344,13 @@ def get_traces( traces = traces.astype(f"int{dtype.itemsize * 8}") if return_scaled: + if hasattr(self, "NeoRawIOClass"): + if self.has_non_standard_units: + message = ( + f"This extractor based on neo.{self.NeoRawIOClass} has channels with units not in (V, mV, uV)" + ) + warnings.warn(message) + if not self.has_scaleable_traces(): if self._dtype.kind == "f": # here we do not truely have scale but we assume this is scaled diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index b224e0d282..2ec3664a45 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -560,7 +560,7 @@ def split_by(self, property="group", outputs="dict"): recordings = [] elif outputs == "dict": recordings = {} - for value in np.unique(values).tolist(): + for value in np.unique(values): (inds,) = np.nonzero(values == value) new_channel_ids = self.get_channel_ids()[inds] subrec = self.select_channels(new_channel_ids) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 4fa1d88974..820b4fcd91 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -30,17 +30,14 @@ def __init__(self, recording_list, renamed_channel_ids=None): ), "'renamed_channel_ids' doesn't have the right size or has duplicates!" channel_ids = list(renamed_channel_ids) else: - - # Explicitly check if all channel_ids arrays are either all integers or all strings. - all_ids_are_int_dtype = all(np.issubdtype(rec.channel_ids.dtype, np.integer) for rec in recording_list) - all_ids_are_str_dtype = all(np.issubdtype(rec.channel_ids.dtype, np.str_) for rec in recording_list) - - all_ids_have_same_dtype = all_ids_are_int_dtype or all_ids_are_str_dtype - if all_ids_have_same_dtype: + # Collect channel IDs from all recordings + all_channels_have_same_type = np.unique([rec.channel_ids.dtype for rec in recording_list]).size == 1 + all_channel_ids_are_unique = False + if all_channels_have_same_type: combined_ids = np.concatenate([rec.channel_ids for rec in recording_list]) all_channel_ids_are_unique = np.unique(combined_ids).size == num_all_channels - if all_ids_have_same_dtype and all_channel_ids_are_unique: + if all_channels_have_same_type and all_channel_ids_are_unique: channel_ids = combined_ids else: # If IDs are not unique or not of the same type, use default as stringify IDs diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 8a4f29e86c..db15bd219a 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -31,12 +31,9 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parents_chan_ids = parent_recording.get_channel_ids() # some checks - # We use lists to compare numpy scalar types as their python versions (e.g. int vs int64()) - channel_ids_not_in_parents = [id for id in self._channel_ids.tolist() if id not in parents_chan_ids.tolist()] - assert ( - len(channel_ids_not_in_parents) == 0 - ), f"ChannelSliceRecording : channel ids {channel_ids_not_in_parents} are not all in parent ids {parents_chan_ids}" - + assert all( + chan_id in parents_chan_ids for chan_id in self._channel_ids + ), "ChannelSliceRecording : channel ids are not all in parents" assert len(self._channel_ids) == len( self._renamed_channel_ids ), "ChannelSliceRecording: renamed channel_ids must be the same size" diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ed18b815de..aa69fe585b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -509,7 +509,7 @@ def add_from_unit_dict( return sorting @staticmethod - def from_samples_and_labels( + def from_times_labels( sorting1, times_list, labels_list, sampling_frequency, unit_ids=None, refractory_period_ms=None ) -> "NumpySorting": """ @@ -537,7 +537,7 @@ def from_samples_and_labels( discarded. """ - sorting2 = NumpySorting.from_samples_and_labels(times_list, labels_list, sampling_frequency, unit_ids) + sorting2 = NumpySorting.from_times_labels(times_list, labels_list, sampling_frequency, unit_ids) sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) return sorting diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 7bd3bbd860..d510204467 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -376,7 +376,6 @@ def __init__( parents: Optional[list[PipelineNode]] = None, return_output: bool = False, radius_um: float = 100.0, - sparsity_mask: np.ndarray = None, ): """ Extract sparse waveforms from a recording. The strategy in this specific node is to reshape the waveforms @@ -401,11 +400,6 @@ def __init__( Pass parents nodes to perform a previous computation return_output : bool, default: False Whether or not the output of the node is returned by the pipeline - radius_um : float, default: 100.0 - The radius to determine the neighborhood of channels to extract waveforms from. - sparsity_mask : np.ndarray, default: None - Optional mask to specify the sparsity of the waveforms. If provided, it should be a boolean array of shape - (num_channels, num_channels) where True indicates that the channel is active in the neighborhood. """ WaveformsNode.__init__( self, @@ -416,15 +410,10 @@ def __init__( return_output=return_output, ) + self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - - if sparsity_mask is not None: - self.neighbours_mask = sparsity_mask - self.radius_um = None - else: - self.radius_um = radius_um - self.neighbours_mask = self.channel_distance <= radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) def get_trace_margin(self): diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index b1c715e7a3..f4790817a8 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings import numpy as np from spikeinterface.core import ( BaseRecording, @@ -237,8 +236,7 @@ class NumpySorting(BaseSorting): But we have convenient class methods to instantiate from: * other sorting object: `NumpySorting.from_sorting()` - * from samples+labels: `NumpySorting.from_samples_and_labels()` - * from times+labels: `NumpySorting.from_times_and_labels()` + * from time+labels: `NumpySorting.from_times_labels()` * from dict of list: `NumpySorting.from_unit_dict()` * from neo: `NumpySorting.from_neo_spiketrain_list()` @@ -289,17 +287,17 @@ def from_sorting(source_sorting: BaseSorting, with_metadata=False, copy_spike_ve return sorting @staticmethod - def from_samples_and_labels(samples_list, labels_list, sampling_frequency, unit_ids=None) -> "NumpySorting": + def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None) -> "NumpySorting": """ Construct NumpySorting extractor from: - * an array of spike samples - * an array of spike labels + * an array of spike times (in frames) + * an array of spike labels and adds all the In case of multisegment, it is a list of array. Parameters ---------- - samples_list : list of array (or array) - An array of spike samples + times_list : list of array (or array) + An array of spike times (in frames) labels_list : list of array (or array) An array of spike labels corresponding to the given times unit_ids : list or None, default: None @@ -307,22 +305,22 @@ def from_samples_and_labels(samples_list, labels_list, sampling_frequency, unit_ If None, then it will be np.unique(labels_list) """ - if isinstance(samples_list, np.ndarray): + if isinstance(times_list, np.ndarray): assert isinstance(labels_list, np.ndarray) - samples_list = [samples_list] + times_list = [times_list] labels_list = [labels_list] - samples_list = [np.asarray(e) for e in samples_list] + times_list = [np.asarray(e) for e in times_list] labels_list = [np.asarray(e) for e in labels_list] - nseg = len(samples_list) + nseg = len(times_list) if unit_ids is None: unit_ids = np.unique(np.concatenate([np.unique(labels_list[i]) for i in range(nseg)])) spikes = [] for i in range(nseg): - times, labels = samples_list[i], labels_list[i] + times, labels = times_list[i], labels_list[i] unit_index = np.zeros(labels.size, dtype="int64") for u, unit_id in enumerate(unit_ids): unit_index[labels == unit_id] = u @@ -339,41 +337,6 @@ def from_samples_and_labels(samples_list, labels_list, sampling_frequency, unit_ return sorting - @staticmethod - def from_times_and_labels(times_list, labels_list, sampling_frequency, unit_ids=None) -> "NumpySorting": - """ - Construct NumpySorting extractor from: - * an array of spike times (in s) - * an array of spike labels - In case of multisegment, it is a list of array. - - Parameters - ---------- - times_list : list of array (or array) - An array of spike samples - labels_list : list of array (or array) - An array of spike labels corresponding to the given times - unit_ids : list or None, default: None - The explicit list of unit_ids that should be extracted from labels_list - If None, then it will be np.unique(labels_list) - """ - if isinstance(times_list, np.ndarray): - assert isinstance(labels_list, np.ndarray) - times_list = [times_list] - labels_list = [labels_list] - - sample_list = [np.round(t * sampling_frequency).astype("int64") for t in times_list] - return NumpySorting.from_samples_and_labels(sample_list, labels_list, sampling_frequency, unit_ids) - - @staticmethod - def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None) -> "NumpySorting": - warnings.warn( - "from_times_labels is deprecated and will be removed in 0.104.0, use from_sample_and_labels instead", - DeprecationWarning, - stacklevel=2, - ) - return NumpySorting.from_times_and_labels(times_list, labels_list, sampling_frequency, unit_ids) - @staticmethod def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": """ diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index bb4ee4db1c..85d405c443 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -88,10 +88,9 @@ def create_sorting_analyzer( If True, overwrite the folder if it already exists. backend_options : dict | None, default: None Keyword arguments for the backend specified by format. It can contain the: - - * storage_options: dict | None (fsspec storage options) - * saving_options: dict | None (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) - + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) sparsity_kwargs : keyword arguments Returns @@ -188,9 +187,8 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_o backend_options : dict | None, default: None The backend options for the backend. The dictionary can contain the following keys: - - * storage_options: dict | None (fsspec storage options) - * saving_options: dict | None (additional saving options for creating and saving datasets) + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets) Returns ------- @@ -910,9 +908,9 @@ def _save_or_select_or_merge( If True, output is verbose. backend_options : dict | None, default: None Keyword arguments for the backend specified by format. It can contain the: - - * storage_options: dict | None (fsspec storage options) - * saving_options: dict | None (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) job_kwargs : keyword arguments Keyword arguments for the job parallelization. @@ -1075,9 +1073,9 @@ def save_as(self, format="memory", folder=None, backend_options=None) -> "Sortin The new backend format to use backend_options : dict | None, default: None Keyword arguments for the backend specified by format. It can contain the: - - * storage_options: dict | None (fsspec storage options) - * saving_options: dict | None (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) """ if format == "zarr": folder = clean_zarr_folder_name(folder) @@ -1153,7 +1151,7 @@ def merge_units( **job_kwargs, ) -> "SortingAnalyzer": """ - This method is equivalent to `save_as()` but with a list of merges that have to be achieved. + This method is equivalent to `save_as()`but with a list of merges that have to be achieved. Merges units by creating a new SortingAnalyzer object with the appropriate merges Extensions are also updated to display the merged `unit_ids`. @@ -1177,7 +1175,6 @@ def merge_units( achieved, soft merging will not be possible and an error will be raised new_id_strategy : "append" | "take_first", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. - * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges return_new_unit_ids : bool, default False @@ -1202,10 +1199,7 @@ def merge_units( if len(merge_unit_groups) == 0: # TODO I think we should raise an error or at least make a copy and not return itself - if return_new_unit_ids: - return self, [] - else: - return self + return self for units in merge_unit_groups: # TODO more checks like one units is only in one group @@ -1984,7 +1978,6 @@ class AnalyzerExtension: It also enables any custom computation on top of the SortingAnalyzer to be implemented by the user. An extension needs to inherit from this class and implement some attributes and abstract methods: - * extension_name * depend_on * need_recording diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 0e30760262..fd613e1fcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -15,8 +15,6 @@ method : str * "best_channels" : N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. - * "closest_channels" : N closest channels according to the distance. Use the "num_channels" argument to specify the - number of channels. * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um. * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument @@ -322,39 +320,6 @@ def from_best_channels( mask[unit_ind, chan_inds] = True return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) - ## Some convinient function to compute sparsity from several strategy - @classmethod - def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): - """ - Construct sparsity from N closest channels - Use the "num_channels" argument to specify the number of channels. - - Parameters - ---------- - templates_or_sorting_analyzer : Templates | SortingAnalyzer - A Templates or a SortingAnalyzer object. - num_channels : int - Number of channels for "best_channels" method. - - Returns - ------- - sparsity : ChannelSparsity - The estimated sparsity - """ - from .template_tools import get_template_amplitudes - - mask = np.zeros( - (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" - ) - channel_locations = templates_or_sorting_analyzer.get_channel_locations() - distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - - for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_inds = np.argsort(distances[unit_ind]) - chan_inds = chan_inds[:num_channels] - mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) - @classmethod def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): """ @@ -635,9 +600,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, - method: ( - "radius" | "best_channels" | "closest_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" - ) = "radius", + method: "radius" | "best_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, @@ -672,7 +635,7 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer - if method in ("best_channels", "closest_channels", "radius", "snr", "amplitude", "ptp"): + if method in ("best_channels", "radius", "snr", "amplitude", "ptp"): assert isinstance( templates_or_sorting_analyzer, (Templates, SortingAnalyzer) ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" @@ -684,9 +647,6 @@ def compute_sparsity( if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) - elif method == "closest_channels": - assert num_channels is not None, "For the 'closest_channels' method, 'num_channels' needs to be given" - sparsity = ChannelSparsity.from_closest_channels(templates_or_sorting_analyzer, num_channels) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" sparsity = ChannelSparsity.from_radius(templates_or_sorting_analyzer, radius_um, peak_sign=peak_sign) @@ -738,7 +698,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "closest_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", + method: "radius" | "best_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -787,7 +747,7 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "closest_channels", "snr", "amplitude", "by_property", "ptp"), ( + assert method in ("radius", "best_channels", "snr", "amplitude", "by_property", "ptp"), ( f"method={method} is not available for `estimate_sparsity()`. " "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)" ) @@ -842,9 +802,6 @@ def estimate_sparsity( sparsity = ChannelSparsity.from_best_channels( templates, num_channels, peak_sign=peak_sign, amplitude_mode=amplitude_mode ) - elif method == "closest_channels": - assert num_channels is not None, "For the 'closest_channels' method, 'num_channels' needs to be given" - sparsity = ChannelSparsity.from_closest_channels(templates, num_channels) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 16a91a55e1..99d6890dfd 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -118,25 +118,5 @@ def test_channel_aggregation_does_not_preserve_ids_not_the_same_type(): assert list(aggregated_recording.get_channel_ids()) == ["0", "1", "2", "3", "4"] -def test_channel_aggregation_with_string_dtypes_of_different_size(): - """ - Fixes issue https://github.com/SpikeInterface/spikeinterface/issues/3733 - - This tests that the channel ids are propagated in the aggregation even if they are strings of different - string dtype sizes. - """ - recording1 = generate_recording(num_channels=2, durations=[10], set_probe=False) - recording1 = recording1.rename_channels(new_channel_ids=np.array(["8", "9"], dtype=">> pip install kachery\n(kachery-cloud is also supported, but deprecated)" - ) - - def apply_sortingview_curation( sorting_or_analyzer, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=None ): @@ -75,10 +56,16 @@ def apply_sortingview_curation( with open(uri_or_json, "r") as f: curation_dict = json.load(f) else: - ka = get_kachery() + try: + import kachery_cloud as kcl + except ImportError: + raise ImportError( + "To apply a SortingView manual curation, you need to have sortingview installed: " + ">>> pip install sortingview" + ) try: - curation_dict = ka.load_json(uri=uri_or_json) + curation_dict = kcl.load_json(uri=uri_or_json) except: raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") @@ -153,7 +140,13 @@ def apply_sortingview_curation_legacy( sorting_curated : BaseSorting The curated sorting """ - ka = get_kachery() + try: + import kachery_cloud as kcl + except ImportError: + raise ImportError( + "To apply a SortingView manual curation, you need to have sortingview installed: " + ">>> pip install sortingview" + ) curation_sorting = CurationSorting(sorting, make_graph=False, properties_policy="keep") # get sorting view curation @@ -162,7 +155,7 @@ def apply_sortingview_curation_legacy( sortingview_curation_dict = json.load(f) else: try: - sortingview_curation_dict = ka.load_json(uri=uri_or_json) + sortingview_curation_dict = kcl.load_json(uri=uri_or_json) except: raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") diff --git a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py index dd2b4c5536..141cc4c34e 100644 --- a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py +++ b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py @@ -38,7 +38,7 @@ def test_remove_excess_spikes(): times.append(times_segment) labels.append(labels_segment) - sorting = NumpySorting.from_samples_and_labels(times, labels, sampling_frequency=sampling_frequency) + sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=sampling_frequency) assert has_exceeding_spikes(sorting, recording) sorting_corrected = remove_excess_spikes(sorting, recording) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 36c518d8e0..ff80be365d 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -149,7 +149,7 @@ def test_false_positive_curation(): times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) labels = np.random.randint(1, num_units + 1, size=num_spikes) - sorting = se.NumpySorting.from_samples_and_labels(times, labels, sampling_frequency) + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) # print("Sorting: {}".format(sorting.get_unit_ids())) json_file = parent_folder / "sv-sorting-curation-false-positive.json" @@ -175,7 +175,7 @@ def test_label_inheritance_int(): times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) labels = np.random.randint(1, 1 + num_units, size=num_spikes) # 7 units: 1 to 7 - sorting = se.NumpySorting.from_samples_and_labels(times, labels, sampling_frequency) + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) json_file = parent_folder / "sv-sorting-curation-int.json" sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) @@ -217,7 +217,7 @@ def test_label_inheritance_str(): times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes) - sorting = se.NumpySorting.from_samples_and_labels(times, labels, sampling_frequency) + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) # print(f"Sorting: {sorting.get_unit_ids()}") # Apply curation diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index 5bb13e3300..7b315b0fba 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -83,7 +83,7 @@ class CurationModelTrainer: metric_names : list of str, default: None A list of metrics to use for training. If None, default metrics will be used. imputation_strategies : list of str | None, default: None - A list of imputation strategies to try. Can be "knn", "iterative" or any allowed + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed strategy passable to the sklearn `SimpleImputer`. If None, the default strategies `["median", "most_frequent", "knn", "iterative"]` will be used. scaling_techniques : list of str | None, default: None @@ -110,7 +110,7 @@ class CurationModelTrainer: labels : list of lists, default: None List of curated labels for each `sorting_analyzer` and each unit; must be in the same order as the metrics data. imputation_strategies : list of str | None, default: None - A list of imputation strategies to try. Can be "knn", "iterative" or any allowed + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed strategy passable to the sklearn `SimpleImputer`. If None, the default strategies `["median", "most_frequent", "knn", "iterative"]` will be used. scaling_techniques : list of str | None, default: None @@ -665,55 +665,53 @@ def train_model( """ Trains and evaluates machine learning models for spike sorting curation. - This function initializes a ``CurationModelTrainer`` object, loads and preprocesses the data, + This function initializes a `CurationModelTrainer` object, loads and preprocesses the data, and evaluates the specified combinations of imputation strategies, scaling techniques, and classifiers. The evaluation results, including the best model and its parameters, are saved to the output folder. Parameters ---------- - mode : ``"analyzers"`` | ``"csv"``, default: ``"analyzers"`` + mode : "analyzers" | "csv", default: "analyzers" Mode to use for training. - analyzers : list of ``SortingAnalyzer`` | None, default: None - List of ``SortingAnalyzer`` objects containing the quality metrics and labels to use for training, - if using ``"analyzers"`` mode. + analyzers : list of SortingAnalyzer | None, default: None + List of SortingAnalyzer objects containing the quality metrics and labels to use for training, if using 'analyzers' mode. labels : list of list | None, default: None List of curated labels for each unit; must be in the same order as the metrics data. metrics_paths : list of str or None, default: None - List of paths to the CSV files containing the metrics data if using ``"csv"`` mode. + List of paths to the CSV files containing the metrics data if using 'csv' mode. folder : str | None, default: None The folder where outputs such as models and evaluation metrics will be saved. metric_names : list of str | None, default: None A list of metrics to use for training. If None, default metrics will be used. imputation_strategies : list of str | None, default: None - A list of imputation strategies to try. Can be ``"knn"``, ``"iterative"``, or any allowed - strategy passable to the ``sklearn.SimpleImputer``. If None, the default strategies - ``["median", "most_frequent", "knn", "iterative"]`` will be used. + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. scaling_techniques : list of str | None, default: None - A list of scaling techniques to try. Can be ``"standard_scaler"``, ``"min_max_scaler"``, - or ``"robust_scaler"``. If None, all techniques will be used. + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. classifiers : list of str | dict | None, default: None - A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their - hyperparameter search spaces can be provided. If None, default classifiers will be used. - Check the ``get_classifier_search_space`` method for the default search spaces & format for custom spaces. + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. test_size : float, default: 0.2 - Proportion of the dataset to include in the test split, passed to ``train_test_split`` from ``sklearn``. + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. overwrite : bool, default: False - Overwrites the ``folder`` if it already exists. + Overwrites the `folder` if it already exists seed : int | None, default: None Random seed for reproducibility. If None, a random seed will be generated. search_kwargs : dict or None, default: None - Keyword arguments passed to ``BayesSearchCV`` or ``RandomizedSearchCV`` from ``sklearn``. If None, use - ``search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}``. + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. verbose : bool, default: True If True, useful information is printed during training. enforce_metric_params : bool, default: False - If True and metric parameters used to calculate metrics for different ``sorting_analyzer`` objects are + If True and metric parameters used to calculate metrics for different `sorting_analyzer`s are different, an error will be raised. + Returns ------- CurationModelTrainer - The ``CurationModelTrainer`` object used for training and evaluation. + The `CurationModelTrainer` object used for training and evaluation. Notes ----- diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 6247ea9591..317ea21cce 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -294,16 +294,15 @@ class IblSortingExtractor(BaseSorting): >>> one = ONE(base_url="https://openalyx.internationalbrainlab.org", password="international", silent=True) >>> pids, _ = one.eid2pid("session_eid") >>> pid = pids[0] - one: One | dict, required - Instance of ONE.api or dict to use for data loading. - For multi-processing applications, this can also be a dictionary of ONE.api arguments - For example: one=dict(base_url='https://alyx.internationalbrainlab.org', mode='remote') good_clusters_only: bool, default: False If True, only load the good clusters load_unit_properties: bool, default: True If True, load the unit properties from the IBL database - kwargs: dict, optional - Additional keyword arguments to pass to the IBL SpikeSortingLoader constructor, such as `revision`. + one: One | dict, default: None + Instance of ONE.api or dict to use for data loading. + For multi-processing applications, this can also be a dictionary of ONE.api arguments + For example: one={} or one=dict(base_url='https://alyx.internationalbrainlab.org', mode='remote') + Returns ------- extractor : IBLSortingExtractor @@ -312,14 +311,13 @@ class IblSortingExtractor(BaseSorting): installation_mesg = "IBL extractors require ibllib as a dependency." " To install, run: \n\n pip install ibllib\n\n" - def __init__( - self, pid: str, good_clusters_only: bool = False, load_unit_properties: bool = True, one=None, **kwargs - ): + def __init__(self, pid: str, good_clusters_only: bool = False, load_unit_properties: bool = True, one=None): try: from one.api import ONE from brainbox.io.one import SpikeSortingLoader - assert one is not None, "one is a required parameter." + if one is None: + one = {} if isinstance(one, dict): one = ONE(**one) else: @@ -329,11 +327,12 @@ def __init__( self.ssl = SpikeSortingLoader(one=one, pid=pid) sr = self.ssl.raw_electrophysiology(band="ap", stream=True) self._folder_path = self.ssl.session_path - spikes, clusters, channels = self.ssl.load_spike_sorting(dataset_types=["spikes.samples"], **kwargs) + spikes, clusters, channels = self.ssl.load_spike_sorting(dataset_types=["spikes.samples"]) clusters = self.ssl.merge_clusters(spikes, clusters, channels) if good_clusters_only: good_cluster_slice = clusters["cluster_id"][clusters["label"] == 1] + unit_ids = clusters["cluster_id"][good_cluster_slice] else: good_cluster_slice = slice(None) unit_ids = clusters["cluster_id"][good_cluster_slice] diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 90d123bc32..261472ede9 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -1,10 +1,10 @@ from __future__ import annotations -from pathlib import Path -import numpy as np +from pathlib import Path from spikeinterface.core.core_tools import define_function_from_class -from .neobaseextractor import NeoBaseRecordingExtractor + +from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor class IntanRecordingExtractor(NeoBaseRecordingExtractor): @@ -64,43 +64,23 @@ def __init__( **neo_kwargs, ) - amplifier_streams = ["RHS2000 amplifier channel", "RHD2000 amplifier channel"] - if self.stream_name in amplifier_streams: - self._add_channel_groups() - - self._kwargs.update( - dict(file_path=str(Path(file_path).resolve()), ignore_integrity_checks=ignore_integrity_checks), - ) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) + if "ignore_integrity_checks" in neo_kwargs: + self._kwargs["ignore_integrity_checks"] = neo_kwargs["ignore_integrity_checks"] @classmethod def map_to_neo_kwargs(cls, file_path, ignore_integrity_checks: bool = False): - neo_kwargs = {"filename": str(file_path), "ignore_integrity_checks": ignore_integrity_checks} - return neo_kwargs - - def _add_channel_groups(self): + # Only propagate the argument if the version is greater than 0.13.1 + import packaging + import neo - num_channels = self.get_num_channels() - groups = np.zeros(shape=num_channels, dtype="uint16") - group_names = np.zeros(shape=num_channels, dtype="str") - - signal_header = self.neo_reader.header["signal_channels"] - amplifier_signal_header = signal_header[signal_header["stream_id"] == self.stream_id] - original_ids = amplifier_signal_header["id"] - - # The hard-coded IDS of intan ids is "Port-Number" (e.g. A-001, C-017, B-020, etc) for amplifier channels - channel_ports = [id[:1] for id in original_ids if id[1] == "-"] - - # This should be A, B, C, D, ... - amplifier_ports = np.unique(channel_ports).tolist() - - for port in amplifier_ports: - channel_index = np.where(np.array(channel_ports) == port) - group_names[channel_index] = port - groups[channel_index] = amplifier_ports.index(port) - - self.set_channel_groups(groups) - self.set_property(key="group_names", values=group_names) + neo_version = packaging.version.parse(neo.__version__) + if neo_version > packaging.version.parse("0.13.1"): + neo_kwargs = {"filename": str(file_path), "ignore_integrity_checks": ignore_integrity_checks} + else: + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs read_intan = define_function_from_class(source_class=IntanRecordingExtractor, name="read_intan") diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index ca4acfed77..a916d140fb 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -73,8 +73,8 @@ def get_streams(cls, *args, **kwargs): neo_reader = cls.get_neo_io_reader(cls.NeoRawIOClass, **neo_kwargs) stream_channels = neo_reader.header["signal_streams"] - stream_names = stream_channels["name"].tolist() - stream_ids = stream_channels["id"].tolist() + stream_names = list(stream_channels["name"]) + stream_ids = list(stream_channels["id"]) return stream_names, stream_ids def build_stream_id_to_sampling_frequency_dict(self) -> Dict[str, float]: @@ -200,13 +200,13 @@ def __init__( use_names_as_ids = False stream_channels = self.neo_reader.header["signal_streams"] - stream_names = stream_channels["name"].tolist() - stream_ids = stream_channels["id"].tolist() + stream_names = list(stream_channels["name"]) + stream_ids = list(stream_channels["id"]) if stream_id is None and stream_name is None: if stream_channels.size > 1: raise ValueError( - f"This reader have several streams: \n`stream_names`: {stream_names}\n`stream_ids`: {stream_ids}. \n" + f"This reader have several streams: \nNames: {stream_names}\nIDs: {stream_ids}. \n" f"Specify it from the options above with the 'stream_name' or 'stream_id' arguments" ) else: @@ -246,52 +246,36 @@ def __init__( self.extra_requirements.append("neo") # find the gain to uV - neo_gains = signal_channels["gain"] - neo_offsets = signal_channels["offset"] - if dtype.kind == "i" and np.all(neo_gains < 0) and np.all(neo_offsets == 0): + gains = signal_channels["gain"] + offsets = signal_channels["offset"] + + if dtype.kind == "i" and np.all(gains < 0) and np.all(offsets == 0): # special hack when all channel have negative gain: we put back the gain positive # this help the end user experience self.inverted_gain = True - neo_gains = -neo_gains + gains = -gains else: self.inverted_gain = False - # Define standard voltage units and their conversion factors to uV - voltage_units_to_gains = {"V": 1e6, "Volt": 1e6, "Volts": 1e6, "mV": 1e3, "uV": 1.0} - - channel_units = signal_channels["units"] - fill_value = 1.0 # This should be np.nan but will break a lot of tests - gain_correction = np.full(shape=channel_units.size, fill_value=fill_value) - for unit, gain in voltage_units_to_gains.items(): - gain_correction[channel_units == unit] = gain - - # Note that gain_to_uV should be undefined (np.nan) for non-voltage units - gain_to_uV = neo_gains * gain_correction - offset_to_uV = neo_offsets * gain_correction - - self.set_property("gain_to_uV", gain_to_uV) - self.set_property("offset_to_uV", offset_to_uV) - - # Add machinery to keep the neo units for downstream users - self.set_property("original_unit", channel_units) - self.set_property("original_gain", neo_gains) - self.set_property("original_offset", neo_offsets) - - # Streams with mixed units are to be used with caution - # We warn the user when this is the case - # Eventually, this should not be allowed as streams should have the same units - supported_voltage_units = list(voltage_units_to_gains.keys()) - is_channel_in_voltage = np.isin(channel_units, supported_voltage_units) - self.has_voltage_channels = np.any(is_channel_in_voltage) - self.has_non_voltage_channels = not np.all(is_channel_in_voltage) - has_mixed_units = self.has_non_voltage_channels and self.has_voltage_channels - if has_mixed_units: - warning_msg = ( - "Found a mix of voltage and non-voltage units. " - 'Proceed with caution. Check channel units with `recording.get_property("original_unit")`.' - ) - warnings.warn(warning_msg) + units = signal_channels["units"] + + # mark that units are V, mV or uV + self.has_non_standard_units = False + if not np.all(np.isin(units, ["V", "Volt", "mV", "uV"])): + self.has_non_standard_units = True + + additional_gain = np.ones(units.size, dtype="float") + additional_gain[units == "V"] = 1e6 + additional_gain[units == "Volt"] = 1e6 + additional_gain[units == "mV"] = 1e3 + additional_gain[units == "uV"] = 1.0 + additional_gain = additional_gain + + final_gains = gains * additional_gain + final_offsets = offsets * additional_gain + self.set_property("gain_to_uV", final_gains) + self.set_property("offset_to_uV", final_offsets) if not use_names_as_ids: self.set_property("channel_names", signal_channels["name"]) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 53c66cd0c9..dd24e6cae7 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -164,7 +164,7 @@ def __init__( **neo_kwargs, ) # get streams to find correct probe - stream_names, stream_ids = self.get_streams(folder_path, load_sync_channel, experiment_names) + stream_names, stream_ids = self.get_streams(folder_path, experiment_names) if stream_name is None and stream_id is None: stream_name = stream_names[0] elif stream_name is None: diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index b131005152..c79627bb59 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -12,7 +12,6 @@ @pytest.mark.streaming_extractors -@pytest.mark.xfail(reason="We need to fix ibllib/one-api dependency") class TestDefaultIblRecordingExtractorApBand(TestCase): @classmethod def setUpClass(cls): @@ -107,7 +106,6 @@ def test_unscaled_trace_dtype(self): @pytest.mark.streaming_extractors -@pytest.mark.xfail(reason="We need to fix ibllib/one-api dependency") class TestIblStreamingRecordingExtractorApBandWithLoadSyncChannel(TestCase): @classmethod def setUpClass(cls): @@ -182,7 +180,6 @@ def test_unscaled_trace_dtype(self): @pytest.mark.streaming_extractors -@pytest.mark.xfail(reason="We need to fix ibllib/one-api dependency") class TestIblSortingExtractor(TestCase): def test_ibl_sorting_extractor(self): """ @@ -200,9 +197,9 @@ def test_ibl_sorting_extractor(self): ) except: pytest.skip("Skipping test due to server being down.") - sorting = read_ibl_sorting(pid=PID, one=one, revision="2023-12-05") + sorting = read_ibl_sorting(pid=PID, one=one) assert len(sorting.unit_ids) == 733 - sorting_good = read_ibl_sorting(pid=PID, good_clusters_only=True, one=one, revision="2023-12-05") + sorting_good = read_ibl_sorting(pid=PID, good_clusters_only=True) assert len(sorting_good.unit_ids) == 108 # check properties @@ -211,14 +208,14 @@ def test_ibl_sorting_extractor(self): assert "brain_area" in sorting_good.get_property_keys() # load without properties - sorting_no_properties = read_ibl_sorting(pid=PID, one=one, load_unit_properties=False, revision="2023-12-05") + sorting_no_properties = read_ibl_sorting(pid=PID, load_unit_properties=False) # check properties assert "firing_rate" not in sorting_no_properties.get_property_keys() if __name__ == "__main__": - TestDefaultIblRecordingExtractorApBand.setUpClass() - test1 = TestDefaultIblRecordingExtractorApBand() + TestDefaultIblStreamingRecordingExtractorApBand.setUpClass() + test1 = TestDefaultIblStreamingExtractorApBand() test1.setUp() test1.test_get_stream_names() test1.test_dtype() @@ -233,7 +230,7 @@ def test_ibl_sorting_extractor(self): test1.test_unscaled_trace_dtype() TestIblStreamingRecordingExtractorApBandWithLoadSyncChannel.setUpClass() - test2 = TestIblStreamingRecordingExtractorApBandWithLoadSyncChannel() + test2 = TestIblStreamingExtractorApBandWithLoadSyncChannel() test2.setUp() test2.test_get_stream_names() test2.test_get_stream_names() diff --git a/src/spikeinterface/extractors/toy_example.py b/src/spikeinterface/extractors/toy_example.py index e4c32f5dad..55b787f3ed 100644 --- a/src/spikeinterface/extractors/toy_example.py +++ b/src/spikeinterface/extractors/toy_example.py @@ -135,7 +135,7 @@ def toy_example( assert isinstance(spike_labels, list) assert len(spike_times) == len(spike_labels) assert len(spike_times) == num_segments - sorting = NumpySorting.from_samples_and_labels(spike_times, spike_labels, sampling_frequency, unit_ids=unit_ids) + sorting = NumpySorting.from_times_labels(spike_times, spike_labels, sampling_frequency, unit_ids=unit_ids) else: sorting = generate_sorting( num_units=num_units, diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 1197adcfb3..2601e8081c 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -363,7 +363,7 @@ def solve_monopolar_triangulation(wf_data, local_contact_locations, max_distance output = scipy.optimize.least_squares(estimate_distance_error, x0=x0, bounds=bounds, args=args) return tuple(output["x"]) except Exception as e: - warnings.warn(f"scipy.optimize.least_squares error: {e}") + print(f"scipy.optimize.least_squares error: {e}") return (np.nan, np.nan, np.nan, np.nan) if optimizer == "minimize_with_log_penality": @@ -378,7 +378,7 @@ def solve_monopolar_triangulation(wf_data, local_contact_locations, max_distance alpha = (wf_data * q).sum() / np.square(q).sum() return (*output["x"], alpha) except Exception as e: - warnings.warn(f"scipy.optimize.minimize error: {e}") + print(f"scipy.optimize.minimize error: {e}") return (np.nan, np.nan, np.nan, np.nan) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 577dc948c3..2efac0e0d0 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -31,6 +31,7 @@ class ComputeSpikeAmplitudes(AnalyzerExtension): spike_retriver_kwargs : dict A dictionary to control the behavior for getting the maximum channel for each spike This dictionary contains: + * channel_from_template: bool, default: True For each spike is the maximum channel computed from template or re estimated at every spikes channel_from_template = True is old behavior but less acurate diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 3bdfb38e06..0431c8d675 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -193,12 +193,12 @@ def test_compute_correlograms(fill_all_bins, on_time_bin, multi_segment): ) if multi_segment: - sorting = NumpySorting.from_samples_and_labels( - samples_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency + sorting = NumpySorting.from_times_labels( + times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency ) else: - sorting = NumpySorting.from_samples_and_labels( - samples_list=[spike_times, spike_times], + sorting = NumpySorting.from_times_labels( + times_list=[spike_times, spike_times], labels_list=[spike_unit_indices, spike_unit_indices], sampling_frequency=sampling_frequency, ) @@ -239,8 +239,8 @@ def test_compute_correlograms_different_units(method): window_ms = 40 bin_ms = 5 - sorting = NumpySorting.from_samples_and_labels( - samples_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency + sorting = NumpySorting.from_times_labels( + times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency ) result, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 034d827c81..1c1ee8bafa 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -168,9 +168,7 @@ def __init__( assert ( ms_before is not None and ms_after is not None ), f"ms_before/after should not be None for mode {mode}" - sorting = NumpySorting.from_samples_and_labels( - list_triggers, list_labels, recording.get_sampling_frequency() - ) + sorting = NumpySorting.from_times_labels(list_triggers, list_labels, recording.get_sampling_frequency()) nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index d2e005682e..18b49cd862 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -316,7 +316,7 @@ def _sorting_violation(): spike_labels = spike_labels[mask] unit_ids = ["a", "b", "c"] - sorting = NumpySorting.from_samples_and_labels(spike_times, spike_labels, sampling_frequency, unit_ids=unit_ids) + sorting = NumpySorting.from_times_labels(spike_times, spike_labels, sampling_frequency, unit_ids=unit_ids) return sorting diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 28b4bdedb7..0f44e4079a 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -227,7 +227,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # keep positive labels keep = peak_labels >= 0 - sorting_final = NumpySorting.from_samples_and_labels( + sorting_final = NumpySorting.from_times_labels( peaks["sample_index"][keep], peak_labels[keep], sampling_frequency ) sorting_final = sorting_final.save(folder=sorter_output_folder / "sorting") diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index da2a6f3807..65dfb2ed45 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -191,7 +191,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): new_peaks["sample_index"] -= peak_shifts mask = clustering_label >= 0 - sorting_pre_peeler = NumpySorting.from_samples_and_labels( + sorting_pre_peeler = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], clustering_label[mask], sampling_frequency, diff --git a/src/spikeinterface/sortingcomponents/clustering/method_list.py b/src/spikeinterface/sortingcomponents/clustering/method_list.py index 94a502c2a0..f735291127 100644 --- a/src/spikeinterface/sortingcomponents/clustering/method_list.py +++ b/src/spikeinterface/sortingcomponents/clustering/method_list.py @@ -24,11 +24,3 @@ "circus": CircusClustering, "tdc_clustering": TdcClustering, } - -try: - # Kilosort licence (GPL 3) is forcing us to make and use an external package - from spikeinterface_kilosort_components import KiloSortClustering - - clustering_methods["kilosort_clustering"] = KiloSortClustering -except ImportError: - pass diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index a954474f5d..20067a2eec 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -161,7 +161,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) - sorting = NumpySorting.from_samples_and_labels(spikes["sample_index"], spikes["unit_index"], fs) + sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) nbefore = int(params["ms_before"] * fs / 1000.0) nafter = int(params["ms_after"] * fs / 1000.0) diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index f2d2e461b5..72aabc07de 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -102,7 +102,7 @@ def test_find_spikes_from_templates(method, sorting_analyzer): gt_sorting = sorting_analyzer.sorting - sorting = NumpySorting.from_samples_and_labels( + sorting = NumpySorting.from_times_labels( spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency ) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index f6e38b3f5e..8fe310d986 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -19,9 +19,8 @@ class AllAmplitudesDistributionsWidget(BaseWidget): The SortingAnalyzer unit_ids : list List of unit ids, default None - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : None or dict + Dict of colors with key : unit, value : color, default None """ def __init__( diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 956c0d3c11..ac73c57249 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -3,16 +3,13 @@ import numpy as np from warnings import warn -from .rasters import BaseRasterWidget from .base import BaseWidget, to_attr from .utils import get_some_colors from ..core.sortinganalyzer import SortingAnalyzer -from spikeinterface.core import SortingAnalyzer - -class AmplitudesWidget(BaseRasterWidget): +class AmplitudesWidget(BaseWidget): """ Plots spike amplitudes @@ -22,17 +19,10 @@ class AmplitudesWidget(BaseRasterWidget): The input waveform extractor unit_ids : list or None, default: None List of unit ids - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. segment_index : int or None, default: None The segment index (or None if mono-segment) max_spikes_per_unit : int or None, default: None Number of max spikes per unit to display. Use None for all spikes - y_lim : tuple or None, default: None - The min and max depth to display, if None (min and max of the amplitudes). - scatter_decimate : int, default: 1 - If equal to n, each nth spike is kept for plotting. hide_unit_selector : bool, default: False If True the unit selector is not displayed (sortingview backend) @@ -41,7 +31,7 @@ class AmplitudesWidget(BaseRasterWidget): (matplotlib backend) bins : int or None, default: None If plot_histogram is True, the number of bins for the amplitude histogram. - If None, uses 100 bins. + If None this is automatically adjusted plot_legend : bool, default: True True includes legend in plot """ @@ -53,8 +43,6 @@ def __init__( unit_colors=None, segment_index=None, max_spikes_per_unit=None, - y_lim=None, - scatter_decimate=1, hide_unit_selector=False, plot_histograms=False, bins=None, @@ -73,21 +61,25 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids + if unit_colors is None: + unit_colors = get_some_colors(sorting.unit_ids) + if sorting.get_num_segments() > 1: if segment_index is None: - warn("More than one segment available! Using `segment_index = 0`.") + warn("More than one segment available! Using segment_index 0") segment_index = 0 else: segment_index = 0 - amplitudes_segment = amplitudes[segment_index] total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency - all_spiketrains = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) - for unit_id in sorting.unit_ids - } + spiketrains_segment = {} + for i, unit_id in enumerate(sorting.unit_ids): + times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + times = times / sorting.get_sampling_frequency() + spiketrains_segment[unit_id] = times + all_spiketrains = spiketrains_segment all_amplitudes = amplitudes_segment if max_spikes_per_unit is not None: spiketrains_to_plot = dict() @@ -109,21 +101,163 @@ def __init__( bins = 100 plot_data = dict( - spike_train_data=spiketrains_to_plot, - y_axis_data=amplitudes_to_plot, + sorting_analyzer=sorting_analyzer, + amplitudes=amplitudes_to_plot, + unit_ids=unit_ids, unit_colors=unit_colors, + spiketrains=spiketrains_to_plot, + total_duration=total_duration, plot_histograms=plot_histograms, bins=bins, - total_duration=total_duration, - unit_ids=unit_ids, hide_unit_selector=hide_unit_selector, plot_legend=plot_legend, - y_label="Amplitude", - y_lim=y_lim, - scatter_decimate=scatter_decimate, ) - BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + if backend_kwargs["axes"] is not None: + axes = backend_kwargs["axes"] + if dp.plot_histograms: + assert np.asarray(axes).size == 2 + else: + assert np.asarray(axes).size == 1 + elif backend_kwargs["ax"] is not None: + assert not dp.plot_histograms + else: + if dp.plot_histograms: + backend_kwargs["num_axes"] = 2 + backend_kwargs["ncols"] = 2 + else: + backend_kwargs["num_axes"] = None + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + scatter_ax = self.axes.flatten()[0] + + for unit_id in dp.unit_ids: + spiketrains = dp.spiketrains[unit_id] + amps = dp.amplitudes[unit_id] + scatter_ax.scatter(spiketrains, amps, color=dp.unit_colors[unit_id], s=3, alpha=1, label=unit_id) + + if dp.plot_histograms: + if dp.bins is None: + bins = int(len(spiketrains) / 30) + else: + bins = dp.bins + ax_hist = self.axes.flatten()[1] + # this is super slow, using plot and np.histogram is really much faster (and nicer!) + # ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + count, bins = np.histogram(amps, bins=bins) + ax_hist.plot(count, bins[:-1], color=dp.unit_colors[unit_id], alpha=0.8) + + if dp.plot_histograms: + ax_hist = self.axes.flatten()[1] + ax_hist.set_ylim(scatter_ax.get_ylim()) + ax_hist.axis("off") + # self.figure.tight_layout() + + if dp.plot_legend: + if hasattr(self, "legend") and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + scatter_ax.set_xlim(0, dp.total_duration) + scatter_ax.set_xlabel("Times [s]") + scatter_ax.set_ylabel(f"Amplitude") + scatter_ax.spines["top"].set_visible(False) + scatter_ax.spines["right"].set_visible(False) + self.figure.subplots_adjust(bottom=0.1, top=0.9, left=0.1) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + + # import ipywidgets.widgets as widgets + import ipywidgets.widgets as W + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + analyzer = data_plot["sorting_analyzer"] + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = W.Output() + with output: + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + self.unit_selector = UnitSelector(analyzer.unit_ids) + self.unit_selector.value = list(analyzer.unit_ids)[:1] + + self.checkbox_histograms = W.Checkbox( + value=data_plot["plot_histograms"], + description="hist", + ) + + left_sidebar = W.VBox( + children=[ + self.unit_selector, + self.checkbox_histograms, + ], + layout=W.Layout(align_items="center", width="100%", height="100%"), + ) + + self.widget = W.AppLayout( + center=self.figure.canvas, + left_sidebar=left_sidebar, + pane_widths=ratios + [0], + ) + + # a first update + self._full_update_plot() + + self.unit_selector.observe(self._update_plot, names="value", type="change") + self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") + + if backend_kwargs["display"]: + display(self.widget) + + def _full_update_plot(self, change=None): + self.figure.clear() + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False + + backend_kwargs = dict(figure=self.figure, axes=None, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + self._update_plot() + + def _update_plot(self, change=None): + for ax in self.axes.flatten(): + ax.clear() + + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False + + backend_kwargs = dict(figure=None, axes=self.axes, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv @@ -136,8 +270,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): sa_items = [ vv.SpikeAmplitudesItem( unit_id=u, - spike_times_sec=dp.spike_train_data[u].astype("float32"), - spike_amplitudes=dp.y_axis_data[u].astype("float32"), + spike_times_sec=dp.spiketrains[u].astype("float32"), + spike_amplitudes=dp.amplitudes[u].astype("float32"), ) for u in unit_ids ] diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index ae07b79e6d..26780ce124 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -31,9 +31,8 @@ class CrossCorrelogramsWidget(BaseWidget): this argument is ignored hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values """ def __init__( diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index e1b1b423f2..813e7d7b63 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -24,9 +24,8 @@ class MetricsBaseWidget(BaseWidget): If given, a list of quality metrics to skip, default: None include_metrics: list or None, default: None If given, a list of quality metrics to include, default: None - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed include_metrics_data : bool, default: True diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index a0c7e1e28c..187344f1c8 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -5,7 +5,6 @@ from .base import BaseWidget, to_attr from spikeinterface.core import BaseRecording, SortingAnalyzer -from .rasters import BaseRasterWidget from spikeinterface.core.motion import Motion @@ -98,7 +97,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_ylabel("Depth [um]") -class DriftRasterMapWidget(BaseRasterWidget): +class DriftRasterMapWidget(BaseWidget): """ Plot the drift raster map from peaks or a SortingAnalyzer. The drift raster map is a scatter plot of the estimated peak depth vs time and it is @@ -128,7 +127,7 @@ class DriftRasterMapWidget(BaseRasterWidget): depth_lim : tuple or None, default: None The min and max depth to display, if None (min and max of the recording). scatter_decimate : int, default: None - If equal to n, each nth spike is kept for plotting. + If > 1, the scatter points are decimated. color_amplitude : bool, default: True If True, the color of the scatter points is the amplitude of the peaks. cmap : str, default: "inferno" @@ -171,7 +170,6 @@ def __init__( if sorting_analyzer is not None: if sorting_analyzer.has_recording(): recording = sorting_analyzer.recording - sampling_frequency = recording.sampling_frequency else: recording = None sampling_frequency = sorting_analyzer.sampling_frequency @@ -202,14 +200,56 @@ def __init__( if peak_amplitudes is not None: peak_amplitudes = peak_amplitudes[peak_mask] - from matplotlib.pyplot import colormaps + plot_data = dict( + peaks=peaks, + peak_locations=peak_locations, + peak_amplitudes=peak_amplitudes, + direction=direction, + sampling_frequency=sampling_frequency, + segment_index=segment_index, + depth_lim=depth_lim, + color_amplitude=color_amplitude, + color=color, + scatter_decimate=scatter_decimate, + cmap=cmap, + clim=clim, + alpha=alpha, + recording=recording, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from matplotlib.colors import Normalize + from .utils_matplotlib import make_mpl_figure + + from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks - if color_amplitude: - amps = peak_amplitudes + dp = to_attr(data_plot) + + assert backend_kwargs["axes"] is None, "axes argument is not allowed in DriftRasterMapWidget. Use ax instead." + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + if dp.recording is None: + peak_times = dp.peaks["sample_index"] / dp.sampling_frequency + else: + peak_times = dp.recording.sample_index_to_time(dp.peaks["sample_index"], segment_index=dp.segment_index) + + peak_locs = dp.peak_locations[dp.direction] + if dp.scatter_decimate is not None: + peak_times = peak_times[:: dp.scatter_decimate] + peak_locs = peak_locs[:: dp.scatter_decimate] + + if dp.color_amplitude: + amps = dp.peak_amplitudes amps_abs = np.abs(amps) q_95 = np.quantile(amps_abs, 0.95) - cmap = colormaps[cmap] - if clim is None: + if dp.scatter_decimate is not None: + amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] + cmap = plt.colormaps[dp.cmap] + if dp.clim is None: amps = amps_abs amps /= q_95 c = cmap(amps) @@ -219,26 +259,17 @@ def __init__( color_kwargs = dict( color=None, c=c, - alpha=alpha, + alpha=dp.alpha, ) else: - color_kwargs = dict(color=color, c=None, alpha=alpha) - - # convert data into format that `BaseRasterWidget` can take it in - spike_train_data = {0: peaks["sample_index"] / sampling_frequency} - y_axis_data = {0: peak_locations[direction]} - - plot_data = dict( - spike_train_data=spike_train_data, - y_axis_data=y_axis_data, - y_lim=depth_lim, - color_kwargs=color_kwargs, - scatter_decimate=scatter_decimate, - title="Peak depth", - y_label="Depth [um]", - ) + color_kwargs = dict(color=dp.color, c=None, alpha=dp.alpha) - BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) + self.ax.scatter(peak_times, peak_locs, s=1, **color_kwargs) + if dp.depth_lim is not None: + self.ax.set_ylim(*dp.depth_lim) + self.ax.set_title("Peak depth") + self.ax.set_xlabel("Times [s]") + self.ax.set_ylabel("Depth [$\\mu$m]") class MotionInfoWidget(BaseWidget): @@ -264,7 +295,7 @@ class MotionInfoWidget(BaseWidget): motion_lim : tuple or None, default: None The min and max motion to display, if None (min and max of the motion). scatter_decimate : int, default: None - If equal to n, each nth spike is kept for plotting. + If > 1, the scatter points are decimated. color_amplitude : bool, default: False If True, the color of the scatter points is the amplitude of the peaks. amplitude_cmap : str, default: "inferno" diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 7d1ff44326..d2625451c8 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -18,9 +18,8 @@ class QualityMetricsWidget(MetricsBaseWidget): If given, a list of quality metrics to include skip_metrics : list or None, default: None If given, a list of quality metrics to skip - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed """ diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 398ae4d728..ca579c975f 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -4,284 +4,16 @@ from warnings import warn from .base import BaseWidget, to_attr, default_backend_kwargs -from .utils import get_some_colors -class BaseRasterWidget(BaseWidget): - """ - Make a raster plot with spike times on the x axis and arbitrary data on the y axis. - Can customize plot with histograms, title, labels, ticks etc. - - - Parameters - ---------- - spike_train_data : dict - A dict of spike trains, indexed by the unit_id - y_axis_data : dict - A dict of the y-axis data, indexed by the unit_id - unit_ids : array-like | None, default: None - List of unit_ids to plot - total_duration : int | None, default: None - Duration of spike_train_data in seconds. - plot_histograms : bool, default: False - Plot histogram of y-axis data in another subplot - bins : int | None, default: None - Number of bins to use in histogram. If None, use 1/30 of spike train sample length. - scatter_decimate : int | None, default: None - If equal to n, each nth spike is kept for plotting. - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. - color_kwargs : dict | None, default: None - More color control for e.g. coloring spikes by property. Passed to `matplotlib.scatter`. - plot_legend : bool, default: False - If True, the legend is plotted - x_lim : tuple or None, default: None - The min and max width to display, if None use (0, total_duration) - y_lim : tuple or None, default: None - The min and max depth to display, if None use the min and max of y_axis_data. - title : str | None, default: None - Title of plot. If None, no title is displayed. - y_label : str | None, default: None - Label of y-axis. If None, no label is displayed. - y_ticks : dict | None, default: None - Ticks on y-axis, passed to `set_yticks`. If None, default ticks are used. - hide_unit_selector : bool, default: False - For sortingview backend, if True the unit selector is not displayed - backend : str | None, default None - Which plotting backend to use e.g. 'matplotlib', 'ipywidgets'. If None, uses - default from `get_default_plotter_backend`. - """ - - def __init__( - self, - spike_train_data: dict, - y_axis_data: dict, - unit_ids: list | None = None, - total_duration: int | None = None, - plot_histograms: bool = False, - bins: int | None = None, - scatter_decimate: int = 1, - unit_colors: dict | None = None, - color_kwargs: dict | None = None, - plot_legend: bool | None = False, - y_lim: tuple[float, float] | None = None, - x_lim: tuple[float, float] | None = None, - title: str | None = None, - y_label: str | None = None, - y_ticks: bool = False, - hide_unit_selector: bool = True, - backend: str | None = None, - **backend_kwargs, - ): - - plot_data = dict( - spike_train_data=spike_train_data, - y_axis_data=y_axis_data, - unit_ids=unit_ids, - plot_histograms=plot_histograms, - y_lim=y_lim, - x_lim=x_lim, - scatter_decimate=scatter_decimate, - color_kwargs=color_kwargs, - unit_colors=unit_colors, - y_label=y_label, - title=title, - total_duration=total_duration, - plot_legend=plot_legend, - bins=bins, - y_ticks=y_ticks, - hide_unit_selector=hide_unit_selector, - ) - - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from matplotlib.colors import Normalize - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - - if dp.unit_colors is None and dp.color_kwargs is None: - unit_colors = get_some_colors(dp.spike_train_data.keys()) - else: - unit_colors = dp.unit_colors - - if backend_kwargs["axes"] is not None: - axes = backend_kwargs["axes"] - if dp.plot_histograms: - assert np.asarray(axes).size == 2 - else: - assert np.asarray(axes).size == 1 - elif backend_kwargs["ax"] is not None: - assert not dp.plot_histograms - else: - if dp.plot_histograms: - backend_kwargs["num_axes"] = 2 - backend_kwargs["ncols"] = 2 - else: - backend_kwargs["num_axes"] = None - - unit_ids = dp.unit_ids - if dp.unit_ids is None: - unit_ids = dp.spike_train_data.keys() - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - scatter_ax = self.axes.flatten()[0] - - spike_train_data = dp.spike_train_data - y_axis_data = dp.y_axis_data - - for unit_id in unit_ids: - - unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate] - unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate] - - if dp.color_kwargs is None: - scatter_ax.scatter(unit_spike_train, unit_y_data, s=1, label=unit_id, color=unit_colors[unit_id]) - else: - color_kwargs = dp.color_kwargs - if dp.scatter_decimate != 1 and color_kwargs.get("c") is not None: - color_kwargs["c"] = dp.color_kwargs["c"][:: dp.scatter_decimate] - scatter_ax.scatter(unit_spike_train, unit_y_data, s=1, label=unit_id, **color_kwargs) - - if dp.plot_histograms: - if dp.bins is None: - bins = int(len(unit_spike_train) / 30) - else: - bins = dp.bins - ax_hist = self.axes.flatten()[1] - count, bins = np.histogram(unit_y_data, bins=bins) - ax_hist.plot(count, bins[:-1], color=unit_colors[unit_id], alpha=0.8) - - if dp.plot_histograms: - ax_hist = self.axes.flatten()[1] - ax_hist.set_ylim(scatter_ax.get_ylim()) - ax_hist.axis("off") - - if dp.plot_legend: - if hasattr(self, "legend") and self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - if dp.y_lim is not None: - scatter_ax.set_ylim(*dp.y_lim) - x_lim = dp.x_lim - if x_lim is None: - x_lim = [0, dp.total_duration] - scatter_ax.set_xlim(x_lim) - - if dp.y_ticks: - scatter_ax.set_yticks(**dp.y_ticks) - - scatter_ax.set_title(dp.title) - scatter_ax.set_xlabel("Times [s]") - scatter_ax.set_ylabel(dp.y_label) - scatter_ax.spines["top"].set_visible(False) - scatter_ax.spines["right"].set_visible(False) - - def plot_ipywidgets(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - - import ipywidgets.widgets as W - from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, UnitSelector - - check_ipywidget_backend() - - self.next_data_plot = data_plot.copy() - - cm = 1 / 2.54 - - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = W.Output() - with output: - self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - self.unit_selector = UnitSelector(list(data_plot["spike_train_data"].keys())) - self.unit_selector.value = list(data_plot["spike_train_data"].keys())[:1] - - children = [self.unit_selector] - - if data_plot["plot_histograms"] is not None: - self.checkbox_histograms = W.Checkbox( - value=data_plot["plot_histograms"], - description="hist", - ) - children.append(self.checkbox_histograms) - - left_sidebar = W.VBox( - children=children, - layout=W.Layout(align_items="center", width="100%", height="100%"), - ) - - self.widget = W.AppLayout( - center=self.figure.canvas, - left_sidebar=left_sidebar, - pane_widths=ratios + [0], - ) - - # a first update - self._full_update_plot() - - self.unit_selector.observe(self._update_plot, names="value", type="change") - if data_plot["plot_histograms"] is not None: - self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") - - if backend_kwargs["display"]: - display(self.widget) - - def _full_update_plot(self, change=None): - self.figure.clear() - data_plot = self.next_data_plot - data_plot["unit_ids"] = self.unit_selector.value - if data_plot["plot_histograms"] is not None: - data_plot["plot_histograms"] = self.checkbox_histograms.value - data_plot["plot_legend"] = False - - backend_kwargs = dict(figure=self.figure, axes=None, ax=None) - self.plot_matplotlib(data_plot, **backend_kwargs) - self._update_plot() - - def _update_plot(self, change=None): - for ax in self.axes.flatten(): - ax.clear() - - data_plot = self.next_data_plot - data_plot["unit_ids"] = self.unit_selector.value - if data_plot["plot_histograms"] is not None: - data_plot["plot_histograms"] = self.checkbox_histograms.value - data_plot["plot_legend"] = False - - backend_kwargs = dict(figure=None, axes=self.axes, ax=None) - self.plot_matplotlib(data_plot, **backend_kwargs) - - self.figure.canvas.draw() - self.figure.canvas.flush_events() - - -import numpy as np - - -class RasterWidget(BaseRasterWidget): +class RasterWidget(BaseWidget): """ Plots spike train rasters. Parameters ---------- - sorting : SortingExtractor | None, default: None - A sorting object - sorting_analyzer : SortingAnalyzer | None, default: None - A sorting analyzer object + sorting : SortingExtractor + The sorting extractor object segment_index : None or int The segment index. unit_ids : list @@ -293,67 +25,64 @@ class RasterWidget(BaseRasterWidget): """ def __init__( - self, - sorting=None, - sorting_analyzer=None, - segment_index=None, - unit_ids=None, - time_range=None, - color="k", - backend=None, - **backend_kwargs, + self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs ): - if sorting is None and sorting_analyzer is None: - raise Exception("Must supply either a sorting or a sorting_analyzer") - elif sorting is not None and sorting_analyzer is not None: - raise Exception("Should supply either a sorting or a sorting_analyzer, not both") - elif sorting_analyzer is not None: - sorting = sorting_analyzer.sorting - sorting = self.ensure_sorting(sorting) - if sorting.get_num_segments() > 1: - if segment_index is None: - warn("More than one segment available! Using `segment_index = 0`.") - segment_index = 0 - else: + if segment_index is None: + if sorting.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") segment_index = 0 - if unit_ids is None: - unit_ids = sorting.unit_ids - - all_spiketrains = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) - for unit_id in unit_ids - } - - if time_range is not None: + if time_range is None: + frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]] + time_range = [f / sorting.sampling_frequency for f in frame_range] + else: assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - for unit_id in unit_ids: - unit_st = all_spiketrains[unit_id] - all_spiketrains[unit_id] = unit_st[(time_range[0] < unit_st) & (unit_st < time_range[1])] - - raster_locations = { - unit_id: unit_index * np.ones(len(all_spiketrains[unit_id])) for unit_index, unit_id in enumerate(unit_ids) - } - - unit_indices = list(range(len(unit_ids))) - - if color is None: - color = "black" - - unit_colors = {unit_id: color for unit_id in unit_ids} - y_ticks = {"ticks": unit_indices, "labels": unit_ids} + frame_range = [int(t * sorting.sampling_frequency) for t in time_range] plot_data = dict( - spike_train_data=all_spiketrains, - y_axis_data=raster_locations, - x_lim=time_range, - y_label="Unit id", + sorting=sorting, + segment_index=segment_index, unit_ids=unit_ids, - unit_colors=unit_colors, - plot_histograms=None, - y_ticks=y_ticks, + color=color, + frame_range=frame_range, + time_range=time_range, ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + sorting = dp.sorting + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) + units_ids = dp.unit_ids + if units_ids is None: + units_ids = sorting.unit_ids + + with plt.rc_context({"axes.edgecolor": "gray"}): + for unit_index, unit_id in enumerate(units_ids): + spiketrain = sorting.get_unit_spike_train( + unit_id, + start_frame=dp.frame_range[0], + end_frame=dp.frame_range[1], + segment_index=dp.segment_index, + ) + spiketimes = spiketrain / float(sorting.sampling_frequency) + self.ax.plot( + spiketimes, + unit_index * np.ones_like(spiketimes), + marker="|", + mew=1, + markersize=3, + ls="", + color=dp.color, + ) + self.ax.set_yticks(np.arange(len(units_ids))) + self.ax.set_yticklabels(units_ids) + self.ax.set_xlim(*dp.time_range) + self.ax.set_xlabel("time (s)") diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index ada1546ac6..94c9def630 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -23,9 +23,8 @@ class SpikeLocationsWidget(BaseWidget): Number of max spikes per unit to display. Use None for all spikes. with_channel_ids : bool, default: False Add channel ids text on the probe - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed plot_all_units : bool, default: True diff --git a/src/spikeinterface/widgets/spike_locations_by_time.py b/src/spikeinterface/widgets/spike_locations_by_time.py deleted file mode 100644 index 89cc6227fe..0000000000 --- a/src/spikeinterface/widgets/spike_locations_by_time.py +++ /dev/null @@ -1,258 +0,0 @@ -from __future__ import annotations - -import numpy as np -from warnings import warn - -from .base import BaseWidget, to_attr -from .utils import get_some_colors - -from ..core.sortinganalyzer import SortingAnalyzer - - -class LocationsWidget(BaseWidget): - """ - Plots spike locations as a function of time - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - The input sorting analyzer - unit_ids : list or None, default: None - List of unit ids - segment_index : int or None, default: None - The segment index (or None if mono-segment) - max_spikes_per_unit : int or None, default: None - Number of max spikes per unit to display. Use None for all spikes - plot_histogram : bool, default: False - If True, an histogram of the locations is plotted on the right axis - (matplotlib backend) - bins : int or None, default: None - If plot_histogram is True, the number of bins for the location histogram. - If None this is automatically adjusted - plot_legend : bool, default: True - True includes legend in plot - locations_axis : str, default: 'y' - Which location axis to use when plotting locations. - """ - - def __init__( - self, - sorting_analyzer: SortingAnalyzer, - unit_ids=None, - unit_colors=None, - segment_index=None, - max_spikes_per_unit=None, - plot_histograms=False, - bins=None, - plot_legend=True, - locations_axis="y", - backend=None, - **backend_kwargs, - ): - - sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - - sorting = sorting_analyzer.sorting - self.check_extensions(sorting_analyzer, "spike_locations") - - locations = sorting_analyzer.get_extension("spike_locations").get_data(outputs="by_unit") - - if unit_ids is None: - unit_ids = sorting.unit_ids - - if unit_colors is None: - unit_colors = get_some_colors(sorting.unit_ids) - - if sorting.get_num_segments() > 1: - if segment_index is None: - warn("More than one segment available! Using segment_index 0") - segment_index = 0 - else: - segment_index = 0 - locations_segment = locations[segment_index] - total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency - - spiketrains_segment = {} - for i, unit_id in enumerate(sorting.unit_ids): - times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) - times = times / sorting.get_sampling_frequency() - spiketrains_segment[unit_id] = times - - all_spiketrains = spiketrains_segment - all_locations = locations_segment - if max_spikes_per_unit is not None: - spiketrains_to_plot = dict() - locations_to_plot = dict() - for unit, st in all_spiketrains.items(): - locs = all_locations[unit][locations_axis] - if len(st) > max_spikes_per_unit: - random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False) - spiketrains_to_plot[unit] = st[random_idxs] - locations_to_plot[unit] = locs[random_idxs] - else: - spiketrains_to_plot[unit] = st - locations_to_plot[unit] = locs - else: - spiketrains_to_plot = all_spiketrains - locations_to_plot = { - unit_id: all_locations[unit_id][locations_axis] for unit_id in sorting_analyzer.unit_ids - } - - if plot_histograms and bins is None: - bins = 100 - - plot_data = dict( - sorting_analyzer=sorting_analyzer, - locations=locations_to_plot, - unit_ids=unit_ids, - unit_colors=unit_colors, - spiketrains=spiketrains_to_plot, - total_duration=total_duration, - plot_histograms=plot_histograms, - bins=bins, - plot_legend=plot_legend, - ) - - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - - if backend_kwargs["axes"] is not None: - axes = backend_kwargs["axes"] - if dp.plot_histograms: - assert np.asarray(axes).size == 2 - else: - assert np.asarray(axes).size == 1 - elif backend_kwargs["ax"] is not None: - assert not dp.plot_histograms - else: - if dp.plot_histograms: - backend_kwargs["num_axes"] = 2 - backend_kwargs["ncols"] = 2 - else: - backend_kwargs["num_axes"] = None - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - scatter_ax = self.axes.flatten()[0] - - for unit_id in dp.unit_ids: - spiketrains = dp.spiketrains[unit_id] - locs = dp.locations[unit_id] - scatter_ax.scatter(spiketrains, locs, color=dp.unit_colors[unit_id], s=3, alpha=1, label=unit_id) - - if dp.plot_histograms: - if dp.bins is None: - bins = int(len(spiketrains) / 30) - else: - bins = dp.bins - ax_hist = self.axes.flatten()[1] - count, bins = np.histogram(locs, bins=bins) - ax_hist.plot(count, bins[:-1], color=dp.unit_colors[unit_id], alpha=0.8) - - if dp.plot_histograms: - ax_hist = self.axes.flatten()[1] - ax_hist.set_ylim(scatter_ax.get_ylim()) - ax_hist.axis("off") - # self.figure.tight_layout() - - if dp.plot_legend: - if hasattr(self, "legend") and self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - scatter_ax.set_xlim(0, dp.total_duration) - scatter_ax.set_xlabel("Times [s]") - scatter_ax.set_ylabel(f"Location [um]") - scatter_ax.spines["top"].set_visible(False) - scatter_ax.spines["right"].set_visible(False) - self.figure.subplots_adjust(bottom=0.1, top=0.9, left=0.1) - - def plot_ipywidgets(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - - # import ipywidgets.widgets as widgets - import ipywidgets.widgets as W - from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, UnitSelector - - check_ipywidget_backend() - - self.next_data_plot = data_plot.copy() - - cm = 1 / 2.54 - analyzer = data_plot["sorting_analyzer"] - - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = W.Output() - with output: - self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - self.unit_selector = UnitSelector(analyzer.unit_ids) - self.unit_selector.value = list(analyzer.unit_ids)[:1] - - self.checkbox_histograms = W.Checkbox( - value=data_plot["plot_histograms"], - description="hist", - ) - - left_sidebar = W.VBox( - children=[ - self.unit_selector, - self.checkbox_histograms, - ], - layout=W.Layout(align_items="center", width="100%", height="100%"), - ) - - self.widget = W.AppLayout( - center=self.figure.canvas, - left_sidebar=left_sidebar, - pane_widths=ratios + [0], - ) - - # a first update - self._full_update_plot() - - self.unit_selector.observe(self._update_plot, names="value", type="change") - self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") - - if backend_kwargs["display"]: - display(self.widget) - - def _full_update_plot(self, change=None): - self.figure.clear() - data_plot = self.next_data_plot - data_plot["unit_ids"] = self.unit_selector.value - data_plot["plot_histograms"] = self.checkbox_histograms.value - data_plot["plot_legend"] = False - - backend_kwargs = dict(figure=self.figure, axes=None, ax=None) - self.plot_matplotlib(data_plot, **backend_kwargs) - self._update_plot() - - def _update_plot(self, change=None): - for ax in self.axes.flatten(): - ax.clear() - - data_plot = self.next_data_plot - data_plot["unit_ids"] = self.unit_selector.value - data_plot["plot_histograms"] = self.checkbox_histograms.value - data_plot["plot_legend"] = False - - backend_kwargs = dict(figure=None, axes=self.axes, ax=None) - self.plot_matplotlib(data_plot, **backend_kwargs) - - self.figure.canvas.draw() - self.figure.canvas.flush_events() diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index f1d5891967..a8eb022847 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -32,9 +32,9 @@ class SpikesOnTracesWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply If SortingAnalyzer is already sparse, the argument is ignored - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values + If None, then the get_unit_colors() is internally used. (matplotlib backend) mode : "line" | "map" | "auto", default: "auto" * "line": classical for low channel count * "map": for high channel count use color heat map diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index 4df719eda5..b80c863e75 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -18,9 +18,8 @@ class TemplateMetricsWidget(MetricsBaseWidget): If given list of quality metrics to include skip_metrics : list or None or None, default: None If given, a list of quality metrics to skip - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed """ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index de31ce2993..d5ffec6dba 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -480,30 +480,6 @@ def test_plot_spike_locations(self): self.sorting_analyzer_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) - def test_plot_locations(self): - possible_backends = list(sw.LocationsWidget.get_possible_backends()) - for backend in possible_backends: - if backend not in self.skip_backends: - sw.plot_locations(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) - unit_ids = self.sorting_analyzer_dense.unit_ids[:4] - sw.plot_locations( - self.sorting_analyzer_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] - ) - sw.plot_locations( - self.sorting_analyzer_dense, - unit_ids=unit_ids, - plot_histograms=True, - backend=backend, - **self.backend_kwargs[backend], - ) - sw.plot_locations( - self.sorting_analyzer_sparse, - unit_ids=unit_ids, - plot_histograms=True, - backend=backend, - **self.backend_kwargs[backend], - ) - def test_plot_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index d41982f766..18d173fc36 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -18,9 +18,8 @@ class UnitDepthsWidget(BaseWidget): ---------- sorting_analyzer : SortingAnalyzer The SortingAnalyzer object - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values depth_axis : int, default: 1 The dimension of unit_locations that is depth peak_sign : "neg" | "pos" | "both", default: "neg" diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 37f9bb0491..3329c2183c 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -22,9 +22,8 @@ class UnitLocationsWidget(BaseWidget): List of unit ids with_channel_ids : bool, default: False Add channel ids text on the probe - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values hide_unit_selector : bool, default: False If True, the unit selector is not displayed (sortingview backend) plot_all_units : bool, default: True diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index b1c1682c8a..9466110110 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -26,21 +26,18 @@ class UnitSummaryWidget(BaseWidget): The SortingAnalyzer object unit_id : int or str The unit id to plot the summary of - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values, sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored subwidget_kwargs : dict or None, default: None Parameters for the subwidgets in a nested dictionary - - * unit_locations : UnitLocationsWidget (see UnitLocationsWidget for details) - * unit_waveforms : UnitWaveformsWidget (see UnitWaveformsWidget for details) - * unit_waveform_density_map : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) - * autocorrelograms : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) - * amplitudes : AmplitudesWidget (see AmplitudesWidget for details) - + unit_locations : UnitLocationsWidget (see UnitLocationsWidget for details) + unit_waveforms : UnitWaveformsWidget (see UnitWaveformsWidget for details) + unit_waveform_density_map : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes : AmplitudesWidget (see AmplitudesWidget for details) Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary. """ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index ee2158d78e..3b31eacee5 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -48,9 +48,9 @@ class UnitWaveformsWidget(BaseWidget): Line width for the waveforms, (matplotlib backend) lw_templates : float, default: 2 Line width for the templates, (matplotlib backend) - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : None or dict, default: None + A dict key is unit_id and value is any color format handled by matplotlib. + If None, then the get_unit_colors() is internally used. (matplotlib / ipywidgets backend) alpha_waveforms : float, default: 0.5 Alpha value for waveforms (matplotlib backend) alpha_templates : float, default: 1 @@ -307,7 +307,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): length_uv = int(np.ptp(wfs_for_scale) // 5) x_offset = xscale_bar[0] - np.ptp(xscale_bar) // 2 ax.plot([xscale_bar[0], xscale_bar[0]], [min_wfs - offset, min_wfs - offset + length], color="k") - ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\\mu$V", fontsize=8, rotation=90) + ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\mu$V", fontsize=8, rotation=90) # plot template if dp.plot_templates: @@ -379,7 +379,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): length_uv = int(np.ptp(template_for_scale) // 5) x_offset = xscale_bar[0] - np.ptp(xscale_bar) // 2 ax.plot([xscale_bar[0], xscale_bar[0]], [min_wfs - offset, min_wfs - offset + length], color="k") - ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\\mu$V", fontsize=8, rotation=90) + ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\mu$V", fontsize=8, rotation=90) # plot channels if dp.plot_channels: diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index f7da0ef1f3..6ef1a7a782 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -27,9 +27,9 @@ class UnitWaveformDensityMapWidget(BaseWidget): Use only the max channel peak_sign : "neg" | "pos" | "both", default: "neg" Used to detect max channel only when use_max_channel=True - unit_colors : dict | None, default: None - Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted - by matplotlib. If None, default colors are chosen using the `get_some_colors` function. + unit_colors : None or dict, default: None + A dict key is unit_id and value is any color format handled by matplotlib. + If None, then the get_unit_colors() is internally used same_axis : bool, default: False If True then all density are plot on the same axis and then channels is the union all channel per units diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 8590aab948..8163271ec4 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -20,7 +20,6 @@ from .rasters import RasterWidget from .sorting_summary import SortingSummaryWidget from .spike_locations import SpikeLocationsWidget -from .spike_locations_by_time import LocationsWidget from .spikes_on_traces import SpikesOnTracesWidget from .template_metrics import TemplateMetricsWidget from .template_similarity import TemplateSimilarityWidget @@ -47,7 +46,6 @@ CrossCorrelogramsWidget, DriftRasterMapWidget, ISIDistributionWidget, - LocationsWidget, MotionWidget, MotionInfoWidget, MultiCompGlobalAgreementWidget, @@ -86,28 +84,30 @@ for wcls in widget_list: wcls_doc = wcls.__doc__ - wcls_doc += """backend: str + wcls_doc += """ + + backend: str {backends} -**backend_kwargs: kwargs + **backend_kwargs: kwargs {backend_kwargs} -Returns -------- -w : BaseWidget - The output widget object. + + Returns + ------- + w : BaseWidget + The output widget object. """ - backend_str = "" + # backend_str = f" {list(wcls.possible_backends.keys())}" + backend_str = f" {wcls.get_possible_backends()}" backend_kwargs_str = "" # for backend, backend_plotter in wcls.possible_backends.items(): for backend in wcls.get_possible_backends(): - backend_str += f"\n * {backend}" # backend_kwargs_desc = backend_plotter.backend_kwargs_desc kwargs_desc = backend_kwargs_desc[backend] if len(kwargs_desc) > 0: - backend_kwargs_str += f"\n * {backend}:\n\n" + backend_kwargs_str += f"\n {backend}:\n\n" for bk, bk_dsc in kwargs_desc.items(): backend_kwargs_str += f" * {bk}: {bk_dsc}\n" - backend_str += "\n" wcls.__doc__ = wcls_doc.format(backends=backend_str, backend_kwargs=backend_kwargs_str) @@ -121,7 +121,6 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_drift_raster_map = DriftRasterMapWidget plot_isi_distribution = ISIDistributionWidget -plot_locations = LocationsWidget plot_motion = MotionWidget plot_motion_info = MotionInfoWidget plot_multicomparison_agreement = MultiCompGlobalAgreementWidget