From af9d5aa86f3f1782ef51c736e92f225178ac9314 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 19 May 2026 12:38:14 -0600 Subject: [PATCH] migrate generete ground truth --- .../tests/common_benchmark_testing.py | 5 +- .../benchmark/tests/test_benchmark_sorter.py | 2 +- .../tests/test_benchmark_sorter_without_gt.py | 2 +- .../tests/test_templatecomparison.py | 3 +- src/spikeinterface/core/__init__.py | 15 +- src/spikeinterface/core/generate.py | 204 ++---------------- .../tests/test_analyzer_extension_core.py | 4 +- src/spikeinterface/core/tests/test_base.py | 3 +- .../core/tests/test_basesorting.py | 2 +- .../core/tests/test_generate.py | 5 +- src/spikeinterface/core/tests/test_loading.py | 2 +- .../core/tests/test_node_pipeline.py | 3 +- .../core/tests/test_recording_tools.py | 12 +- .../core/tests/test_sorting_tools.py | 4 +- .../core/tests/test_sortinganalyzer.py | 4 +- .../core/tests/test_sparsity.py | 4 +- .../core/tests/test_template_tools.py | 5 +- .../core/tests/test_time_series_tools.py | 8 +- .../core/tests/test_waveform_tools.py | 5 +- ...forms_extractor_backwards_compatibility.py | 5 +- src/spikeinterface/curation/tests/common.py | 5 +- .../curation/tests/test_curation_format.py | 3 +- src/spikeinterface/exporters/tests/common.py | 5 +- .../extractors/phykilosortextractors.py | 2 +- .../extractors/tests/test_mdaextractors.py | 2 +- .../tests/test_shybridextractors.py | 2 +- src/spikeinterface/extractors/toy_example.py | 4 +- src/spikeinterface/generation/__init__.py | 3 +- .../generation/ground_truth_generator.py | 201 +++++++++++++++++ src/spikeinterface/generation/noise_tools.py | 1 - .../generation/tests/test_splitting_tools.py | 2 +- src/spikeinterface/metrics/conftest.py | 4 +- .../quality/tests/test_metrics_functions.py | 4 +- .../tests/test_quality_metric_calculator.py | 2 +- .../tests/common_extension_tests.py | 4 +- .../postprocessing/tests/conftest.py | 2 +- .../tests/test_extension_merges.py | 3 +- .../tests/test_multi_extensions.py | 6 +- .../tests/test_unit_locations.py | 3 +- .../preprocessing/silence_periods.py | 1 - .../external/tests/test_docker_containers.py | 2 +- .../sorters/external/tests/test_kilosort4.py | 3 +- .../tests/test_singularity_containers.py | 2 +- .../tests/test_singularity_containers_gpu.py | 2 +- .../sorters/tests/common_tests.py | 2 +- .../sorters/tests/test_container_tools.py | 2 +- .../sorters/tests/test_launcher.py | 2 +- .../sorters/tests/test_runsorter.py | 3 +- .../tests/test_runsorter_dependency_checks.py | 2 +- .../motion/tests/test_motion_interpolation.py | 4 +- .../sortingcomponents/tests/common.py | 4 +- .../waveforms/tests/conftest.py | 2 +- .../widgets/tests/test_widgets.py | 4 +- 53 files changed, 311 insertions(+), 279 deletions(-) create mode 100644 src/spikeinterface/generation/ground_truth_generator.py diff --git a/src/spikeinterface/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py index 49f9e9350a..bf16279ffb 100644 --- a/src/spikeinterface/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py @@ -12,13 +12,12 @@ import numpy as np from spikeinterface.core import ( - generate_ground_truth_recording, estimate_templates, Templates, create_sorting_analyzer, ms_to_samples, ) -from spikeinterface.generation import generate_drifting_recording +from spikeinterface.generation import generate_drifting_recording, generate_ground_truth_recording ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -39,7 +38,7 @@ def make_dataset(job_kwargs={}): contact_shape_params={"radius": 6}, ), generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index ceaa9331cc..4a118aa648 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -2,7 +2,7 @@ import pytest from pathlib import Path -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.preprocessing import bandpass_filter from spikeinterface.benchmark import SorterStudy diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter_without_gt.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter_without_gt.py index 62492cbdbe..032fd43666 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter_without_gt.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter_without_gt.py @@ -2,7 +2,7 @@ import pytest from pathlib import Path -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.preprocessing import bandpass_filter from spikeinterface.benchmark import SorterStudyWithoutGroundTruth diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index 001f80f4f6..aca4d06122 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -2,7 +2,8 @@ import pytest import numpy as np -from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording +from spikeinterface.core import create_sorting_analyzer +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.comparison import compare_templates, compare_multiple_templates # def setup_module(): diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index c32d919bc7..a047f191f2 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -48,7 +48,6 @@ generate_recording_by_size, InjectTemplatesRecording, inject_templates, - generate_ground_truth_recording, ) # utils to append and concatenate segment (equivalent to OLD MultiRecordingTimeExtractor) @@ -213,4 +212,18 @@ def __getattr__(name): "noise_generator_recording": noise_generator_recording, } return _map[name] + if name == "generate_ground_truth_recording": + import warnings + + warnings.warn( + "Importing generate_ground_truth_recording from spikeinterface.core is deprecated. " + "Import from spikeinterface.generation instead: " + "`from spikeinterface.generation import generate_ground_truth_recording`. " + "This will be removed in version 0.106.0.", + FutureWarning, + stacklevel=2, + ) + from spikeinterface.generation.ground_truth_generator import generate_ground_truth_recording + + return generate_ground_truth_recording raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 1470473fe2..92ed827b79 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2266,196 +2266,6 @@ def generate_unit_locations( return units_locations -def generate_ground_truth_recording( - durations=[10.0], - sampling_frequency=25000.0, - num_channels=4, - num_units=10, - sorting=None, - probe=None, - generate_probe_kwargs=dict( - num_columns=2, - xpitch=20, - ypitch=20, - contact_shapes="circle", - contact_shape_params={"radius": 6}, - ), - templates=None, - ms_before=1.0, - ms_after=3.0, - upsample_factor=None, - upsample_vector=None, - generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), - generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), - generate_templates_kwargs=None, - dtype="float32", - seed=None, -): - """ - Generate a recording with spike given a probe+sorting+templates. - - Parameters - ---------- - durations : list[float], default: [10.] - Durations in seconds for all segments. - sampling_frequency : float, default: 25000.0 - Sampling frequency. - num_channels : int, default: 4 - Number of channels, not used when probe is given. - num_units : int, default: 10 - Number of units, not used when sorting is given. - sorting : Sorting | None - An external sorting object. If not provide, one is genrated. - probe : Probe | None - An external Probe object. If not provided a probe is generated using generate_probe_kwargs. - generate_probe_kwargs : dict - A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. - templates : np.ndarray | None - The templates of units. - If None they are generated. - Shape can be: - - * (num_units, num_samples, num_channels): standard case - * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. - ms_before : float, default: 1.5 - Cut out in ms before spike peak. - ms_after : float, default: 3.0 - Cut out in ms after spike peak. - upsample_factor : None | int, default: None - A upsampling factor used only when templates are not provided. - upsample_vector : np.ndarray | None - Optional the upsample_vector can given. This has the same shape as spike_vector - generate_sorting_kwargs : dict - When sorting is not provide, this dict is used to generated a Sorting. - noise_kwargs : dict - Dict used to generated the noise with NoiseGeneratorRecording. - generate_unit_locations_kwargs : dict - Dict used to generated template when template not provided. - generate_templates_kwargs : dict - Dict used to generated template when template not provided. - dtype : np.dtype, default: "float32" - The dtype of the recording. - seed : int | None - Seed for random initialization. - If None a diffrent Recording is generated at every call. - Note: even with None a generated recording keep internaly a seed to regenerate the same signal after dump/load. - - Returns - ------- - recording : Recording - The generated recording extractor. - sorting : Sorting - The generated sorting extractor. - """ - generate_templates_kwargs = generate_templates_kwargs or dict() - - # TODO implement upsample_factor in InjectTemplatesRecording and propagate into toy_example - - # if None so the same seed will be used for all steps - seed = _ensure_seed(seed) - rng = np.random.default_rng(seed) - - if sorting is None: - generate_sorting_kwargs = generate_sorting_kwargs.copy() - generate_sorting_kwargs["durations"] = durations - generate_sorting_kwargs["num_units"] = num_units - generate_sorting_kwargs["sampling_frequency"] = sampling_frequency - generate_sorting_kwargs["seed"] = seed - sorting = generate_sorting(**generate_sorting_kwargs) - else: - num_units = sorting.get_num_units() - assert sorting.sampling_frequency == sampling_frequency - num_spikes = sorting.to_spike_vector().size - - if probe is None: - # probe = generate_linear_probe(num_elec=num_channels) - # probe.set_device_channel_indices(np.arange(num_channels)) - - prb_kwargs = generate_probe_kwargs.copy() - if "num_contact_per_column" in prb_kwargs: - assert ( - prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"] - ) == num_channels, ( - "generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns" - ) - elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs: - n = num_channels // prb_kwargs["num_columns"] - num_contact_per_column = [n] * prb_kwargs["num_columns"] - mid = prb_kwargs["num_columns"] // 2 - num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"] - prb_kwargs["num_contact_per_column"] = num_contact_per_column - else: - raise ValueError("num_columns should be provided in dict generate_probe_kwargs") - - probe = generate_multi_columns_probe(**prb_kwargs) - probe.set_device_channel_indices(np.arange(num_channels)) - - else: - num_channels = probe.get_contact_count() - - if templates is None: - channel_locations = probe.contact_positions - unit_locations = generate_unit_locations( - num_units, channel_locations, seed=seed, **generate_unit_locations_kwargs - ) - templates = generate_templates( - channel_locations, - unit_locations, - sampling_frequency, - ms_before, - ms_after, - upsample_factor=upsample_factor, - seed=seed, - dtype=dtype, - **generate_templates_kwargs, - ) - sorting.set_property("gt_unit_locations", unit_locations) - else: - assert templates.shape[0] == num_units - - if templates.ndim == 3: - upsample_vector = None - else: - if upsample_vector is None: - upsample_factor = templates.shape[3] - upsample_vector = rng.integers(0, upsample_factor, size=num_spikes) - - nbefore = ms_to_samples(ms_before, sampling_frequency) - nafter = ms_to_samples(ms_after, sampling_frequency) - assert (nbefore + nafter) == templates.shape[1] - - # construct recording - from spikeinterface.generation.noise_tools import NoiseGeneratorRecording - - noise_rec = NoiseGeneratorRecording( - num_channels=num_channels, - sampling_frequency=sampling_frequency, - durations=durations, - dtype=dtype, - seed=seed, - noise_block_size=int(sampling_frequency), - **noise_kwargs, - ) - - recording = InjectTemplatesRecording( - sorting, - templates, - nbefore=nbefore, - parent_recording=noise_rec, - upsample_vector=upsample_vector, - ) - recording.annotate(is_filtered=True) - recording.set_probe(probe, in_place=True) - recording.set_channel_gains(1.0) - recording.set_channel_offsets(0.0) - - recording.name = "GroundTruthRecording" - sorting.name = "GroundTruthSorting" - - return recording, sorting - - def __getattr__(name): if name in ("NoiseGeneratorRecording", "noise_generator_recording"): import warnings @@ -2475,4 +2285,18 @@ def __getattr__(name): "noise_generator_recording": noise_generator_recording, } return _map[name] + if name == "generate_ground_truth_recording": + import warnings + + warnings.warn( + "Importing generate_ground_truth_recording from spikeinterface.core.generate is deprecated. " + "Import from spikeinterface.generation instead: " + "`from spikeinterface.generation import generate_ground_truth_recording`. " + "This will be removed in version 0.106.0.", + FutureWarning, + stacklevel=2, + ) + from spikeinterface.generation.ground_truth_generator import generate_ground_truth_recording + + return generate_ground_truth_recording raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index fd550c729e..92426c5802 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -4,7 +4,7 @@ from pathlib import Path -from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer from spikeinterface.core import Templates @@ -30,7 +30,7 @@ def get_sorting_analyzer(cache_folder, format="memory", sparse=True): alpha=(200.0, 500.0), ) ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=2406, ) if format == "memory": diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index e3c3149ab0..802f487ed1 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -6,7 +6,8 @@ from typing import Sequence import numpy as np from spikeinterface.core.base import BaseExtractor -from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings +from spikeinterface.core import generate_recording, concatenate_recordings +from spikeinterface.generation import generate_ground_truth_recording class DummyDictExtractor(BaseExtractor): diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6c06b212b8..5342f3114a 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -15,12 +15,12 @@ SharedMemorySorting, NpzFolderSorting, NumpyFolderSorting, - generate_ground_truth_recording, generate_sorting, create_sorting_npz, generate_sorting, load, ) +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.base import BaseExtractor, unit_period_dtype from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index c5a0b83f87..fa2f97d4c1 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -18,11 +18,10 @@ generate_templates, generate_channel_locations, generate_unit_locations, - generate_ground_truth_recording, generate_sorting_to_inject, synthesize_random_firings, ) -from spikeinterface.generation import NoiseGeneratorRecording +from spikeinterface.generation import NoiseGeneratorRecording, generate_ground_truth_recording from spikeinterface.core.numpyextractors import NumpySorting @@ -220,7 +219,6 @@ def test_noise_generator_several_noise_levels(): dtype="float32", seed=32, noise_levels=1, - strategy="on_the_fly", noise_block_size=20000, ) assert np.all(np.abs(get_noise_levels(rec1) - 1) < 0.1) @@ -232,7 +230,6 @@ def test_noise_generator_several_noise_levels(): dtype="float32", seed=32, noise_levels=[0, 1, 2, 3], - strategy="on_the_fly", noise_block_size=20000, ) assert np.all(np.abs(get_noise_levels(rec2) - np.arange(4)) < 0.1) diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index c9d6e888f9..f607aff7d1 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -2,7 +2,6 @@ import numpy as np from spikeinterface import ( - generate_ground_truth_recording, create_sorting_analyzer, load, ms_to_samples, @@ -10,6 +9,7 @@ Templates, aggregate_channels, ) +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.motion import Motion from spikeinterface.core.generate import generate_unit_locations, generate_templates from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 61c2fda873..89cb6dc727 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,8 @@ from pathlib import Path import shutil -from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface import create_sorting_analyzer, get_template_extremum_channel +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.base import spike_peak_dtype from spikeinterface.core.job_tools import divide_time_series_into_chunks diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 477e4f04aa..08f343feea 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -33,7 +33,6 @@ def test_write_binary_recording(tmp_path): durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw"] @@ -60,7 +59,6 @@ def test_write_binary_recording_offset(tmp_path): durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw"] @@ -95,7 +93,6 @@ def test_write_binary_recording_parallel(tmp_path): num_channels=num_channels, sampling_frequency=sampling_frequency, dtype=dtype, - strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] @@ -125,7 +122,6 @@ def test_write_binary_recording_multiple_segment(tmp_path): durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] @@ -146,9 +142,7 @@ def test_write_binary_recording_multiple_segment(tmp_path): def test_write_memory_recording(): # 2 segments - recording = MockRecording( - num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" - ) + recording = MockRecording(num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000) recording = recording.save() # write with loop @@ -316,9 +310,7 @@ def test_order_channels_by_depth(): def test_do_recording_attributes_match(): - recording = MockRecording( - num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" - ) + recording = MockRecording(num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000) rec_attributes = get_rec_attributes(recording) do_match, _ = do_recording_attributes_match(recording, rec_attributes) assert do_match diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 4194f459b3..405df715d2 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -4,7 +4,7 @@ from spikeinterface.core import NumpySorting -from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.sorting_tools import ( spike_vector_to_spike_trains, random_spikes_selection, @@ -52,7 +52,7 @@ def test_random_spikes_selection(): num_channels=10, num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) max_spikes_per_unit = 12 diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..05702748cb 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -5,13 +5,13 @@ import shutil from spikeinterface.core import ( - generate_ground_truth_recording, create_sorting_analyzer, load_sorting_analyzer, get_available_analyzer_extensions, get_default_analyzer_extension_params, get_default_zarr_compressor, ) +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.sortinganalyzer import ( register_result_extension, AnalyzerExtension, @@ -31,7 +31,7 @@ def get_dataset(): num_channels=10, num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 6e85221621..dc147e0af3 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -5,7 +5,7 @@ from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, get_noise_levels from spikeinterface.core.core_tools import check_json -from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer @@ -152,7 +152,7 @@ def get_dataset(): num_channels=10, num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=1.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=1.0), seed=2205, ) recording.set_property("group", ["a"] * 5 + ["b"] * 5) diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index a28680612a..f04a7f181e 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -1,7 +1,8 @@ import pytest import numpy as np -from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface import Templates @@ -19,7 +20,7 @@ def get_sorting_analyzer(): sampling_frequency=10_000.0, num_channels=4, num_units=10, - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) recording.annotate(is_filtered=True) diff --git a/src/spikeinterface/core/tests/test_time_series_tools.py b/src/spikeinterface/core/tests/test_time_series_tools.py index 4c6ba6b105..d41d223ae7 100644 --- a/src/spikeinterface/core/tests/test_time_series_tools.py +++ b/src/spikeinterface/core/tests/test_time_series_tools.py @@ -25,7 +25,6 @@ def test_write_binary(tmp_path): durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw"] @@ -52,7 +51,6 @@ def test_write_binary_offset(tmp_path): durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw"] @@ -85,7 +83,6 @@ def test_write_binary_parallel(tmp_path): num_channels=num_channels, sampling_frequency=sampling_frequency, dtype=dtype, - strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] @@ -115,7 +112,6 @@ def test_write_binary_multiple_segment(tmp_path): durations=durations, num_channels=num_channels, sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", ) file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] @@ -136,9 +132,7 @@ def test_write_binary_multiple_segment(tmp_path): def test_write_memory_recording(): # 2 segments - recording = MockRecording( - num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" - ) + recording = MockRecording(num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000) recording = recording.save() # write with loop diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 5e0350f833..af30dce079 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -5,7 +5,8 @@ import numpy as np -from spikeinterface.core import generate_recording, generate_sorting, generate_ground_truth_recording, ms_to_samples +from spikeinterface.core import generate_recording, generate_sorting, ms_to_samples +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.waveform_tools import ( extract_waveforms_to_buffers, extract_waveforms_to_single_buffer, @@ -29,7 +30,7 @@ def get_dataset(): num_channels=4, num_units=7, generate_sorting_kwargs=dict(firing_rates=5.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=1.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=1.0), seed=2205, ) return recording, sorting diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index e75f3be156..608e9809b2 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -6,7 +6,8 @@ import numpy as np -from spikeinterface.core import generate_ground_truth_recording, SortingAnalyzer +from spikeinterface.core import SortingAnalyzer +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.core_tools import SIJsonEncoder from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor @@ -34,7 +35,7 @@ def get_dataset(): alpha=(100.0, 500.0), ) ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=2406, ) return recording, sorting diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index a665b074a6..27367d244b 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -1,6 +1,7 @@ import pytest -from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer, aggregate_units +from spikeinterface.core import create_sorting_analyzer, aggregate_units +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.generate import inject_some_split_units from spikeinterface.curation import train_model from pathlib import Path @@ -25,7 +26,7 @@ def make_sorting_analyzer(sparse=True, num_units=5, durations=[300.0]): num_channels=4, num_units=num_units, generate_sorting_kwargs=dict(firing_rates=20.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 40e0a9b2b7..58b934d4e0 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -4,7 +4,8 @@ import json import numpy as np -from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.curation.curation_format import ( validate_curation_dict, diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index d86df0c6c8..e890e763ae 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -1,7 +1,8 @@ import pytest from pathlib import Path -from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer, compute_sparsity +from spikeinterface.core import create_sorting_analyzer, compute_sparsity +from spikeinterface.generation import generate_ground_truth_recording def make_sorting_analyzer(sparse=True, with_group=False): @@ -18,7 +19,7 @@ def make_sorting_analyzer(sparse=True, with_group=False): contact_shape_params={"radius": 6}, ), generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..6fbb510547 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -7,12 +7,12 @@ BaseSorting, BaseSortingSegment, read_python, - generate_ground_truth_recording, ChannelSparsity, ComputeTemplates, create_sorting_analyzer, SortingAnalyzer, ) +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index 1ed930613f..fdb35da365 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -1,7 +1,7 @@ import pytest from pathlib import Path from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal -from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.extractors.extractor_classes import MdaRecordingExtractor, MdaSortingExtractor diff --git a/src/spikeinterface/extractors/tests/test_shybridextractors.py b/src/spikeinterface/extractors/tests/test_shybridextractors.py index e4458f0b36..8d0e5b8106 100644 --- a/src/spikeinterface/extractors/tests/test_shybridextractors.py +++ b/src/spikeinterface/extractors/tests/test_shybridextractors.py @@ -1,6 +1,6 @@ import pytest -from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.extractors.extractor_classes import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor diff --git a/src/spikeinterface/extractors/toy_example.py b/src/spikeinterface/extractors/toy_example.py index 626ea0050a..29ef55775c 100644 --- a/src/spikeinterface/extractors/toy_example.py +++ b/src/spikeinterface/extractors/toy_example.py @@ -7,8 +7,8 @@ generate_channel_locations, generate_unit_locations, generate_templates, - generate_ground_truth_recording, ) +from spikeinterface.generation import generate_ground_truth_recording def toy_example( @@ -155,7 +155,7 @@ def toy_example( ms_after=ms_after, dtype="float32", seed=seed, - noise_kwargs=dict(noise_levels=10.0, strategy="on_the_fly"), + noise_kwargs=dict(noise_levels=10.0), ) return recording, sorting diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index e0fb98dbbd..0548838cc6 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -15,6 +15,8 @@ ) from .noise_tools import generate_noise, NoiseGeneratorRecording, noise_generator_recording +from .ground_truth_generator import generate_ground_truth_recording + from .splitting_tools import split_sorting_by_amplitudes, split_sorting_by_times from .drifting_generator import ( @@ -37,7 +39,6 @@ generate_snippets, generate_templates, generate_recording_by_size, - generate_ground_truth_recording, add_synchrony_to_sorting, synthesize_random_firings, inject_some_duplicate_units, diff --git a/src/spikeinterface/generation/ground_truth_generator.py b/src/spikeinterface/generation/ground_truth_generator.py new file mode 100644 index 0000000000..683afdfca5 --- /dev/null +++ b/src/spikeinterface/generation/ground_truth_generator.py @@ -0,0 +1,201 @@ +import numpy as np + +from probeinterface import generate_multi_columns_probe + +from spikeinterface.core.core_tools import ms_to_samples +from spikeinterface.core.generate import ( + _ensure_seed, + generate_sorting, + generate_unit_locations, + generate_templates, + InjectTemplatesRecording, +) + +from .noise_tools import NoiseGeneratorRecording + + +def generate_ground_truth_recording( + durations=[10.0], + sampling_frequency=25000.0, + num_channels=4, + num_units=10, + sorting=None, + probe=None, + generate_probe_kwargs=dict( + num_columns=2, + xpitch=20, + ypitch=20, + contact_shapes="circle", + contact_shape_params={"radius": 6}, + ), + templates=None, + ms_before=1.0, + ms_after=3.0, + upsample_factor=None, + upsample_vector=None, + generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), + noise_kwargs=dict(noise_levels=5.0), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), + generate_templates_kwargs=None, + dtype="float32", + seed=None, +): + """ + Generate a recording with spike given a probe+sorting+templates. + + Parameters + ---------- + durations : list[float], default: [10.] + Durations in seconds for all segments. + sampling_frequency : float, default: 25000.0 + Sampling frequency. + num_channels : int, default: 4 + Number of channels, not used when probe is given. + num_units : int, default: 10 + Number of units, not used when sorting is given. + sorting : Sorting | None + An external sorting object. If not provide, one is genrated. + probe : Probe | None + An external Probe object. If not provided a probe is generated using generate_probe_kwargs. + generate_probe_kwargs : dict + A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. + templates : np.ndarray | None + The templates of units. + If None they are generated. + Shape can be: + + * (num_units, num_samples, num_channels): standard case + * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. + ms_before : float, default: 1.5 + Cut out in ms before spike peak. + ms_after : float, default: 3.0 + Cut out in ms after spike peak. + upsample_factor : None | int, default: None + A upsampling factor used only when templates are not provided. + upsample_vector : np.ndarray | None + Optional the upsample_vector can given. This has the same shape as spike_vector + generate_sorting_kwargs : dict + When sorting is not provide, this dict is used to generated a Sorting. + noise_kwargs : dict + Dict used to generated the noise with NoiseGeneratorRecording. + generate_unit_locations_kwargs : dict + Dict used to generated template when template not provided. + generate_templates_kwargs : dict + Dict used to generated template when template not provided. + dtype : np.dtype, default: "float32" + The dtype of the recording. + seed : int | None + Seed for random initialization. + If None a diffrent Recording is generated at every call. + Note: even with None a generated recording keep internaly a seed to regenerate the same signal after dump/load. + + Returns + ------- + recording : Recording + The generated recording extractor. + sorting : Sorting + The generated sorting extractor. + """ + generate_templates_kwargs = generate_templates_kwargs or dict() + + # TODO implement upsample_factor in InjectTemplatesRecording and propagate into toy_example + + # if None so the same seed will be used for all steps + seed = _ensure_seed(seed) + rng = np.random.default_rng(seed) + + if sorting is None: + generate_sorting_kwargs = generate_sorting_kwargs.copy() + generate_sorting_kwargs["durations"] = durations + generate_sorting_kwargs["num_units"] = num_units + generate_sorting_kwargs["sampling_frequency"] = sampling_frequency + generate_sorting_kwargs["seed"] = seed + sorting = generate_sorting(**generate_sorting_kwargs) + else: + num_units = sorting.get_num_units() + assert sorting.sampling_frequency == sampling_frequency + num_spikes = sorting.to_spike_vector().size + + if probe is None: + # probe = generate_linear_probe(num_elec=num_channels) + # probe.set_device_channel_indices(np.arange(num_channels)) + + prb_kwargs = generate_probe_kwargs.copy() + if "num_contact_per_column" in prb_kwargs: + assert ( + prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"] + ) == num_channels, ( + "generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns" + ) + elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs: + n = num_channels // prb_kwargs["num_columns"] + num_contact_per_column = [n] * prb_kwargs["num_columns"] + mid = prb_kwargs["num_columns"] // 2 + num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"] + prb_kwargs["num_contact_per_column"] = num_contact_per_column + else: + raise ValueError("num_columns should be provided in dict generate_probe_kwargs") + + probe = generate_multi_columns_probe(**prb_kwargs) + probe.set_device_channel_indices(np.arange(num_channels)) + + else: + num_channels = probe.get_contact_count() + + if templates is None: + channel_locations = probe.contact_positions + unit_locations = generate_unit_locations( + num_units, channel_locations, seed=seed, **generate_unit_locations_kwargs + ) + templates = generate_templates( + channel_locations, + unit_locations, + sampling_frequency, + ms_before, + ms_after, + upsample_factor=upsample_factor, + seed=seed, + dtype=dtype, + **generate_templates_kwargs, + ) + sorting.set_property("gt_unit_locations", unit_locations) + else: + assert templates.shape[0] == num_units + + if templates.ndim == 3: + upsample_vector = None + else: + if upsample_vector is None: + upsample_factor = templates.shape[3] + upsample_vector = rng.integers(0, upsample_factor, size=num_spikes) + + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) + assert (nbefore + nafter) == templates.shape[1] + + noise_rec = NoiseGeneratorRecording( + num_channels=num_channels, + sampling_frequency=sampling_frequency, + durations=durations, + dtype=dtype, + seed=seed, + noise_block_size=int(sampling_frequency), + **noise_kwargs, + ) + + recording = InjectTemplatesRecording( + sorting, + templates, + nbefore=nbefore, + parent_recording=noise_rec, + upsample_vector=upsample_vector, + ) + recording.annotate(is_filtered=True) + recording.set_probe(probe, in_place=True) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) + + recording.name = "GroundTruthRecording" + sorting.name = "GroundTruthSorting" + + return recording, sorting diff --git a/src/spikeinterface/generation/noise_tools.py b/src/spikeinterface/generation/noise_tools.py index f64a8abe72..f9a73aad9f 100644 --- a/src/spikeinterface/generation/noise_tools.py +++ b/src/spikeinterface/generation/noise_tools.py @@ -284,7 +284,6 @@ def generate_noise( sampling_frequency=sampling_frequency, durations=durations, dtype=dtype, - strategy="on_the_fly", noise_levels=noise_levels, cov_matrix=cov_matrix, seed=seed, diff --git a/src/spikeinterface/generation/tests/test_splitting_tools.py b/src/spikeinterface/generation/tests/test_splitting_tools.py index 313df5c9fa..6066088d4c 100644 --- a/src/spikeinterface/generation/tests/test_splitting_tools.py +++ b/src/spikeinterface/generation/tests/test_splitting_tools.py @@ -1,7 +1,7 @@ from spikeinterface.generation import split_sorting_by_amplitudes, split_sorting_by_times from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.generate import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording def test_split_by_times(): diff --git a/src/spikeinterface/metrics/conftest.py b/src/spikeinterface/metrics/conftest.py index 5313e763c1..51038fc00b 100644 --- a/src/spikeinterface/metrics/conftest.py +++ b/src/spikeinterface/metrics/conftest.py @@ -1,9 +1,9 @@ import pytest from spikeinterface.core import ( - generate_ground_truth_recording, create_sorting_analyzer, ) +from spikeinterface.generation import generate_ground_truth_recording job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") @@ -65,7 +65,7 @@ def sorting_analyzer_simple(): alpha=(200.0, 500.0), ) ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=1205, ) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index e267b176ce..7eb10784d2 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -6,10 +6,10 @@ NumpySorting, synthetize_spike_train_bad_isi, add_synchrony_to_sorting, - generate_ground_truth_recording, create_sorting_analyzer, synthesize_random_firings, ) +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, create_regular_periods @@ -84,7 +84,7 @@ def _sorting_analyzer_violations(): sampling_frequency=sorting.sampling_frequency, num_channels=6, sorting=sorting, - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py index dfd47c4df9..ba8a7eefcc 100644 --- a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -2,11 +2,11 @@ import numpy as np from spikeinterface.core import ( - generate_ground_truth_recording, create_sorting_analyzer, NumpySorting, aggregate_units, ) +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.metrics.quality.misc_metrics import compute_snrs, compute_drift_metrics diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index e03d54fe74..a868258e47 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -3,11 +3,11 @@ import numpy as np from spikeinterface.core import ( - generate_ground_truth_recording, create_sorting_analyzer, load_sorting_analyzer, estimate_sparsity, ) +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.sortinganalyzer import get_extension_class extensions_which_allow_unit_ids = ["unit_locations"] @@ -30,7 +30,7 @@ def get_dataset(): alpha=(100.0, 500.0), ) ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) return recording, sorting diff --git a/src/spikeinterface/postprocessing/tests/conftest.py b/src/spikeinterface/postprocessing/tests/conftest.py index 51ac8aa250..75e6d56547 100644 --- a/src/spikeinterface/postprocessing/tests/conftest.py +++ b/src/spikeinterface/postprocessing/tests/conftest.py @@ -1,9 +1,9 @@ import pytest from spikeinterface.core import ( - generate_ground_truth_recording, create_sorting_analyzer, ) +from spikeinterface.generation import generate_ground_truth_recording def _small_sorting_analyzer(): diff --git a/src/spikeinterface/postprocessing/tests/test_extension_merges.py b/src/spikeinterface/postprocessing/tests/test_extension_merges.py index fa0310af5c..e075128a8c 100644 --- a/src/spikeinterface/postprocessing/tests/test_extension_merges.py +++ b/src/spikeinterface/postprocessing/tests/test_extension_merges.py @@ -1,6 +1,7 @@ import numpy as np -from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer +from spikeinterface.generation import generate_ground_truth_recording def test_correlograms_merge(): diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index e3c45fe8ef..ce4f117058 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -5,10 +5,10 @@ from spikeinterface import ( create_sorting_analyzer, - generate_ground_truth_recording, set_global_job_kwargs, get_template_extremum_amplitude, ) +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.generate import inject_some_split_units # even if this is in postprocessing, we make an extension for quality metrics @@ -70,7 +70,7 @@ def get_dataset_to_merge(): num_channels=10, num_units=10, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=2.0, maximum_z=15.0, minimum_distance=20), seed=2205, ) @@ -102,7 +102,7 @@ def get_dataset_to_split(): num_channels=10, num_units=10, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) diff --git a/src/spikeinterface/postprocessing/tests/test_unit_locations.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py index 4202b783e6..17d6e2d986 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_locations.py @@ -2,7 +2,8 @@ from spikeinterface.postprocessing import ComputeUnitLocations import pytest from probeinterface import Probe -from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording +from spikeinterface.core import create_sorting_analyzer +from spikeinterface.generation import generate_ground_truth_recording import numpy as np diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 1343d447b2..f475dfa7a8 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -94,7 +94,6 @@ def __init__( durations=[recording.select_segments(i).get_duration() for i in range(recording.get_num_segments())], dtype=recording.dtype, seed=seed, - strategy="on_the_fly", noise_block_size=int(recording.sampling_frequency), ) noise_generator = ScaleRecording(mock_noise, gain=noise_levels, dtype=recording.dtype) diff --git a/src/spikeinterface/sorters/external/tests/test_docker_containers.py b/src/spikeinterface/sorters/external/tests/test_docker_containers.py index 5c8a3f6777..857511231a 100644 --- a/src/spikeinterface/sorters/external/tests/test_docker_containers.py +++ b/src/spikeinterface/sorters/external/tests/test_docker_containers.py @@ -2,7 +2,7 @@ import pytest -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.core_tools import is_editable_mode import spikeinterface.sorters as ss diff --git a/src/spikeinterface/sorters/external/tests/test_kilosort4.py b/src/spikeinterface/sorters/external/tests/test_kilosort4.py index 5d82b84bc5..77e7a822aa 100644 --- a/src/spikeinterface/sorters/external/tests/test_kilosort4.py +++ b/src/spikeinterface/sorters/external/tests/test_kilosort4.py @@ -1,7 +1,8 @@ import unittest import pytest -from spikeinterface import load, generate_ground_truth_recording +from spikeinterface import load +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.sorters import Kilosort4Sorter, run_sorter from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite diff --git a/src/spikeinterface/sorters/external/tests/test_singularity_containers.py b/src/spikeinterface/sorters/external/tests/test_singularity_containers.py index c7ff54db53..0763780f18 100644 --- a/src/spikeinterface/sorters/external/tests/test_singularity_containers.py +++ b/src/spikeinterface/sorters/external/tests/test_singularity_containers.py @@ -2,7 +2,7 @@ import pytest -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.core_tools import is_editable_mode import spikeinterface.sorters as ss diff --git a/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py b/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py index eb238abdf4..8ce6714d31 100644 --- a/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py +++ b/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py @@ -3,7 +3,7 @@ import pytest -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.core_tools import is_editable_mode import spikeinterface.sorters as ss diff --git a/src/spikeinterface/sorters/tests/common_tests.py b/src/spikeinterface/sorters/tests/common_tests.py index 316ea21d40..a7ecc883e8 100644 --- a/src/spikeinterface/sorters/tests/common_tests.py +++ b/src/spikeinterface/sorters/tests/common_tests.py @@ -1,7 +1,7 @@ import pytest import shutil -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.sorters import run_sorter from spikeinterface.core.snippets_tools import snippets_from_sorting diff --git a/src/spikeinterface/sorters/tests/test_container_tools.py b/src/spikeinterface/sorters/tests/test_container_tools.py index 606fe9940e..a759249418 100644 --- a/src/spikeinterface/sorters/tests/test_container_tools.py +++ b/src/spikeinterface/sorters/tests/test_container_tools.py @@ -5,7 +5,7 @@ import os import spikeinterface as si -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.sorters.container_tools import find_recording_folders, ContainerClient, install_package_in_container import platform diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 7cca9d3eb0..87e9522ee0 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -5,7 +5,7 @@ import pytest from pathlib import Path from platform import system -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.sorters import run_sorter_jobs, run_sorter_by_property # no need to have many diff --git a/src/spikeinterface/sorters/tests/test_runsorter.py b/src/spikeinterface/sorters/tests/test_runsorter.py index 332e6e857e..107ad73022 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter.py +++ b/src/spikeinterface/sorters/tests/test_runsorter.py @@ -7,7 +7,8 @@ import json import numpy as np -from spikeinterface import generate_ground_truth_recording, load +from spikeinterface import load +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.sorters import run_sorter ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index 12ef8caafd..ccc2dde38b 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -1,6 +1,6 @@ import pytest import platform -from spikeinterface import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.sorters.utils import has_spython, has_docker_python, has_docker, has_singularity from spikeinterface.sorters import run_sorter import sys diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index 9635bad48b..fcb7ab4145 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -4,7 +4,7 @@ import pickle from spikeinterface import NumpyRecording, load -from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.core.motion import Motion from spikeinterface.core.testing import check_recordings_equal from spikeinterface.sortingcomponents.motion.motion_interpolation import ( @@ -241,7 +241,7 @@ def test_InterpolateMotionRecording(): contact_shape_params={"radius": 6}, ), generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) diff --git a/src/spikeinterface/sortingcomponents/tests/common.py b/src/spikeinterface/sortingcomponents/tests/common.py index e8bee33794..8777adf3ce 100644 --- a/src/spikeinterface/sortingcomponents/tests/common.py +++ b/src/spikeinterface/sortingcomponents/tests/common.py @@ -1,4 +1,4 @@ -from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording def make_dataset(): @@ -16,7 +16,7 @@ def make_dataset(): contact_shape_params={"radius": 6}, ), generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/conftest.py b/src/spikeinterface/sortingcomponents/waveforms/tests/conftest.py index f3ab58d87b..f8130446e9 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/conftest.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.generation import generate_ground_truth_recording from spikeinterface.sortingcomponents.peak_detection import detect_peaks diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index c811c386e2..f199e12897 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,9 +13,9 @@ from spikeinterface import ( compute_sparsity, - generate_ground_truth_recording, create_sorting_analyzer, ) +from spikeinterface.generation import generate_ground_truth_recording import spikeinterface.widgets as sw import spikeinterface.comparison as sc @@ -54,7 +54,7 @@ def setUpClass(cls): contact_shape_params={"radius": 6}, ), generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), + noise_kwargs=dict(noise_levels=5.0), seed=2205, ) # cls.recording = recording.save(folder=cache_folder / "recording")