Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
import numpy as np

from spikeinterface.core import (
generate_ground_truth_recording,
estimate_templates,
Templates,
create_sorting_analyzer,
ms_to_samples,
)
from spikeinterface.generation import generate_drifting_recording
from spikeinterface.generation import generate_drifting_recording, generate_ground_truth_recording

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))

Expand All @@ -39,7 +38,7 @@ def make_dataset(job_kwargs={}):
contact_shape_params={"radius": 6},
),
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
noise_kwargs=dict(noise_levels=5.0),
seed=2205,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from pathlib import Path

from spikeinterface import generate_ground_truth_recording
from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.preprocessing import bandpass_filter
from spikeinterface.benchmark import SorterStudy

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from pathlib import Path

from spikeinterface import generate_ground_truth_recording
from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.preprocessing import bandpass_filter
from spikeinterface.benchmark import SorterStudyWithoutGroundTruth

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pytest
import numpy as np

from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording
from spikeinterface.core import create_sorting_analyzer
from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.comparison import compare_templates, compare_multiple_templates

# def setup_module():
Expand Down
15 changes: 14 additions & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
generate_recording_by_size,
InjectTemplatesRecording,
inject_templates,
generate_ground_truth_recording,
)

# utils to append and concatenate segment (equivalent to OLD MultiRecordingTimeExtractor)
Expand Down Expand Up @@ -213,4 +212,18 @@ def __getattr__(name):
"noise_generator_recording": noise_generator_recording,
}
return _map[name]
if name == "generate_ground_truth_recording":
import warnings

warnings.warn(
"Importing generate_ground_truth_recording from spikeinterface.core is deprecated. "
"Import from spikeinterface.generation instead: "
"`from spikeinterface.generation import generate_ground_truth_recording`. "
"This will be removed in version 0.106.0.",
FutureWarning,
stacklevel=2,
)
from spikeinterface.generation.ground_truth_generator import generate_ground_truth_recording

return generate_ground_truth_recording
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
204 changes: 14 additions & 190 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2266,196 +2266,6 @@ def generate_unit_locations(
return units_locations


def generate_ground_truth_recording(
durations=[10.0],
sampling_frequency=25000.0,
num_channels=4,
num_units=10,
sorting=None,
probe=None,
generate_probe_kwargs=dict(
num_columns=2,
xpitch=20,
ypitch=20,
contact_shapes="circle",
contact_shape_params={"radius": 6},
),
templates=None,
ms_before=1.0,
ms_after=3.0,
upsample_factor=None,
upsample_vector=None,
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20),
generate_templates_kwargs=None,
dtype="float32",
seed=None,
):
"""
Generate a recording with spike given a probe+sorting+templates.

Parameters
----------
durations : list[float], default: [10.]
Durations in seconds for all segments.
sampling_frequency : float, default: 25000.0
Sampling frequency.
num_channels : int, default: 4
Number of channels, not used when probe is given.
num_units : int, default: 10
Number of units, not used when sorting is given.
sorting : Sorting | None
An external sorting object. If not provide, one is genrated.
probe : Probe | None
An external Probe object. If not provided a probe is generated using generate_probe_kwargs.
generate_probe_kwargs : dict
A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`.
templates : np.ndarray | None
The templates of units.
If None they are generated.
Shape can be:

* (num_units, num_samples, num_channels): standard case
* (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter.
ms_before : float, default: 1.5
Cut out in ms before spike peak.
ms_after : float, default: 3.0
Cut out in ms after spike peak.
upsample_factor : None | int, default: None
A upsampling factor used only when templates are not provided.
upsample_vector : np.ndarray | None
Optional the upsample_vector can given. This has the same shape as spike_vector
generate_sorting_kwargs : dict
When sorting is not provide, this dict is used to generated a Sorting.
noise_kwargs : dict
Dict used to generated the noise with NoiseGeneratorRecording.
generate_unit_locations_kwargs : dict
Dict used to generated template when template not provided.
generate_templates_kwargs : dict
Dict used to generated template when template not provided.
dtype : np.dtype, default: "float32"
The dtype of the recording.
seed : int | None
Seed for random initialization.
If None a diffrent Recording is generated at every call.
Note: even with None a generated recording keep internaly a seed to regenerate the same signal after dump/load.

Returns
-------
recording : Recording
The generated recording extractor.
sorting : Sorting
The generated sorting extractor.
"""
generate_templates_kwargs = generate_templates_kwargs or dict()

# TODO implement upsample_factor in InjectTemplatesRecording and propagate into toy_example

# if None so the same seed will be used for all steps
seed = _ensure_seed(seed)
rng = np.random.default_rng(seed)

if sorting is None:
generate_sorting_kwargs = generate_sorting_kwargs.copy()
generate_sorting_kwargs["durations"] = durations
generate_sorting_kwargs["num_units"] = num_units
generate_sorting_kwargs["sampling_frequency"] = sampling_frequency
generate_sorting_kwargs["seed"] = seed
sorting = generate_sorting(**generate_sorting_kwargs)
else:
num_units = sorting.get_num_units()
assert sorting.sampling_frequency == sampling_frequency
num_spikes = sorting.to_spike_vector().size

if probe is None:
# probe = generate_linear_probe(num_elec=num_channels)
# probe.set_device_channel_indices(np.arange(num_channels))

prb_kwargs = generate_probe_kwargs.copy()
if "num_contact_per_column" in prb_kwargs:
assert (
prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"]
) == num_channels, (
"generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns"
)
elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs:
n = num_channels // prb_kwargs["num_columns"]
num_contact_per_column = [n] * prb_kwargs["num_columns"]
mid = prb_kwargs["num_columns"] // 2
num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"]
prb_kwargs["num_contact_per_column"] = num_contact_per_column
else:
raise ValueError("num_columns should be provided in dict generate_probe_kwargs")

probe = generate_multi_columns_probe(**prb_kwargs)
probe.set_device_channel_indices(np.arange(num_channels))

else:
num_channels = probe.get_contact_count()

if templates is None:
channel_locations = probe.contact_positions
unit_locations = generate_unit_locations(
num_units, channel_locations, seed=seed, **generate_unit_locations_kwargs
)
templates = generate_templates(
channel_locations,
unit_locations,
sampling_frequency,
ms_before,
ms_after,
upsample_factor=upsample_factor,
seed=seed,
dtype=dtype,
**generate_templates_kwargs,
)
sorting.set_property("gt_unit_locations", unit_locations)
else:
assert templates.shape[0] == num_units

if templates.ndim == 3:
upsample_vector = None
else:
if upsample_vector is None:
upsample_factor = templates.shape[3]
upsample_vector = rng.integers(0, upsample_factor, size=num_spikes)

nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)
assert (nbefore + nafter) == templates.shape[1]

# construct recording
from spikeinterface.generation.noise_tools import NoiseGeneratorRecording

noise_rec = NoiseGeneratorRecording(
num_channels=num_channels,
sampling_frequency=sampling_frequency,
durations=durations,
dtype=dtype,
seed=seed,
noise_block_size=int(sampling_frequency),
**noise_kwargs,
)

recording = InjectTemplatesRecording(
sorting,
templates,
nbefore=nbefore,
parent_recording=noise_rec,
upsample_vector=upsample_vector,
)
recording.annotate(is_filtered=True)
recording.set_probe(probe, in_place=True)
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

recording.name = "GroundTruthRecording"
sorting.name = "GroundTruthSorting"

return recording, sorting


def __getattr__(name):
if name in ("NoiseGeneratorRecording", "noise_generator_recording"):
import warnings
Expand All @@ -2475,4 +2285,18 @@ def __getattr__(name):
"noise_generator_recording": noise_generator_recording,
}
return _map[name]
if name == "generate_ground_truth_recording":
import warnings

warnings.warn(
"Importing generate_ground_truth_recording from spikeinterface.core.generate is deprecated. "
"Import from spikeinterface.generation instead: "
"`from spikeinterface.generation import generate_ground_truth_recording`. "
"This will be removed in version 0.106.0.",
FutureWarning,
stacklevel=2,
)
from spikeinterface.generation.ground_truth_generator import generate_ground_truth_recording

return generate_ground_truth_recording
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pathlib import Path

from spikeinterface.core import generate_ground_truth_recording
from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.core import create_sorting_analyzer
from spikeinterface.core import Templates

Expand All @@ -30,7 +30,7 @@ def get_sorting_analyzer(cache_folder, format="memory", sparse=True):
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
noise_kwargs=dict(noise_levels=5.0),
seed=2406,
)
if format == "memory":
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import Sequence
import numpy as np
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings
from spikeinterface.core import generate_recording, concatenate_recordings
from spikeinterface.generation import generate_ground_truth_recording


class DummyDictExtractor(BaseExtractor):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
SharedMemorySorting,
NpzFolderSorting,
NumpyFolderSorting,
generate_ground_truth_recording,
generate_sorting,
create_sorting_npz,
generate_sorting,
load,
)
from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.core.base import BaseExtractor, unit_period_dtype
from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal

Expand Down
5 changes: 1 addition & 4 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
generate_templates,
generate_channel_locations,
generate_unit_locations,
generate_ground_truth_recording,
generate_sorting_to_inject,
synthesize_random_firings,
)
from spikeinterface.generation import NoiseGeneratorRecording
from spikeinterface.generation import NoiseGeneratorRecording, generate_ground_truth_recording

from spikeinterface.core.numpyextractors import NumpySorting

Expand Down Expand Up @@ -220,7 +219,6 @@ def test_noise_generator_several_noise_levels():
dtype="float32",
seed=32,
noise_levels=1,
strategy="on_the_fly",
noise_block_size=20000,
)
assert np.all(np.abs(get_noise_levels(rec1) - 1) < 0.1)
Expand All @@ -232,7 +230,6 @@ def test_noise_generator_several_noise_levels():
dtype="float32",
seed=32,
noise_levels=[0, 1, 2, 3],
strategy="on_the_fly",
noise_block_size=20000,
)
assert np.all(np.abs(get_noise_levels(rec2) - np.arange(4)) < 0.1)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import numpy as np
from spikeinterface import (
generate_ground_truth_recording,
create_sorting_analyzer,
load,
ms_to_samples,
SortingAnalyzer,
Templates,
aggregate_channels,
)
from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.core.motion import Motion
from spikeinterface.core.generate import generate_unit_locations, generate_templates
from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from pathlib import Path
import shutil

from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording
from spikeinterface import create_sorting_analyzer, get_template_extremum_channel
from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.core.base import spike_peak_dtype
from spikeinterface.core.job_tools import divide_time_series_into_chunks

Expand Down
Loading
Loading