From 1ebb6dba2bf16b3f56c14977b850b86cfc4b82f4 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 25 Feb 2025 10:42:55 +0100 Subject: [PATCH 01/35] Handle automatic RAM allocation for chunks --- src/spikeinterface/core/recording_tools.py | 2 +- .../sorters/internal/spyking_circus2.py | 7 +- src/spikeinterface/sortingcomponents/tools.py | 94 +++++++++++++++++-- 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index f7a2bce6a7..e5e4694708 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -541,7 +541,7 @@ def get_random_recording_slices( chunk_duration : str | float | None, default "500ms" The duration of each chunk in 's' or 'ms' chunk_size : int | None - Size of a chunk in number of frames. This is ued only if chunk_duration is None. + Size of a chunk in number of frames. This is used only if chunk_duration is None. This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. concatenated : bool, default: True If True chunk are concatenated along time axis diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 36b1383229..1f31856fb7 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -15,6 +15,7 @@ cache_preprocessing, get_prototype_and_waveforms_from_recording, get_shuffled_recording_slices, + set_optimal_chunk_size ) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity @@ -55,6 +56,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_preprocessing": True, "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, + "chunk_preprocessing": {"memory_limit": 0.5}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.5}, "seed": 42, @@ -84,6 +86,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", + "chunk_preprocessing": "How much RAM (approximately) should be devoted to load data chunks. memory_limit will control how much RAM can be used\ + as a fraction of available memory. Otherwise, use total_memory to fix a hard limit", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", "seed": "An int to control how chunks are shuffled while detecting peaks", @@ -126,8 +130,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) - recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + job_kwargs = set_optimal_chunk_size(recording, job_kwargs, **params["chunk_preprocessing"]) sampling_frequency = recording.get_sampling_frequency() num_channels = recording.get_num_channels() @@ -136,6 +140,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): radius_um = params["general"].get("radius_um", 75) exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after)) + ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 92f3e3f942..8b44e0254c 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -14,7 +14,7 @@ from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer -from spikeinterface.core.job_tools import split_job_kwargs +from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs def make_multi_method_doc(methods, ident=" "): @@ -248,19 +248,97 @@ def check_probe_for_drift_correction(recording, dist_x_max=60): return True -def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs): - save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) - - if mode == "memory": +def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None): + """ + Set the optimal chunk size for a job given the memory_limit and the number of jobs + + Parameters + ---------- + + recording: Recording + The recording object + job_kwargs: dict + The job kwargs + memory_limit: float + The memory limit in fraction of available memory + total_memory: str, Default None + The total memory to use for the job in bytes + + Returns + ------- + + job_kwargs: dict + The updated job kwargs + """ + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs['n_jobs'] + if total_memory is None: if HAVE_PSUTIL: assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" memory_usage = memory_limit * psutil.virtual_memory().available - if recording.get_total_memory_size() < memory_usage: + num_channels = recording.get_num_channels() + dtype_size_bytes = recording.get_dtype().itemsize + chunk_size = (memory_usage / ((num_channels * dtype_size_bytes) * n_jobs)) + chunk_duration = chunk_size/recording.get_sampling_frequency() + job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) + else: + print("psutil is required to use only a fraction of available memory") + else: + from spikeinterface.core.job_tools import convert_string_to_bytes + total_memory = convert_string_to_bytes(total_memory) + num_channels = recording.get_num_channels() + dtype_size_bytes = recording.get_dtype().itemsize + chunk_size = ((num_channels * dtype_size_bytes) * n_jobs / total_memory) + chunk_duration = chunk_size/recording.get_sampling_frequency() + job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) + return job_kwargs + + +def cache_preprocessing(recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs): + """ + Cache the preprocessing of a recording object + + Parameters + ---------- + + recording: Recording + The recording object + mode: str + The mode to cache the preprocessing, can be 'memory', 'folder', 'zarr' or 'no-cache' + memory_limit: float + The memory limit in fraction of available memory + total_memory: str, Default None + The total memory to use for the job in bytes + delete_cache: bool + If True, delete the cache after the job + **extra_kwargs: dict + The extra kwargs for the job + + Returns + ------- + + recording: Recording + The cached recording object + """ + + save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) + + if mode == "memory": + if total_memory is None: + if HAVE_PSUTIL: + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" + memory_usage = memory_limit * psutil.virtual_memory().available + if recording.get_total_memory_size() < memory_usage: + recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) + else: + print("Recording too large to be preloaded in RAM...") + else: + print("psutil is required to preload in memory given only a fraction of available memory") + else: + if recording.get_total_memory_size() < total_memory: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: print("Recording too large to be preloaded in RAM...") - else: - print("psutil is required to preload in memory") elif mode == "folder": recording = recording.save_to_folder(**extra_kwargs) elif mode == "zarr": From 0ce092cbe38fac37ab7c3d7d5542da6c93e483d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Feb 2025 09:45:40 +0000 Subject: [PATCH 02/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 5 ++-- src/spikeinterface/sortingcomponents/tools.py | 29 ++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 1f31856fb7..c9c8a0c337 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -15,7 +15,7 @@ cache_preprocessing, get_prototype_and_waveforms_from_recording, get_shuffled_recording_slices, - set_optimal_chunk_size + set_optimal_chunk_size, ) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity @@ -87,7 +87,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", "chunk_preprocessing": "How much RAM (approximately) should be devoted to load data chunks. memory_limit will control how much RAM can be used\ - as a fraction of available memory. Otherwise, use total_memory to fix a hard limit", + as a fraction of available memory. Otherwise, use total_memory to fix a hard limit", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", "seed": "An int to control how chunks are shuffled while detecting peaks", @@ -140,7 +140,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): radius_um = params["general"].get("radius_um", 75) exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after)) - ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 8b44e0254c..7b45c81f33 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -251,10 +251,10 @@ def check_probe_for_drift_correction(recording, dist_x_max=60): def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None): """ Set the optimal chunk size for a job given the memory_limit and the number of jobs - + Parameters ---------- - + recording: Recording The recording object job_kwargs: dict @@ -262,39 +262,42 @@ def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory memory_limit: float The memory limit in fraction of available memory total_memory: str, Default None - The total memory to use for the job in bytes - + The total memory to use for the job in bytes + Returns ------- - + job_kwargs: dict The updated job kwargs """ job_kwargs = fix_job_kwargs(job_kwargs) - n_jobs = job_kwargs['n_jobs'] + n_jobs = job_kwargs["n_jobs"] if total_memory is None: if HAVE_PSUTIL: assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" memory_usage = memory_limit * psutil.virtual_memory().available num_channels = recording.get_num_channels() dtype_size_bytes = recording.get_dtype().itemsize - chunk_size = (memory_usage / ((num_channels * dtype_size_bytes) * n_jobs)) - chunk_duration = chunk_size/recording.get_sampling_frequency() + chunk_size = memory_usage / ((num_channels * dtype_size_bytes) * n_jobs) + chunk_duration = chunk_size / recording.get_sampling_frequency() job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) else: print("psutil is required to use only a fraction of available memory") else: from spikeinterface.core.job_tools import convert_string_to_bytes + total_memory = convert_string_to_bytes(total_memory) num_channels = recording.get_num_channels() dtype_size_bytes = recording.get_dtype().itemsize - chunk_size = ((num_channels * dtype_size_bytes) * n_jobs / total_memory) - chunk_duration = chunk_size/recording.get_sampling_frequency() + chunk_size = (num_channels * dtype_size_bytes) * n_jobs / total_memory + chunk_duration = chunk_size / recording.get_sampling_frequency() job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) return job_kwargs -def cache_preprocessing(recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs): +def cache_preprocessing( + recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs +): """ Cache the preprocessing of a recording object @@ -313,10 +316,10 @@ def cache_preprocessing(recording, mode="memory", memory_limit=0.5, total_memory If True, delete the cache after the job **extra_kwargs: dict The extra kwargs for the job - + Returns ------- - + recording: Recording The cached recording object """ From cf31e7ffddc60c819a3e9ccc8385675740abcad8 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 25 Feb 2025 10:55:40 +0100 Subject: [PATCH 03/35] Default --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 1f31856fb7..80efe8c6fb 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -58,7 +58,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "chunk_preprocessing": {"memory_limit": 0.5}, "multi_units_only": False, - "job_kwargs": {"n_jobs": 0.5}, + "job_kwargs": {"n_jobs": 0.01}, "seed": 42, "debug": False, } From 72555308365fe4f41a3b62019d60a31a837844ca Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 25 Feb 2025 10:56:46 +0100 Subject: [PATCH 04/35] Default --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index e454921e80..f7b0d00136 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -56,9 +56,9 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_preprocessing": True, "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, - "chunk_preprocessing": {"memory_limit": 0.5}, + "chunk_preprocessing": {"memory_limit": 0.01}, "multi_units_only": False, - "job_kwargs": {"n_jobs": 0.01}, + "job_kwargs": {"n_jobs": 0.75}, "seed": 42, "debug": False, } From 26fc22651801e6bc6b8564fdfbf684f6a9ce5d30 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 27 Feb 2025 15:30:53 +0100 Subject: [PATCH 05/35] WIP --- .../sortingcomponents/clustering/circus.py | 9 ++++-- src/spikeinterface/sortingcomponents/tools.py | 31 +++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 7bce0800d3..436bae7cfd 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -23,7 +23,7 @@ from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.sortingcomponents.tools import remove_empty_templates +from spikeinterface.sortingcomponents.tools import remove_empty_templates, get_optimal_n_jobs import pickle, json from spikeinterface.core.node_pipeline import ( run_node_pipeline, @@ -66,6 +66,7 @@ class CircusClustering: "tmp_folder": None, "verbose": True, "debug": False, + "memory_limit": 0.25, } @classmethod @@ -257,6 +258,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + job_kwargs_local = job_kwargs.copy() + ram_requested = recording.get_num_channels() *(nbefore + nafter) * len(unit_ids) * 4 + job_kwargs_local = get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) + templates_array = estimate_templates( recording, spikes, @@ -265,7 +270,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): nafter, return_scaled=False, job_name=None, - **job_kwargs, + **job_kwargs_local, ) best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1e1a46badc..f3140a2fbe 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -295,6 +295,37 @@ def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) return job_kwargs +def get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): + """ + Set the optimal chunk size for a job given the memory_limit and the number of jobs + + Parameters + ---------- + + recording: Recording + The recording object + ram_requested: dict + The amount of RAM (in bytes) requested for the job + memory_limit: float + The memory limit in fraction of available memory + + Returns + ------- + + job_kwargs: dict + The updated job kwargs + """ + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + if HAVE_PSUTIL: + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" + memory_usage = memory_limit * psutil.virtual_memory().available + n_jobs = min(n_jobs, memory_usage // ram_requested) + job_kwargs = fix_job_kwargs(dict(n_jobs=n_jobs)) + else: + print("psutil is required to use only a fraction of available memory") + return job_kwargs + def cache_preprocessing( recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs From 3819f8c796d897d84c3c1174a21f5181f574d50b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Feb 2025 14:32:17 +0000 Subject: [PATCH 06/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- src/spikeinterface/sortingcomponents/tools.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 436bae7cfd..0e57c4d5c8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -259,7 +259,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) job_kwargs_local = job_kwargs.copy() - ram_requested = recording.get_num_channels() *(nbefore + nafter) * len(unit_ids) * 4 + ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4 job_kwargs_local = get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) templates_array = estimate_templates( diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index f3140a2fbe..e8ebc0e2eb 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -295,6 +295,7 @@ def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) return job_kwargs + def get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): """ Set the optimal chunk size for a job given the memory_limit and the number of jobs From 42d1c2c39ba12f34386798b432207fba25aa8144 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 27 Feb 2025 15:38:37 +0100 Subject: [PATCH 07/35] Keeping the input dict --- src/spikeinterface/sortingcomponents/tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index f3140a2fbe..edcb1c4769 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -320,8 +320,8 @@ def get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): if HAVE_PSUTIL: assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" memory_usage = memory_limit * psutil.virtual_memory().available - n_jobs = min(n_jobs, memory_usage // ram_requested) - job_kwargs = fix_job_kwargs(dict(n_jobs=n_jobs)) + n_jobs = int(min(n_jobs, memory_usage // ram_requested)) + job_kwargs.update(dict(n_jobs=n_jobs)) else: print("psutil is required to use only a fraction of available memory") return job_kwargs From 29eb160c3647ab0429f935413a6e913aef9eeabd Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 28 Feb 2025 12:23:16 +0100 Subject: [PATCH 08/35] Reducing memory footprint --- .../sortingcomponents/matching/circus.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 3b97f2dc6a..ed8c050437 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -50,11 +50,16 @@ def compress_templates( if remove_mean: templates_array -= templates_array.mean(axis=(1, 2))[:, None, None] - temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False) - # Keep only the strongest components - temporal = temporal[:, :, :approx_rank].astype(np.float32) - singular = singular[:, :approx_rank].astype(np.float32) - spatial = spatial[:, :approx_rank, :].astype(np.float32) + num_templates, num_samples, num_channels = templates_array.shape + temporal = np.zeros((num_templates, num_samples, approx_rank), dtype=np.float32) + spatial = np.zeros((num_templates, approx_rank, num_channels), dtype=np.float32) + singular = np.zeros((num_templates, approx_rank), dtype=np.float32) + + for i in range(num_templates): + i_temporal, i_singular, i_spatial = np.linalg.svd(templates_array[i], full_matrices=False) + temporal[i] = i_temporal[:, :approx_rank] + spatial[i] = i_spatial[:approx_rank, :] + singular[i] = i_singular[:approx_rank] if return_new_templates: templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial) From ed319bf1fbf81a8a1b0446e00b158afadf306839 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 28 Feb 2025 13:54:57 +0100 Subject: [PATCH 09/35] Patch for small num_channels --- src/spikeinterface/sortingcomponents/matching/circus.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ed8c050437..d187d714d8 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -57,9 +57,9 @@ def compress_templates( for i in range(num_templates): i_temporal, i_singular, i_spatial = np.linalg.svd(templates_array[i], full_matrices=False) - temporal[i] = i_temporal[:, :approx_rank] - spatial[i] = i_spatial[:approx_rank, :] - singular[i] = i_singular[:approx_rank] + temporal[i, :, :min(approx_rank, num_channels)] = i_temporal[:, :approx_rank] + spatial[i, :min(approx_rank, num_channels), :] = i_spatial[:approx_rank, :] + singular[i, :min(approx_rank, num_channels)] = i_singular[:approx_rank] if return_new_templates: templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial) From 8c1ca39ebc52c34a95489946c13934849236f5ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Feb 2025 12:55:24 +0000 Subject: [PATCH 10/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/circus.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index d187d714d8..04dd785616 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -57,9 +57,9 @@ def compress_templates( for i in range(num_templates): i_temporal, i_singular, i_spatial = np.linalg.svd(templates_array[i], full_matrices=False) - temporal[i, :, :min(approx_rank, num_channels)] = i_temporal[:, :approx_rank] - spatial[i, :min(approx_rank, num_channels), :] = i_spatial[:approx_rank, :] - singular[i, :min(approx_rank, num_channels)] = i_singular[:approx_rank] + temporal[i, :, : min(approx_rank, num_channels)] = i_temporal[:, :approx_rank] + spatial[i, : min(approx_rank, num_channels), :] = i_spatial[:approx_rank, :] + singular[i, : min(approx_rank, num_channels)] = i_singular[:approx_rank] if return_new_templates: templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial) From 0a2bd232fcd4ce865d9f31171dd441bd4c055296 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 28 Feb 2025 14:05:59 +0100 Subject: [PATCH 11/35] Saving the final analyzer --- .../sorters/internal/spyking_circus2.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 32712cc648..ad5457439f 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -347,8 +347,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): shutil.rmtree(sorting_folder) merging_params = params["merging"].copy() - if params["debug"]: - merging_params["debug_folder"] = sorter_output_folder / "merging" + merging_params["debug_folder"] = sorter_output_folder / "merging" if len(merging_params) > 0: if params["motion_correction"] and motion_folder is not None: @@ -369,7 +368,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting.save(folder=curation_folder) # np.save(fitting_folder / "amplitudes", guessed_amplitudes) - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) + final_analyzer = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) + final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer") + + sorting = final_analyzer.sorting if verbose: print(f"Kept {len(sorting.unit_ids)} units after final merging") @@ -428,4 +430,5 @@ def final_cleaning_circus( sparsity_overlap=sparsity_overlap, **job_kwargs, ) - return final_sa.sorting + + return final_sa From 5d88e7234b8614862284eb665070554a66478446 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Sun, 2 Mar 2025 21:06:30 +0100 Subject: [PATCH 12/35] Docstrings --- .../sorters/internal/spyking_circus2.py | 5 +++-- src/spikeinterface/sortingcomponents/tools.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ad5457439f..50a6603d24 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -75,8 +75,9 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", - "chunk_preprocessing": "How much RAM (approximately) should be devoted to load data chunks. memory_limit will control how much RAM can be used\ - as a fraction of available memory. Otherwise, use total_memory to fix a hard limit", + "chunk_preprocessing": "How much RAM (approximately) should be devoted to load all data chunks (given n_jobs).\ + memory_limit will control how much RAM can be used as a fraction of available memory. Otherwise, use total_memory to fix a hard limit, with\ + a string syntax (e.g. '1G', '500M')", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", "seed": "An int to control how chunks are shuffled while detecting peaks", diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index e55af1edfb..ba4154eb80 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -305,7 +305,7 @@ def get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): recording: Recording The recording object - ram_requested: dict + ram_requested: int The amount of RAM (in bytes) requested for the job memory_limit: float The memory limit in fraction of available memory @@ -324,7 +324,8 @@ def get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): n_jobs = int(min(n_jobs, memory_usage // ram_requested)) job_kwargs.update(dict(n_jobs=n_jobs)) else: - print("psutil is required to use only a fraction of available memory") + import warnings + warnings.warn("psutil is required to use only a fraction of available memory") return job_kwargs @@ -367,14 +368,17 @@ def cache_preprocessing( if recording.get_total_memory_size() < memory_usage: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: - print("Recording too large to be preloaded in RAM...") + import warnings + warnings.warn("Recording too large to be preloaded in RAM...") else: - print("psutil is required to preload in memory given only a fraction of available memory") + import warnings + warnings.warn("psutil is required to preload in memory given only a fraction of available memory") else: if recording.get_total_memory_size() < total_memory: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: - print("Recording too large to be preloaded in RAM...") + import warnings + warnings.warn("Recording too large to be preloaded in RAM...") elif mode == "folder": recording = recording.save_to_folder(**extra_kwargs) elif mode == "zarr": From f746d4194bd2133c8e3e90790b57a35dd2f78b95 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 2 Mar 2025 20:07:03 +0000 Subject: [PATCH 13/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/tools.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index ba4154eb80..b97aea9ad0 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -325,6 +325,7 @@ def get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): job_kwargs.update(dict(n_jobs=n_jobs)) else: import warnings + warnings.warn("psutil is required to use only a fraction of available memory") return job_kwargs @@ -369,15 +370,18 @@ def cache_preprocessing( recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: import warnings + warnings.warn("Recording too large to be preloaded in RAM...") else: import warnings + warnings.warn("psutil is required to preload in memory given only a fraction of available memory") else: if recording.get_total_memory_size() < total_memory: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: import warnings + warnings.warn("Recording too large to be preloaded in RAM...") elif mode == "folder": recording = recording.save_to_folder(**extra_kwargs) From d6b3bdbeaeffa872fc1e507e13ecb7f7463fa422 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Sun, 2 Mar 2025 21:07:41 +0100 Subject: [PATCH 14/35] More docstrings --- src/spikeinterface/sortingcomponents/tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index ba4154eb80..72f0001f05 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -283,7 +283,8 @@ def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory chunk_duration = chunk_size / recording.get_sampling_frequency() job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) else: - print("psutil is required to use only a fraction of available memory") + import warnings + warnings.warn("psutil is required to use only a fraction of available memory") else: from spikeinterface.core.job_tools import convert_string_to_bytes From f2a3ac4fa6ade3c1ef040fb7a6990d1506de7668 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 2 Mar 2025 20:08:14 +0000 Subject: [PATCH 15/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 99af5116f3..56b128d028 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -284,6 +284,7 @@ def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) else: import warnings + warnings.warn("psutil is required to use only a fraction of available memory") else: from spikeinterface.core.job_tools import convert_string_to_bytes From b72ee86faa9847d32309e4e120996d59b961daf0 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 3 Mar 2025 15:37:06 +0100 Subject: [PATCH 16/35] Cosmetic and bug fixes --- src/spikeinterface/core/sortinganalyzer.py | 5 ++++- src/spikeinterface/widgets/crosscorrelograms.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 85d405c443..8af3024cfd 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1199,7 +1199,10 @@ def merge_units( if len(merge_unit_groups) == 0: # TODO I think we should raise an error or at least make a copy and not return itself - return self + if return_new_unit_ids: + return self, [] + else: + return self for units in merge_unit_groups: # TODO more checks like one units is only in one group diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 26780ce124..a8ef91f54d 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -119,7 +119,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if i < len(self.axes) - 1: self.axes[i, j].set_xticks([], []) - plt.tight_layout() + self.figure.tight_layout() for i, unit_id in enumerate(unit_ids): self.axes[0, i].set_title(str(unit_id)) From 0445d4e4dcb146476dc888061d1244e6932638e0 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 5 Mar 2025 09:48:39 +0100 Subject: [PATCH 17/35] Remove HDBSCAN dependency --- .../sorters/external/tridesclous.py | 2 +- .../sorters/internal/simplesorter.py | 10 +++++----- .../sorters/internal/spyking_circus2.py | 10 +--------- .../sortingcomponents/clustering/circus.py | 15 ++++---------- .../clustering/clustering_tools.py | 20 ++++++++++--------- .../sortingcomponents/clustering/merge.py | 4 ++-- .../sortingcomponents/clustering/position.py | 16 +++++---------- .../clustering/position_and_features.py | 12 ++--------- .../clustering/position_and_pca.py | 12 ++--------- .../clustering/position_ptp_scaled.py | 11 +++------- .../clustering/random_projections.py | 16 +++++---------- .../clustering/sliding_hdbscan.py | 11 ++-------- .../clustering/sliding_nn.py | 11 +--------- .../sortingcomponents/clustering/split.py | 2 +- .../sortingcomponents/clustering/tdc.py | 2 +- 15 files changed, 46 insertions(+), 108 deletions(-) diff --git a/src/spikeinterface/sorters/external/tridesclous.py b/src/spikeinterface/sorters/external/tridesclous.py index e9a22d0951..13e8d9bb32 100644 --- a/src/spikeinterface/sorters/external/tridesclous.py +++ b/src/spikeinterface/sorters/external/tridesclous.py @@ -65,7 +65,7 @@ def is_installed(cls): HAVE_TDC = False except: print( - "tridesclous is installed, but it has some dependency problems, check numba or hdbscan installations!" + "tridesclous is installed, but it has some dependency problems, check numba install!" ) HAVE_TDC = False return HAVE_TDC diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 0f44e4079a..9fea538a04 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -12,6 +12,7 @@ import json + class SimpleSorter(ComponentsBasedSorter): """ Implementation of a very simple sorter usefull for teaching. @@ -181,13 +182,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clust_method = clust_params.pop("method", "hdbscan") if clust_method == "hdbscan": - import hdbscan - - out = hdbscan.hdbscan(features_flat, **clust_params) - peak_labels = out[0] + from sklearn.cluster import HDBSCAN + clusterer = HDBSCAN(**clust_params) + clusterer.fit(features_flat) + peak_labels = clusterer.labels_ elif clust_method in ("kmeans"): from sklearn.cluster import MiniBatchKMeans - peak_labels = MiniBatchKMeans(**clust_params).fit_predict(features_flat) elif clust_method in ("mean_shift"): from sklearn.cluster import MeanShift diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 50a6603d24..179b75f1d6 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -95,15 +95,7 @@ def get_sorter_version(cls): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - try: - import hdbscan - - HAVE_HDBSCAN = True - except: - HAVE_HDBSCAN = False - - assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" - + try: import torch except ImportError: diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 0e57c4d5c8..ed42d1e1ba 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -4,14 +4,6 @@ from pathlib import Path import numpy as np - -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False - import random, string from spikeinterface.core import get_global_tmp_folder from spikeinterface.core.basesorting import minimum_spike_dtype @@ -30,6 +22,7 @@ ExtractSparseWaveforms, PeakRetriever, ) +from sklearn.cluster import HDBSCAN from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel @@ -71,7 +64,6 @@ class CircusClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" d = params verbose = d["verbose"] @@ -180,8 +172,9 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): hdbscan_data = tsvd.fit_transform(sub_data) try: - clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) - local_labels = clustering[0] + clusterer = HDBSCAN(**d["hdbscan_kwargs"]) + clusterer.fit(hdbscan_data) + local_labels = clusterer.labels_ except Exception: local_labels = np.zeros(len(hdbscan_data)) valid_clusters = local_labels > -1 diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 93db9a268f..173820e9f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -15,7 +15,7 @@ def _split_waveforms( wfs_and_noise, noise_size, n_components_by_channel, n_components, hdbscan_params, probability_thr, debug ): import sklearn.decomposition - import hdbscan + from sklearn.cluster import HDBSCAN valid_size = wfs_and_noise.shape[0] - noise_size @@ -30,9 +30,10 @@ def _split_waveforms( local_feature = pca.fit_transform(local_feature) # hdbscan on pca - clustering = hdbscan.hdbscan(local_feature, **hdbscan_params) - local_labels_with_noise = clustering[0] - cluster_probability = clustering[2] + clusterer = HDBSCAN(**hdbscan_params) + clusterer.fit(local_feature) + local_labels_with_noise = clusterer.labels_ + cluster_probability = clusterer.probabilities_ (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1 @@ -95,7 +96,7 @@ def _split_waveforms_nested( wfs_and_noise, noise_size, nbefore, n_components_by_channel, n_components, hdbscan_params, probability_thr, debug ): import sklearn.decomposition - import hdbscan + from sklearn.cluster import HDBSCAN valid_size = wfs_and_noise.shape[0] - noise_size @@ -123,9 +124,10 @@ def _split_waveforms_nested( # ~ local_feature = pca.fit_transform(local_feature) # hdbscan on pca - clustering = hdbscan.hdbscan(local_feature, **hdbscan_params) - active_labels_with_noise = clustering[0] - cluster_probability = clustering[2] + clusterer = HDBSCAN(**hdbscan_params) + clusterer.fit(local_feature) + active_labels_with_noise = clusterer.labels_ + cluster_probability = clusterer.probabilities_ (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1 @@ -233,7 +235,7 @@ def auto_split_clustering( """ import sklearn.decomposition - import hdbscan + from sklearn.cluster import HDBSCAN split_peak_labels = -1 * np.ones(peak_labels.size, dtype=np.int64) nb_clusters = 0 diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index c9a9841c57..f18135acd2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -94,10 +94,10 @@ def merge_clusters( pair_values[~pair_mask] = 20 - import hdbscan + from sklearn.cluster import HDBSCAN fig, ax = plt.subplots() - clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) + clusterer = HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) clusterer.fit(pair_values) # print(clusterer.labels_) clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True) diff --git a/src/spikeinterface/sortingcomponents/clustering/position.py b/src/spikeinterface/sortingcomponents/clustering/position.py index dc76d787f6..4cdcf6d201 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position.py +++ b/src/spikeinterface/sortingcomponents/clustering/position.py @@ -2,16 +2,10 @@ # """Sorting components: clustering""" from pathlib import Path +from sklearn.cluster import HDBSCAN import numpy as np -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False - class PositionClustering: """ @@ -29,7 +23,7 @@ class PositionClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" + d = params if d["peak_locations"] is None: @@ -51,9 +45,9 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): else: to_cluster_from = locations - clustering = hdbscan.hdbscan(to_cluster_from, **d["hdbscan_kwargs"]) - peak_labels = clustering[0] - + clusterer = HDBSCAN(**d["hdbscan_kwargs"]) + clusterer.fit(to_cluster_from) + peak_labels = clusterer.labels_ labels = np.unique(peak_labels) labels = labels[labels >= 0] diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index 20067a2eec..65f13a29ab 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -5,13 +5,7 @@ import shutil import numpy as np - -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False +from sklearn.cluster import HDBSCAN import random, string, os from spikeinterface.core import get_global_tmp_folder, get_noise_levels @@ -48,8 +42,6 @@ class PositionAndFeaturesClustering: def main_function(cls, recording, peaks, params, job_kwargs=dict()): from sklearn.preprocessing import QuantileTransformer - assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" - d = params peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] @@ -82,7 +74,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): preprocessing = QuantileTransformer(output_distribution="uniform") hdbscan_data = preprocessing.fit_transform(hdbscan_data) - clusterer = hdbscan.HDBSCAN(**d["hdbscan_kwargs"]) + clusterer = HDBSCAN(**d["hdbscan_kwargs"]) clusterer.fit(X=hdbscan_data) peak_labels = clusterer.labels_ diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index c4f372fc21..c34170f073 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -7,13 +7,7 @@ import os import numpy as np - -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False +from sklearn.cluster import HDBSCAN from spikeinterface.core import get_global_tmp_folder from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks @@ -75,8 +69,6 @@ def _check_params(cls, recording, peaks, params): def main_function(cls, recording, peaks, params, job_kwargs=dict()): # res = PositionClustering(recording, peaks, params) - assert HAVE_HDBSCAN, "position_and_pca clustering need hdbscan to be installed" - params = cls._check_params(recording, peaks, params) # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) @@ -96,7 +88,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): else: to_cluster_from = locations - clusterer = hdbscan.HDBSCAN(**params["hdbscan_global_kwargs"]) + clusterer = HDBSCAN(**params["hdbscan_global_kwargs"]) clusterer.fit(X=to_cluster_from) spatial_peak_labels = clusterer.labels_ diff --git a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py index 0f7390d7ac..750690ef3f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py @@ -4,12 +4,7 @@ from pathlib import Path import numpy as np -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False +from sklearn.cluster import HDBSCAN from spikeinterface.sortingcomponents.features_from_peaks import ( compute_features_from_peaks, @@ -38,7 +33,7 @@ class PositionPTPScaledClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" + d = params if d["peak_locations"] is None: @@ -70,7 +65,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): to_cluster_from = np.hstack((locations, logmaxptps[:, np.newaxis])) to_cluster_from = to_cluster_from * d["scales"] - clusterer = hdbscan.HDBSCAN(**d["hdbscan_kwargs"]) + clusterer = HDBSCAN(**d["hdbscan_kwargs"]) clusterer.fit(X=to_cluster_from) peak_labels = clusterer.labels_ diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index e67e1907f1..1dae122dc6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -5,13 +5,7 @@ import shutil import numpy as np - -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False +from sklearn.cluster import HDBSCAN from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates @@ -61,8 +55,7 @@ class RandomProjectionClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" - + d = params verbose = d["verbose"] @@ -114,8 +107,9 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): recording, pipeline_nodes, job_kwargs=job_kwargs, job_name="extracting features" ) - clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) - peak_labels = clustering[0] + clusterer = HDBSCAN(**d["hdbscan_kwargs"]) + clusterer.fit(hdbscan_data) + peak_labels = clusterer.labels_ labels = np.unique(peak_labels) labels = labels[labels >= 0] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 5f8ac99848..609182f5ad 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -5,17 +5,11 @@ import time import random import string +from sklearn.cluster import HDBSCAN import numpy as np -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False - from spikeinterface.core import ( get_global_tmp_folder, @@ -59,7 +53,6 @@ class SlidingHdbscanClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - assert HAVE_HDBSCAN, "sliding_hdbscan clustering need hdbscan to be installed" params = cls._check_params(recording, peaks, params) wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) peak_labels = cls._find_clusters(recording, peaks, wfs_arrays, sparsity_mask, noise, params) @@ -275,7 +268,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): # find some clusters # ~ t0 = time.perf_counter() - clusterer = hdbscan.HDBSCAN(min_cluster_size=d["min_cluster_size"], allow_single_cluster=True, metric="l2") + clusterer = HDBSCAN(min_cluster_size=d["min_cluster_size"], allow_single_cluster=True, metric="l2") all_labels = clusterer.fit_predict(local_feature) # ~ t1 = time.perf_counter() # ~ print('HDBSCAN time', t1 - t0) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index 40cedacdc5..635eb47cd2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -20,14 +20,6 @@ from spikeinterface.core import get_channel_distances from tqdm.auto import tqdm -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False - - try: import pymde @@ -79,7 +71,6 @@ def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_TORCH, "SlidingNN needs torch to work" assert HAVE_NNDESCENT, "SlidingNN needs pynndescent to work" assert HAVE_PYMDE, "SlidingNN needs pymde to work" - assert HAVE_HDBSCAN, "SlidingNN needs hdbscan to work" d = params tmp_folder = params["tmp_folder"] @@ -294,7 +285,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): "Clustering MDE embeddings (n={}): {}".format(embeddings_chunk.shape, datetime.datetime.now()) ) # TODO HDBSCAN can be done on GPU with NVIDIA RAPIDS for speed - clusterer = hdbscan.HDBSCAN( + clusterer = HDBSCAN( prediction_data=True, core_dist_n_jobs=job_kwargs["n_jobs"], **d["hdbscan_kwargs"], diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 310ee0969f..bfa1cbe035 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -259,7 +259,7 @@ def split( final_features = flatten_features if clusterer == "hdbscan": - from hdbscan import HDBSCAN + from sklearn.cluster import HDBSCAN clust = HDBSCAN(**clusterer_kwargs, core_dist_n_jobs=1) clust.fit(final_features) diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index 027503a7c8..42b893810e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -57,7 +57,7 @@ class TdcClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - import hdbscan + from sklearn.cluster import HDBSCAN if params["folder"] is None: randname = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) From 6409824f48f8472716be702f1d70361efe902ba4 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 5 Mar 2025 09:49:54 +0100 Subject: [PATCH 18/35] Remove hdbscan --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 97ba77299e..d39b4df2d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,7 +161,6 @@ test = [ # tridesclous "numba<0.61.0;python_version<'3.13'", "numba>=0.61.0;python_version>='3.13'", - "hdbscan>=0.8.33", # Previous version had a broken wheel # for sortingview backend "sortingview", @@ -191,7 +190,6 @@ docs = [ # for notebooks in the gallery "MEArec", # Use as an example "pandas", # in the modules gallery comparison tutorial - "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions "networkx", "skops", # For auotmated curation From b9b3457e84d244ab06b0462b07ebf06e99d95201 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Mar 2025 08:53:06 +0000 Subject: [PATCH 19/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/external/tridesclous.py | 4 +--- src/spikeinterface/sorters/internal/simplesorter.py | 3 ++- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/external/tridesclous.py b/src/spikeinterface/sorters/external/tridesclous.py index 13e8d9bb32..460acc2287 100644 --- a/src/spikeinterface/sorters/external/tridesclous.py +++ b/src/spikeinterface/sorters/external/tridesclous.py @@ -64,9 +64,7 @@ def is_installed(cls): except ImportError: HAVE_TDC = False except: - print( - "tridesclous is installed, but it has some dependency problems, check numba install!" - ) + print("tridesclous is installed, but it has some dependency problems, check numba install!") HAVE_TDC = False return HAVE_TDC diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 9fea538a04..8bd2c1923a 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -12,7 +12,6 @@ import json - class SimpleSorter(ComponentsBasedSorter): """ Implementation of a very simple sorter usefull for teaching. @@ -183,11 +182,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if clust_method == "hdbscan": from sklearn.cluster import HDBSCAN + clusterer = HDBSCAN(**clust_params) clusterer.fit(features_flat) peak_labels = clusterer.labels_ elif clust_method in ("kmeans"): from sklearn.cluster import MiniBatchKMeans + peak_labels = MiniBatchKMeans(**clust_params).fit_predict(features_flat) elif clust_method in ("mean_shift"): from sklearn.cluster import MeanShift diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 179b75f1d6..3f64478359 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -95,7 +95,7 @@ def get_sorter_version(cls): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - + try: import torch except ImportError: diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 1dae122dc6..56a07d642c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -55,7 +55,7 @@ class RandomProjectionClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - + d = params verbose = d["verbose"] From 5d4e516ca8e888f25796ebfda17794f49efd1af1 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 5 Mar 2025 09:56:39 +0100 Subject: [PATCH 20/35] Adding sklearn as a dependency for testing --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d39b4df2d6..75edd94f10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,6 +162,8 @@ test = [ "numba<0.61.0;python_version<'3.13'", "numba>=0.61.0;python_version>='3.13'", + "scikit-learn", + # for sortingview backend "sortingview", @@ -185,6 +187,7 @@ docs = [ "sphinx-design", "numpydoc", "ipython", + "scikit-learn", "sphinxcontrib-jquery", # for notebooks in the gallery From fcea1b6748e9585429183b86e36843c925aa0e3d Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 5 Mar 2025 10:09:26 +0100 Subject: [PATCH 21/35] Spaces --- src/spikeinterface/sortingcomponents/tools.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 56b128d028..ba7e273311 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -372,18 +372,15 @@ def cache_preprocessing( recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: import warnings - warnings.warn("Recording too large to be preloaded in RAM...") else: import warnings - warnings.warn("psutil is required to preload in memory given only a fraction of available memory") else: if recording.get_total_memory_size() < total_memory: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: import warnings - warnings.warn("Recording too large to be preloaded in RAM...") elif mode == "folder": recording = recording.save_to_folder(**extra_kwargs) From 3812ece6eedbbc0b4076126cd132f950ab90d007 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Mar 2025 09:10:39 +0000 Subject: [PATCH 22/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/tools.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index ba7e273311..56b128d028 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -372,15 +372,18 @@ def cache_preprocessing( recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: import warnings + warnings.warn("Recording too large to be preloaded in RAM...") else: import warnings + warnings.warn("psutil is required to preload in memory given only a fraction of available memory") else: if recording.get_total_memory_size() < total_memory: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: import warnings + warnings.warn("Recording too large to be preloaded in RAM...") elif mode == "folder": recording = recording.save_to_folder(**extra_kwargs) From 8c400f439a8395f294758db87853a65df5096b8c Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 5 Mar 2025 11:18:16 +0100 Subject: [PATCH 23/35] HDBSCAN will go in other PR --- pyproject.toml | 7 +++-- .../sorters/internal/simplesorter.py | 9 +++---- .../sorters/internal/spyking_circus2.py | 8 ++++++ .../sorters/internal/tridesclous2.py | 2 +- .../sortingcomponents/clustering/circus.py | 26 ++++++++++--------- .../clustering/clustering_tools.py | 22 +++++++--------- .../sortingcomponents/clustering/merge.py | 6 ++--- .../sortingcomponents/clustering/position.py | 18 ++++++++----- .../clustering/position_and_features.py | 14 +++++++--- .../clustering/position_and_pca.py | 14 +++++++--- .../clustering/position_ptp_scaled.py | 13 +++++++--- .../clustering/random_projections.py | 16 ++++++++---- .../clustering/sliding_hdbscan.py | 13 +++++++--- .../clustering/sliding_nn.py | 13 ++++++++-- .../sortingcomponents/clustering/split.py | 4 +-- .../sortingcomponents/clustering/tdc.py | 4 +-- 16 files changed, 122 insertions(+), 67 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 75edd94f10..f6d83b6616 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,8 +161,7 @@ test = [ # tridesclous "numba<0.61.0;python_version<'3.13'", "numba>=0.61.0;python_version>='3.13'", - - "scikit-learn", + "hdbscan>=0.8.33", # Previous version had a broken wheel # for sortingview backend "sortingview", @@ -187,12 +186,12 @@ docs = [ "sphinx-design", "numpydoc", "ipython", - "scikit-learn", "sphinxcontrib-jquery", # for notebooks in the gallery "MEArec", # Use as an example "pandas", # in the modules gallery comparison tutorial + "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions "networkx", "skops", # For auotmated curation @@ -236,4 +235,4 @@ markers = [ filterwarnings =[ 'ignore:.*distutils Version classes are deprecated.*:DeprecationWarning', 'ignore:.*the imp module is deprecated in favour of importlib.*:DeprecationWarning', -] +] \ No newline at end of file diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 8bd2c1923a..91d1fccd84 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -181,11 +181,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clust_method = clust_params.pop("method", "hdbscan") if clust_method == "hdbscan": - from sklearn.cluster import HDBSCAN + import hdbscan - clusterer = HDBSCAN(**clust_params) - clusterer.fit(features_flat) - peak_labels = clusterer.labels_ + out = hdbscan.hdbscan(features_flat, **clust_params) + peak_labels = out[0] elif clust_method in ("kmeans"): from sklearn.cluster import MiniBatchKMeans @@ -233,4 +232,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) sorting_final = sorting_final.save(folder=sorter_output_folder / "sorting") - return sorting_final + return sorting_final \ No newline at end of file diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 3f64478359..f2ac44c191 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -95,7 +95,15 @@ def get_sorter_version(cls): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): + try: + import hdbscan + + HAVE_HDBSCAN = True + except: + HAVE_HDBSCAN = False + assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" + try: import torch except ImportError: diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 65dfb2ed45..1bbd852b4f 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -258,4 +258,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = NumpySorting(final_spikes, sampling_frequency, labels_set) sorting = sorting.save(folder=sorter_output_folder / "sorting") - return sorting + return sorting \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index ed42d1e1ba..7f12cd09fb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -4,6 +4,14 @@ from pathlib import Path import numpy as np + +try: + import hdbscan + + HAVE_HDBSCAN = True +except: + HAVE_HDBSCAN = False + import random, string from spikeinterface.core import get_global_tmp_folder from spikeinterface.core.basesorting import minimum_spike_dtype @@ -15,14 +23,13 @@ from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.sortingcomponents.tools import remove_empty_templates, get_optimal_n_jobs +from spikeinterface.sortingcomponents.tools import remove_empty_templates import pickle, json from spikeinterface.core.node_pipeline import ( run_node_pipeline, ExtractSparseWaveforms, PeakRetriever, ) -from sklearn.cluster import HDBSCAN from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel @@ -59,11 +66,11 @@ class CircusClustering: "tmp_folder": None, "verbose": True, "debug": False, - "memory_limit": 0.25, } @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): + assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" d = params verbose = d["verbose"] @@ -172,9 +179,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): hdbscan_data = tsvd.fit_transform(sub_data) try: - clusterer = HDBSCAN(**d["hdbscan_kwargs"]) - clusterer.fit(hdbscan_data) - local_labels = clusterer.labels_ + clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) + local_labels = clustering[0] except Exception: local_labels = np.zeros(len(hdbscan_data)) valid_clusters = local_labels > -1 @@ -251,10 +257,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - job_kwargs_local = job_kwargs.copy() - ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4 - job_kwargs_local = get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) - templates_array = estimate_templates( recording, spikes, @@ -263,7 +265,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): nafter, return_scaled=False, job_name=None, - **job_kwargs_local, + **job_kwargs, ) best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) @@ -312,4 +314,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("Kept %d non-duplicated clusters" % len(labels)) - return labels, peak_labels + return labels, peak_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 173820e9f6..51b46e3f67 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -15,7 +15,7 @@ def _split_waveforms( wfs_and_noise, noise_size, n_components_by_channel, n_components, hdbscan_params, probability_thr, debug ): import sklearn.decomposition - from sklearn.cluster import HDBSCAN + import hdbscan valid_size = wfs_and_noise.shape[0] - noise_size @@ -30,10 +30,9 @@ def _split_waveforms( local_feature = pca.fit_transform(local_feature) # hdbscan on pca - clusterer = HDBSCAN(**hdbscan_params) - clusterer.fit(local_feature) - local_labels_with_noise = clusterer.labels_ - cluster_probability = clusterer.probabilities_ + clustering = hdbscan.hdbscan(local_feature, **hdbscan_params) + local_labels_with_noise = clustering[0] + cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1 @@ -96,7 +95,7 @@ def _split_waveforms_nested( wfs_and_noise, noise_size, nbefore, n_components_by_channel, n_components, hdbscan_params, probability_thr, debug ): import sklearn.decomposition - from sklearn.cluster import HDBSCAN + import hdbscan valid_size = wfs_and_noise.shape[0] - noise_size @@ -124,10 +123,9 @@ def _split_waveforms_nested( # ~ local_feature = pca.fit_transform(local_feature) # hdbscan on pca - clusterer = HDBSCAN(**hdbscan_params) - clusterer.fit(local_feature) - active_labels_with_noise = clusterer.labels_ - cluster_probability = clusterer.probabilities_ + clustering = hdbscan.hdbscan(local_feature, **hdbscan_params) + active_labels_with_noise = clustering[0] + cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1 @@ -235,7 +233,7 @@ def auto_split_clustering( """ import sklearn.decomposition - from sklearn.cluster import HDBSCAN + import hdbscan split_peak_labels = -1 * np.ones(peak_labels.size, dtype=np.int64) nb_clusters = 0 @@ -777,4 +775,4 @@ def remove_duplicates_via_dip(wfs_arrays, peak_labels, dip_threshold=1, cosine_t labels = np.unique(new_labels) labels = labels[labels >= 0] - return labels, new_labels + return labels, new_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index f18135acd2..b1ce79c557 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -94,10 +94,10 @@ def merge_clusters( pair_values[~pair_mask] = 20 - from sklearn.cluster import HDBSCAN + import hdbscan fig, ax = plt.subplots() - clusterer = HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) + clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) clusterer.fit(pair_values) # print(clusterer.labels_) clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True) @@ -753,4 +753,4 @@ def merge( ProjectDistribution, NormalizedTemplateDiff, ] -find_pair_method_dict = {e.name: e for e in find_pair_method_list} +find_pair_method_dict = {e.name: e for e in find_pair_method_list} \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/position.py b/src/spikeinterface/sortingcomponents/clustering/position.py index 4cdcf6d201..3d86c70b5f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position.py +++ b/src/spikeinterface/sortingcomponents/clustering/position.py @@ -2,10 +2,16 @@ # """Sorting components: clustering""" from pathlib import Path -from sklearn.cluster import HDBSCAN import numpy as np +try: + import hdbscan + + HAVE_HDBSCAN = True +except: + HAVE_HDBSCAN = False + class PositionClustering: """ @@ -23,7 +29,7 @@ class PositionClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - + assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params if d["peak_locations"] is None: @@ -45,9 +51,9 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): else: to_cluster_from = locations - clusterer = HDBSCAN(**d["hdbscan_kwargs"]) - clusterer.fit(to_cluster_from) - peak_labels = clusterer.labels_ + clustering = hdbscan.hdbscan(to_cluster_from, **d["hdbscan_kwargs"]) + peak_labels = clustering[0] + labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -71,4 +77,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): fig1.savefig(tmp_folder / "peak_locations.png") fig2.savefig(tmp_folder / "peak_locations_clustered.png") - return labels, peak_labels + return labels, peak_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index 65f13a29ab..4540bdca11 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -5,7 +5,13 @@ import shutil import numpy as np -from sklearn.cluster import HDBSCAN + +try: + import hdbscan + + HAVE_HDBSCAN = True +except: + HAVE_HDBSCAN = False import random, string, os from spikeinterface.core import get_global_tmp_folder, get_noise_levels @@ -42,6 +48,8 @@ class PositionAndFeaturesClustering: def main_function(cls, recording, peaks, params, job_kwargs=dict()): from sklearn.preprocessing import QuantileTransformer + assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" + d = params peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] @@ -74,7 +82,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): preprocessing = QuantileTransformer(output_distribution="uniform") hdbscan_data = preprocessing.fit_transform(hdbscan_data) - clusterer = HDBSCAN(**d["hdbscan_kwargs"]) + clusterer = hdbscan.HDBSCAN(**d["hdbscan_kwargs"]) clusterer.fit(X=hdbscan_data) peak_labels = clusterer.labels_ @@ -182,4 +190,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): print("We kept %d non-duplicated clusters..." % len(labels)) - return labels, peak_labels + return labels, peak_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index c34170f073..3e4b8e997c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -7,7 +7,13 @@ import os import numpy as np -from sklearn.cluster import HDBSCAN + +try: + import hdbscan + + HAVE_HDBSCAN = True +except: + HAVE_HDBSCAN = False from spikeinterface.core import get_global_tmp_folder from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks @@ -69,6 +75,8 @@ def _check_params(cls, recording, peaks, params): def main_function(cls, recording, peaks, params, job_kwargs=dict()): # res = PositionClustering(recording, peaks, params) + assert HAVE_HDBSCAN, "position_and_pca clustering need hdbscan to be installed" + params = cls._check_params(recording, peaks, params) # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) @@ -88,7 +96,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): else: to_cluster_from = locations - clusterer = HDBSCAN(**params["hdbscan_global_kwargs"]) + clusterer = hdbscan.HDBSCAN(**params["hdbscan_global_kwargs"]) clusterer.fit(X=to_cluster_from) spatial_peak_labels = clusterer.labels_ @@ -232,4 +240,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels = np.unique(clean_peak_labels) labels = labels[labels >= 0] - return labels, clean_peak_labels + return labels, clean_peak_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py index 750690ef3f..9953ccdaa4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py @@ -4,7 +4,12 @@ from pathlib import Path import numpy as np -from sklearn.cluster import HDBSCAN +try: + import hdbscan + + HAVE_HDBSCAN = True +except: + HAVE_HDBSCAN = False from spikeinterface.sortingcomponents.features_from_peaks import ( compute_features_from_peaks, @@ -33,7 +38,7 @@ class PositionPTPScaledClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - + assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params if d["peak_locations"] is None: @@ -65,7 +70,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): to_cluster_from = np.hstack((locations, logmaxptps[:, np.newaxis])) to_cluster_from = to_cluster_from * d["scales"] - clusterer = HDBSCAN(**d["hdbscan_kwargs"]) + clusterer = hdbscan.HDBSCAN(**d["hdbscan_kwargs"]) clusterer.fit(X=to_cluster_from) peak_labels = clusterer.labels_ @@ -92,4 +97,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): fig1.savefig(tmp_folder / "peak_locations.png") fig2.savefig(tmp_folder / "peak_locations_clustered.png") - return labels, peak_labels + return labels, peak_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 56a07d642c..ffbede8724 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -5,7 +5,13 @@ import shutil import numpy as np -from sklearn.cluster import HDBSCAN + +try: + import hdbscan + + HAVE_HDBSCAN = True +except: + HAVE_HDBSCAN = False from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates @@ -55,6 +61,7 @@ class RandomProjectionClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): + assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" d = params verbose = d["verbose"] @@ -107,9 +114,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): recording, pipeline_nodes, job_kwargs=job_kwargs, job_name="extracting features" ) - clusterer = HDBSCAN(**d["hdbscan_kwargs"]) - clusterer.fit(hdbscan_data) - peak_labels = clusterer.labels_ + clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) + peak_labels = clustering[0] labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -180,4 +186,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("Kept %d non-duplicated clusters" % len(labels)) - return labels, peak_labels + return labels, peak_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 609182f5ad..2f6c84a05c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -5,11 +5,17 @@ import time import random import string -from sklearn.cluster import HDBSCAN import numpy as np +try: + import hdbscan + + HAVE_HDBSCAN = True +except: + HAVE_HDBSCAN = False + from spikeinterface.core import ( get_global_tmp_folder, @@ -53,6 +59,7 @@ class SlidingHdbscanClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): + assert HAVE_HDBSCAN, "sliding_hdbscan clustering need hdbscan to be installed" params = cls._check_params(recording, peaks, params) wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) peak_labels = cls._find_clusters(recording, peaks, wfs_arrays, sparsity_mask, noise, params) @@ -268,7 +275,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): # find some clusters # ~ t0 = time.perf_counter() - clusterer = HDBSCAN(min_cluster_size=d["min_cluster_size"], allow_single_cluster=True, metric="l2") + clusterer = hdbscan.HDBSCAN(min_cluster_size=d["min_cluster_size"], allow_single_cluster=True, metric="l2") all_labels = clusterer.fit_predict(local_feature) # ~ t1 = time.perf_counter() # ~ print('HDBSCAN time', t1 - t0) @@ -529,4 +536,4 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, # TODO DEBUG and check assert wanted_chans.shape[0] == wfs.shape[2] - return wfs, wanted_chans + return wfs, wanted_chans \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index 635eb47cd2..cb354a1373 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -20,6 +20,14 @@ from spikeinterface.core import get_channel_distances from tqdm.auto import tqdm +try: + import hdbscan + + HAVE_HDBSCAN = True +except: + HAVE_HDBSCAN = False + + try: import pymde @@ -71,6 +79,7 @@ def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_TORCH, "SlidingNN needs torch to work" assert HAVE_NNDESCENT, "SlidingNN needs pynndescent to work" assert HAVE_PYMDE, "SlidingNN needs pymde to work" + assert HAVE_HDBSCAN, "SlidingNN needs hdbscan to work" d = params tmp_folder = params["tmp_folder"] @@ -285,7 +294,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): "Clustering MDE embeddings (n={}): {}".format(embeddings_chunk.shape, datetime.datetime.now()) ) # TODO HDBSCAN can be done on GPU with NVIDIA RAPIDS for speed - clusterer = HDBSCAN( + clusterer = hdbscan.HDBSCAN( prediction_data=True, core_dist_n_jobs=job_kwargs["n_jobs"], **d["hdbscan_kwargs"], @@ -645,4 +654,4 @@ def embed_graph( if mde_device == "cuda": x = np.array(x.cpu()) - return x + return x \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index bfa1cbe035..777bd6c003 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -259,7 +259,7 @@ def split( final_features = flatten_features if clusterer == "hdbscan": - from sklearn.cluster import HDBSCAN + from hdbscan import HDBSCAN clust = HDBSCAN(**clusterer_kwargs, core_dist_n_jobs=1) clust.fit(final_features) @@ -327,4 +327,4 @@ def split( split_methods_list = [ LocalFeatureClustering, ] -split_methods_dict = {e.name: e for e in split_methods_list} +split_methods_dict = {e.name: e for e in split_methods_list} \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index 42b893810e..b9aa75953b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -57,7 +57,7 @@ class TdcClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - from sklearn.cluster import HDBSCAN + import hdbscan if params["folder"] is None: randname = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) @@ -239,4 +239,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): shutil.rmtree(clustering_folder) extra_out = {"peak_shifts": peak_shifts} - return labels_set, post_clean_label, extra_out + return labels_set, post_clean_label, extra_out \ No newline at end of file From d6917c91a23ef3211122a1703910a7496a95fb5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Mar 2025 10:19:40 +0000 Subject: [PATCH 24/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyproject.toml | 2 +- src/spikeinterface/sorters/internal/simplesorter.py | 2 +- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- src/spikeinterface/sorters/internal/tridesclous2.py | 2 +- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- src/spikeinterface/sortingcomponents/clustering/merge.py | 2 +- src/spikeinterface/sortingcomponents/clustering/position.py | 2 +- .../sortingcomponents/clustering/position_and_features.py | 2 +- .../sortingcomponents/clustering/position_and_pca.py | 2 +- .../sortingcomponents/clustering/position_ptp_scaled.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 +- .../sortingcomponents/clustering/sliding_hdbscan.py | 2 +- src/spikeinterface/sortingcomponents/clustering/sliding_nn.py | 2 +- src/spikeinterface/sortingcomponents/clustering/split.py | 2 +- src/spikeinterface/sortingcomponents/clustering/tdc.py | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f6d83b6616..97ba77299e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,4 +235,4 @@ markers = [ filterwarnings =[ 'ignore:.*distutils Version classes are deprecated.*:DeprecationWarning', 'ignore:.*the imp module is deprecated in favour of importlib.*:DeprecationWarning', -] \ No newline at end of file +] diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 91d1fccd84..0f44e4079a 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -232,4 +232,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) sorting_final = sorting_final.save(folder=sorter_output_folder / "sorting") - return sorting_final \ No newline at end of file + return sorting_final diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f2ac44c191..50a6603d24 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -103,7 +103,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): HAVE_HDBSCAN = False assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" - + try: import torch except ImportError: diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 1bbd852b4f..65dfb2ed45 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -258,4 +258,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = NumpySorting(final_spikes, sampling_frequency, labels_set) sorting = sorting.save(folder=sorter_output_folder / "sorting") - return sorting \ No newline at end of file + return sorting diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 7f12cd09fb..7bce0800d3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -314,4 +314,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("Kept %d non-duplicated clusters" % len(labels)) - return labels, peak_labels \ No newline at end of file + return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 51b46e3f67..93db9a268f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -775,4 +775,4 @@ def remove_duplicates_via_dip(wfs_arrays, peak_labels, dip_threshold=1, cosine_t labels = np.unique(new_labels) labels = labels[labels >= 0] - return labels, new_labels \ No newline at end of file + return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index b1ce79c557..c9a9841c57 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -753,4 +753,4 @@ def merge( ProjectDistribution, NormalizedTemplateDiff, ] -find_pair_method_dict = {e.name: e for e in find_pair_method_list} \ No newline at end of file +find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/position.py b/src/spikeinterface/sortingcomponents/clustering/position.py index 3d86c70b5f..dc76d787f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position.py +++ b/src/spikeinterface/sortingcomponents/clustering/position.py @@ -77,4 +77,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): fig1.savefig(tmp_folder / "peak_locations.png") fig2.savefig(tmp_folder / "peak_locations_clustered.png") - return labels, peak_labels \ No newline at end of file + return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index 4540bdca11..20067a2eec 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -190,4 +190,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): print("We kept %d non-duplicated clusters..." % len(labels)) - return labels, peak_labels \ No newline at end of file + return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index 3e4b8e997c..c4f372fc21 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -240,4 +240,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels = np.unique(clean_peak_labels) labels = labels[labels >= 0] - return labels, clean_peak_labels \ No newline at end of file + return labels, clean_peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py index 9953ccdaa4..0f7390d7ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py @@ -97,4 +97,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): fig1.savefig(tmp_folder / "peak_locations.png") fig2.savefig(tmp_folder / "peak_locations_clustered.png") - return labels, peak_labels \ No newline at end of file + return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index ffbede8724..e67e1907f1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -186,4 +186,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("Kept %d non-duplicated clusters" % len(labels)) - return labels, peak_labels \ No newline at end of file + return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 2f6c84a05c..5f8ac99848 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -536,4 +536,4 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, # TODO DEBUG and check assert wanted_chans.shape[0] == wfs.shape[2] - return wfs, wanted_chans \ No newline at end of file + return wfs, wanted_chans diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index cb354a1373..40cedacdc5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -654,4 +654,4 @@ def embed_graph( if mde_device == "cuda": x = np.array(x.cpu()) - return x \ No newline at end of file + return x diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 777bd6c003..310ee0969f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -327,4 +327,4 @@ def split( split_methods_list = [ LocalFeatureClustering, ] -split_methods_dict = {e.name: e for e in split_methods_list} \ No newline at end of file +split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index b9aa75953b..027503a7c8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -239,4 +239,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): shutil.rmtree(clustering_folder) extra_out = {"peak_shifts": peak_shifts} - return labels_set, post_clean_label, extra_out \ No newline at end of file + return labels_set, post_clean_label, extra_out From 12b5edb4aa06bbb335e630e6c11dd87ee5c5ff84 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 5 Mar 2025 11:19:55 +0100 Subject: [PATCH 25/35] Reverting --- src/spikeinterface/sorters/external/tridesclous.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/tridesclous.py b/src/spikeinterface/sorters/external/tridesclous.py index 460acc2287..e9a22d0951 100644 --- a/src/spikeinterface/sorters/external/tridesclous.py +++ b/src/spikeinterface/sorters/external/tridesclous.py @@ -64,7 +64,9 @@ def is_installed(cls): except ImportError: HAVE_TDC = False except: - print("tridesclous is installed, but it has some dependency problems, check numba install!") + print( + "tridesclous is installed, but it has some dependency problems, check numba or hdbscan installations!" + ) HAVE_TDC = False return HAVE_TDC From 8282461da03e2dcaa0d8c9ab1c45194c98175935 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 12 Mar 2025 16:44:01 +0100 Subject: [PATCH 26/35] Bringing back optimal n jobs --- src/spikeinterface/sortingcomponents/clustering/circus.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 7bce0800d3..e4af325f71 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -23,7 +23,7 @@ from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.sortingcomponents.tools import remove_empty_templates +from spikeinterface.sortingcomponents.tools import remove_empty_templates, get_optimal_n_jobs import pickle, json from spikeinterface.core.node_pipeline import ( run_node_pipeline, @@ -257,6 +257,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + job_kwargs_local = job_kwargs.copy() + ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4 + job_kwargs_local = get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) + templates_array = estimate_templates( recording, spikes, @@ -265,7 +269,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): nafter, return_scaled=False, job_name=None, - **job_kwargs, + **job_kwargs_local, ) best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) From ce5676085063b820c496e8a9c147294592d49d3f Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 12 Mar 2025 16:45:16 +0100 Subject: [PATCH 27/35] Fixes --- src/spikeinterface/sortingcomponents/tools.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 56b128d028..0c12fa1453 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -281,7 +281,8 @@ def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory dtype_size_bytes = recording.get_dtype().itemsize chunk_size = memory_usage / ((num_channels * dtype_size_bytes) * n_jobs) chunk_duration = chunk_size / recording.get_sampling_frequency() - job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) + job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s")) + job_kwargs = fix_job_kwargs(job_kwargs) else: import warnings @@ -294,7 +295,8 @@ def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory dtype_size_bytes = recording.get_dtype().itemsize chunk_size = (num_channels * dtype_size_bytes) * n_jobs / total_memory chunk_duration = chunk_size / recording.get_sampling_frequency() - job_kwargs = fix_job_kwargs(dict(chunk_duration=f"{chunk_duration}s")) + job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s")) + job_kwargs = fix_job_kwargs(job_kwargs) return job_kwargs From 7cc16ce369d3ede45846a094fb96274844dc852f Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 12 Mar 2025 16:53:45 +0100 Subject: [PATCH 28/35] WIP --- src/spikeinterface/sortingcomponents/clustering/circus.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index e4af325f71..aec7e77805 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -65,6 +65,7 @@ class CircusClustering: "noise_levels": None, "tmp_folder": None, "verbose": True, + "memory_limit":0.25, "debug": False, } From c4faee6cf69acdfbfbfe7d0f6ae57b5a15cdfd73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Mar 2025 15:55:00 +0000 Subject: [PATCH 29/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index aec7e77805..53ec116a22 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -65,7 +65,7 @@ class CircusClustering: "noise_levels": None, "tmp_folder": None, "verbose": True, - "memory_limit":0.25, + "memory_limit": 0.25, "debug": False, } From 08f579e58d85b29eb3bb95312f148f590623bfed Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 4 Apr 2025 14:37:15 +0200 Subject: [PATCH 30/35] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 6 +++--- src/spikeinterface/sortingcomponents/clustering/circus.py | 4 ++-- src/spikeinterface/sortingcomponents/tools.py | 6 ++++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 50a6603d24..cee25fa124 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -15,7 +15,7 @@ cache_preprocessing, get_prototype_and_waveforms_from_recording, get_shuffled_recording_slices, - set_optimal_chunk_size, + _set_optimal_chunk_size, ) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity @@ -91,7 +91,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2.0" + return "2.0.1" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): @@ -121,7 +121,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) - job_kwargs = set_optimal_chunk_size(recording, job_kwargs, **params["chunk_preprocessing"]) + job_kwargs = _set_optimal_chunk_size(recording, job_kwargs, **params["chunk_preprocessing"]) sampling_frequency = recording.get_sampling_frequency() num_channels = recording.get_num_channels() diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 53ec116a22..3286d390ff 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -23,7 +23,7 @@ from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.sortingcomponents.tools import remove_empty_templates, get_optimal_n_jobs +from spikeinterface.sortingcomponents.tools import remove_empty_templates, _get_optimal_n_jobs import pickle, json from spikeinterface.core.node_pipeline import ( run_node_pipeline, @@ -260,7 +260,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): job_kwargs_local = job_kwargs.copy() ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4 - job_kwargs_local = get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) + job_kwargs_local = _get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) templates_array = estimate_templates( recording, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 0c12fa1453..782fb0392d 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -249,7 +249,7 @@ def check_probe_for_drift_correction(recording, dist_x_max=60): return True -def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None): +def _set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None): """ Set the optimal chunk size for a job given the memory_limit and the number of jobs @@ -300,7 +300,7 @@ def set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory return job_kwargs -def get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): +def _get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): """ Set the optimal chunk size for a job given the memory_limit and the number of jobs @@ -458,3 +458,5 @@ def get_shuffled_recording_slices(recording, seed=None, **job_kwargs): recording_slices = rng.permutation(recording_slices) return recording_slices + + From c26b09222fdfa667e3887d3b8d8c9c0e1d8dc1c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Apr 2025 12:39:17 +0000 Subject: [PATCH 31/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/tools.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 782fb0392d..847ca9c8d8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -458,5 +458,3 @@ def get_shuffled_recording_slices(recording, seed=None, **job_kwargs): recording_slices = rng.permutation(recording_slices) return recording_slices - - From 16e9a512be281c26a77d12f2b7ff549d474d0da8 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 4 Apr 2025 16:49:40 +0200 Subject: [PATCH 32/35] Desactivate by default --- src/spikeinterface/sorters/internal/spyking_circus2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index cee25fa124..8e711c65d8 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -45,7 +45,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_preprocessing": True, "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, - "chunk_preprocessing": {"memory_limit": 0.01}, + "chunk_preprocessing": {"memory_limit": None}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.75}, "seed": 42, @@ -121,7 +121,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) - job_kwargs = _set_optimal_chunk_size(recording, job_kwargs, **params["chunk_preprocessing"]) + if params["chunk_preprocessing"].get("memory_limit", None) is not None: + job_kwargs = _set_optimal_chunk_size(recording, job_kwargs, **params["chunk_preprocessing"]) sampling_frequency = recording.get_sampling_frequency() num_channels = recording.get_num_channels() From 51f753a0fedab87f36c230b1808ef31e29f476af Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 7 Apr 2025 16:02:48 +0200 Subject: [PATCH 33/35] WIP --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- src/spikeinterface/sortingcomponents/tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 3286d390ff..96a02b762d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -308,7 +308,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs = job_kwargs_local.copy() cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 847ca9c8d8..6a380b4d56 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -325,7 +325,7 @@ def _get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): if HAVE_PSUTIL: assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" memory_usage = memory_limit * psutil.virtual_memory().available - n_jobs = int(min(n_jobs, memory_usage // ram_requested)) + n_jobs = max(1, int(min(n_jobs, memory_usage // ram_requested))) job_kwargs.update(dict(n_jobs=n_jobs)) else: import warnings From 456e25738cdc2a0c82315678be49fbf01e962c63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 20:36:59 +0000 Subject: [PATCH 34/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 2a757a0bf3..47d8dcf77c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -157,6 +157,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if not params["templates_from_svd"]: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording + job_kwargs_local = job_kwargs.copy() unit_ids = np.unique(peak_labels) ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4 From f375efc8d5de56937da6d3f512741186a05ba700 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Jun 2025 07:45:30 +0000 Subject: [PATCH 35/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 0aec3718ae..3e3695b9ef 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -407,7 +407,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # np.save(fitting_folder / "amplitudes", guessed_amplitudes) if sorting.get_non_empty_unit_ids().size > 0: - final_analyzer = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) + final_analyzer = final_cleaning_circus( + recording_w, sorting, templates, **merging_params, **job_kwargs + ) final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer") sorting = final_analyzer.sorting