Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow export_to_phy to work with fast_templates #2549

Merged
merged 15 commits into from Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 46 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Expand Up @@ -397,10 +397,10 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
----------
unit_ids: list or None
Unit ids to retrieve waveforms for
mode: "average" | "median" | "std" | "percentile", default: "average"
The mode to compute the templates
operator: "average" | "median" | "std" | "percentile", default: "average"
The operator to compute the templates
percentile: float, default: None
Percentile to use for mode="percentile"
Percentile to use for operator="percentile"
save: bool, default True
In case, the operator is not computed yet it can be saved to folder or zarr.

Expand Down Expand Up @@ -520,6 +520,49 @@ def _select_extension_data(self, unit_ids):

return new_data

def get_templates(self, unit_ids=None):
"""
Return average templates for multiple units.

Parameters
----------
unit_ids: list or None, default: None
Unit ids to retrieve waveforms for

Returns
-------
templates: np.array
The returned templates (num_units, num_samples, num_channels)
"""

templates = self.data["average"]

if unit_ids is not None:
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
templates = templates[unit_indices, :, :]

return np.array(templates)

def get_unit_template(self, unit_id):
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
"""
Return average template for a single unit.

Parameters
----------
unit_id: str | int
Unit id to retrieve waveforms for

Returns
-------
template: np.array
The returned template (num_samples, num_channels)
"""

templates = self.data["average"]
unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id)

return np.array(templates[unit_index, :, :])


compute_fast_templates = ComputeFastTemplates.function_factory()
register_result_extension(ComputeFastTemplates)
Expand Down
14 changes: 12 additions & 2 deletions src/spikeinterface/exporters/to_phy.py
Expand Up @@ -181,9 +181,19 @@ def export_to_phy(
# export templates/templates_ind/similar_templates
# shape (num_units, num_samples, max_num_channels)
templates_ext = sorting_analyzer.get_extension("templates")
templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates'"
if templates_ext is None:
templates_ext = sorting_analyzer.get_extension("fast_templates")
if templates_ext is not None and template_mode != "average":
raise ValueError(
"export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'"
)
dense_templates = templates_ext.get_templates(unit_ids=unit_ids)
else:
dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode)
assert (
templates_ext is not None
), "export_to_phy requires SortingAnalyzer with either extension 'templates' or 'fast_templates'"
Copy link
Collaborator

@zm711 zm711 Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this and templates_ext cannot be None, because if it is None above the templates_ext.get_templates will fail. Or am I misreading this? It seems to me that it would be better to do a check for none after line 185. Something like

if templates_ext is None:
    raise ValueError('Must have either calculated templates or fast_templates')

That way regardless of if the | notation is accepted you have the check in place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, nice catch!

But as you said, I'm waiting for the | PR to be merged, to re-write this more cleanly :)

max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values())
dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode)
num_samples = dense_templates.shape[1]
templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64")
# here we pad template inds with -1 if len of sparse channels is unequal
Expand Down