Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute quality metrics after pipeline nodes #2773

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,13 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):

extensions_with_pipeline = {}
extensions_without_pipeline = {}
extensions_post_pipeline = {}
for extension_name, extension_params in extensions.items():
if extension_name == "quality_metrics":
# PATCH: the quality metric is computed after the pipeline, since some of the metrics optionally require
# the output of the pipeline extensions (e.g., spike_amplitudes, spike_locations).
extensions_post_pipeline[extension_name] = extension_params
continue
extension_class = get_extension_class(extension_name)
if extension_class.use_nodepipeline:
extensions_with_pipeline[extension_name] = extension_params
Expand Down Expand Up @@ -1020,6 +1026,17 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
if save:
extension_instance.save()

# PATCH: the quality metric is computed after the pipeline, since some of the metrics optionally require
# the output of the pipeline extensions (e.g., spike_amplitudes, spike_locations).
# An alternative could be to extend the "depend_on" attribute to use optional and to check if an extension
# depends on the output of the pipeline nodes (e.g. depend_on=["spike_amplitudes[optional]"])
Comment on lines +1029 to +1032
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect

for extension_name, extension_params in extensions_post_pipeline.items():
extension_class = get_extension_class(extension_name)
if extension_class.need_job_kwargs:
self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs)
else:
self.compute_one_extension(extension_name, save=save, **extension_params)

def get_saved_extension_names(self):
"""
Get extension names saved in folder or zarr that can be loaded.
Expand Down Expand Up @@ -1173,7 +1190,7 @@ def _get_children_dependencies(extension_name):
This function is making the reverse way : get all children that depend of a
particular extension.

This is recurssive so this includes : children and so grand children and grand grand children
This is recursive so this includes : children and so grand children and great grand children

This function is usefull for deleting on recompute.
For instance recompute the "waveforms" need to delete "template"
Expand Down