Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 49 additions & 30 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,23 @@

class ComputeRandomSpikes(AnalyzerExtension):
"""
AnalyzerExtension that select some random spikes.
AnalyzerExtension that select somes random spikes.
This allows for a subsampling of spikes for further calculations and is important
for managing that amount of memory and speed of computation in the analyzer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice addition, could 'for further calculations' and the below ' This will be used by the waveforms/templates extensions.' be combined and emphasise that this choice could potentially have important consequences for results. e.g. 'The samples spikes will be used for calculating waveforms and templates and as such determine many downstream parameters (e.g. quality metrics). Therefore it is important that spikes a sufficient number of spikes are sampled and that these are distributed evenly through the dataset'.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm torn. I was just trying to clarify. I prefer

one line short summary
params
return
examples
notes

And do that big explanation in the notes. The statement you're making I would put into notes, but I didn't want to change stuff too much although since I made the diff it's my fault you looked.

The real point of this PR is just to improve error messaging. I think we could improve further the docstrings etc in a separate PR (and this is a reminder for me to focus on my task so we don't get bogged down on other stuff for a PR with a specific purpose :) ).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good, happy to leave this for another day. It's easier said than done to stick religiously to one change in a PR, there are always some appealing changes to be made that inevitably catch your eye!


This will be used by the `waveforms`/`templates` extensions.

This internally use `random_spikes_selection()` parameters are the same.
This internally uses `random_spikes_selection()` parameters.

Parameters
----------
method: "uniform" | "all", default: "uniform"
method : "uniform" | "all", default: "uniform"
The method to select the spikes
max_spikes_per_unit: int, default: 500
max_spikes_per_unit : int, default: 500
The maximum number of spikes per unit, ignored if method="all"
margin_size: int, default: None
margin_size : int, default: None
A margin on each border of segments to avoid border spikes, ignored if method="all"
seed: int or None, default: None
seed : int or None, default: None
A seed for the random generator, ignored if method="all"

Returns
Expand Down Expand Up @@ -104,7 +106,7 @@ def get_random_spikes(self):
return self._some_spikes

def get_selected_indices_in_spike_train(self, unit_id, segment_index):
# usefull for Waveforms extractor backwars compatibility
# useful for WaveformExtractor backwards compatibility
# In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain
sorting = self.sorting_analyzer.sorting
random_spikes_indices = self.data["random_spikes_indices"]
Expand Down Expand Up @@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension):

Parameters
----------
ms_before: float, default: 1.0
ms_before : float, default: 1.0
The number of ms to extract before the spike events
ms_after: float, default: 2.0
ms_after : float, default: 2.0
The number of ms to extract after the spike events
dtype: None | dtype, default: None
dtype : None | dtype, default: None
The dtype of the waveforms. If None, the dtype of the recording is used.

Returns
-------
waveforms: np.ndarray
waveforms : np.ndarray
Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels)
"""

Expand Down Expand Up @@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N
assert isinstance(operators, list)
for operator in operators:
if isinstance(operator, str):
assert operator in ("average", "std", "median", "mad")
if operator not in ("average", "std", "median", "mad"):
error_msg = (
f"You have entered an operator {operator} in your `operators` argument which is "
f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice, in that case discard above comment 😆

raise ValueError(error_msg)
else:
assert isinstance(operator, (list, tuple))
assert len(operator) == 2
Expand All @@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs):
self._compute_and_append_from_waveforms(self.params["operators"])

else:
for operator in self.params["operators"]:
if operator not in ("average", "std"):
raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension")
bad_operator_list = [
operator for operator in self.params["operators"] if operator not in ("average", "std")
]
if len(bad_operator_list) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe if any(bad_operator_list) is slightly more readable (?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is stylistic. I don't use any that often in code I think checking the length of the list is just my default. I don't know what others think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's clear enough :)

raise ValueError(
f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting I assumed waveforms always needed to be computed for templates, what other ways ways of doing it are there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should dig into the templates code. Alessio + Sam ( I think it was just them but maybe Heberto too) developed an accumulator method that uses the waveforms to make an average without saving them so you would

analyzer.compute(['random_spikes', 'templates']) and it will read the waveforms from random spikes while making the templates then discard them to save on memory. I think the only think it breaks would be doing PCA later, but it is way less storage intensive since you don't save all the extra waveforms. It is limited in the types of operators it can do though as you can see from the error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep this accumulator was a great stuf during the analyzer dev.
the only drawback is that MAD cannot be computed that way. but saving ram and disk space is cool.
you can see the idea in waveform_tools.py
every worker accumulate snippet in parralel and the sum + divide is done at teh end.
this is quite fast

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool! will check that out

)

recording = self.sorting_analyzer.recording
sorting = self.sorting_analyzer.sorting
Expand Down Expand Up @@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs):

def _compute_and_append_from_waveforms(self, operators):
if not self.sorting_analyzer.has_extension("waveforms"):
raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension")
raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension")

unit_ids = self.sorting_analyzer.unit_ids
channel_ids = self.sorting_analyzer.channel_ids
Expand All @@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators):

assert self.sorting_analyzer.has_extension(
"random_spikes"
), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()"
), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')"
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
for unit_index, unit_id in enumerate(unit_ids):
spike_mask = some_spikes["unit_index"] == unit_index
Expand Down Expand Up @@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
if operator != "percentile":
key = operator
else:
assert percentile is not None, "You must provide percentile=..."
assert percentile is not None, "You must provide percentile=... if `operator=percentile`"
key = f"percentile_{percentile}"

if key not in self.data.keys():
error_msg = (
f"You have entered `operator={key}`, but the only operators calculated are "
f"{list(self.data.keys())}. Please use one of these as your `operator` in the "
f"`get_data` function."
)
raise ValueError(error_msg)

templates_array = self.data[key]

if outputs == "numpy":
Expand All @@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
probe=self.sorting_analyzer.get_probe(),
)
else:
raise ValueError("outputs must be numpy or Templates")
raise ValueError("outputs must be `numpy` or `Templates`")

def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"):
"""
Expand All @@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save

Parameters
----------
unit_ids: list or None
unit_ids : list or None
Unit ids to retrieve waveforms for
operator: "average" | "median" | "std" | "percentile", default: "average"
operator : "average" | "median" | "std" | "percentile", default: "average"
The operator to compute the templates
percentile: float, default: None
percentile : float, default: None
Percentile to use for operator="percentile"
save: bool, default True
save : bool, default: True
In case, the operator is not computed yet it can be saved to folder or zarr
outputs: "numpy" | "Templates"
outputs : "numpy" | "Templates", default: "numpy"
Whether to return a numpy array or a Templates object

Returns
-------
templates: np.array
templates : np.array | Templates
The returned templates (num_units, num_samples, num_channels)
"""
if operator != "percentile":
key = operator
else:
assert percentile is not None, "You must provide percentile=..."
assert percentile is not None, "You must provide percentile=... if `operator='percentile'`"
key = f"pencentile_{percentile}"

if key in self.data:
Expand Down Expand Up @@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
is_scaled=self.sorting_analyzer.return_scaled,
)
else:
raise ValueError("outputs must be numpy or Templates")
raise ValueError("`outputs` must be 'numpy' or 'Templates'")

def get_unit_template(self, unit_id, operator="average"):
"""
Expand All @@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"):
----------
unit_id: str | int
Unit id to retrieve waveforms for
operator: str
operator: str, default: "average"
The operator to compute the templates

Returns
Expand Down Expand Up @@ -701,13 +720,13 @@ def _set_params(self, **noise_level_params):
return params

def _select_extension_data(self, unit_ids):
# this do not depend on units
# this does not depend on units
return self.data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
# this do not depend on units
# this does not depend on units
return self.data.copy()

def _run(self, verbose=False):
Expand Down
Loading