Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
sampling_frequency = zarr_group.attrs["sampling_frequency"]
nbefore = zarr_group.attrs["nbefore"]

# TODO: Consider eliminating the True and make it required
is_scaled = zarr_group.attrs.get("is_scaled", True)

sparsity_mask = None
if "sparsity_mask" in zarr_group:
sparsity_mask = zarr_group["sparsity_mask"]
Expand All @@ -316,6 +319,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
channel_ids=channel_ids,
unit_ids=unit_ids,
probe=probe,
is_scaled=is_scaled,
)

@staticmethod
Expand Down
30 changes: 20 additions & 10 deletions src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from probeinterface import generate_multi_columns_probe


def generate_test_template(template_type):
def generate_test_template(template_type, is_scaled=True) -> Templates:
num_units = 2
num_samples = 5
num_channels = 3
Expand All @@ -21,7 +21,11 @@ def generate_test_template(template_type):

if template_type == "dense":
return Templates(
templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=is_scaled,
)
elif template_type == "sparse": # sparse with sparse templates
sparsity_mask = np.array([[True, False, True], [False, True, False]])
Expand All @@ -42,6 +46,7 @@ def generate_test_template(template_type):
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=is_scaled,
)

elif template_type == "sparse_with_dense_templates": # sparse with dense templates
Expand All @@ -53,12 +58,14 @@ def generate_test_template(template_type):
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=is_scaled,
)


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_pickle_serialization(template_type, tmp_path):
template = generate_test_template(template_type)
def test_pickle_serialization(template_type, is_scaled, tmp_path):
template = generate_test_template(template_type, is_scaled)

# Dump to pickle
pkl_path = tmp_path / "templates.pkl"
Expand All @@ -72,19 +79,21 @@ def test_pickle_serialization(template_type, tmp_path):
assert template == template_reloaded


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_json_serialization(template_type):
template = generate_test_template(template_type)
def test_json_serialization(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)

json_str = template.to_json()
template_reloaded_from_json = Templates.from_json(json_str)

assert template == template_reloaded_from_json


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_get_dense_templates(template_type):
template = generate_test_template(template_type)
def test_get_dense_templates(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)
dense_templates = template.get_dense_templates()
assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels)

Expand All @@ -94,9 +103,10 @@ def test_initialization_fail_with_dense_templates():
template = generate_test_template(template_type="sparse_with_dense_templates")


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_save_and_load_zarr(template_type, tmp_path):
original_template = generate_test_template(template_type)
def test_save_and_load_zarr(template_type, is_scaled, tmp_path):
original_template = generate_test_template(template_type, is_scaled)

zarr_path = tmp_path / "templates.zarr"
original_template.to_zarr(str(zarr_path))
Expand Down