Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow export_to_phy to work with fast_templates #2549

Merged
merged 15 commits into from Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Expand Up @@ -397,10 +397,10 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
----------
unit_ids: list or None
Unit ids to retrieve waveforms for
mode: "average" | "median" | "std" | "percentile", default: "average"
The mode to compute the templates
operator: "average" | "median" | "std" | "percentile", default: "average"
The operator to compute the templates
percentile: float, default: None
Percentile to use for mode="percentile"
Percentile to use for operator="percentile"
save: bool, default True
In case, the operator is not computed yet it can be saved to folder or zarr.

Expand Down Expand Up @@ -520,6 +520,34 @@ def _select_extension_data(self, unit_ids):

return new_data

def get_templates(self, unit_ids=None, save=True):
"""
Return average templates for multiple units.

Parameters
----------
unit_ids: list or None
DradeAW marked this conversation as resolved.
Show resolved Hide resolved
Unit ids to retrieve waveforms for
save: bool, default True
In case, the operator is not computed yet it can be saved to folder or zarr.

Returns
-------
templates: np.array
The returned templates (num_units, num_samples, num_channels)
"""

templates = self.data["average"]

if save:
self.save()
DradeAW marked this conversation as resolved.
Show resolved Hide resolved

if unit_ids is not None:
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
templates = templates[unit_indices, :, :]

return np.array(templates)


compute_fast_templates = ComputeFastTemplates.function_factory()
register_result_extension(ComputeFastTemplates)
Expand Down
14 changes: 12 additions & 2 deletions src/spikeinterface/exporters/to_phy.py
Expand Up @@ -181,9 +181,19 @@ def export_to_phy(
# export templates/templates_ind/similar_templates
# shape (num_units, num_samples, max_num_channels)
templates_ext = sorting_analyzer.get_extension("templates")
templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates'"
if templates_ext is None:
templates_ext = sorting_analyzer.get_extension("fast_templates")
if templates_ext is not None and template_mode != "average":
assert (
False
), "export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'"
DradeAW marked this conversation as resolved.
Show resolved Hide resolved
dense_templates = templates_ext.get_templates(unit_ids=unit_ids)
else:
dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode)
assert (
templates_ext is not None
), "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'"
DradeAW marked this conversation as resolved.
Show resolved Hide resolved
max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values())
dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode)
num_samples = dense_templates.shape[1]
templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64")
# here we pad template inds with -1 if len of sparse channels is unequal
Expand Down