-
Notifications
You must be signed in to change notification settings - Fork 228
Add error messaging around use of get data in templates #3501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,21 +22,23 @@ | |
|
|
||
| class ComputeRandomSpikes(AnalyzerExtension): | ||
| """ | ||
| AnalyzerExtension that select some random spikes. | ||
| AnalyzerExtension that select somes random spikes. | ||
| This allows for a subsampling of spikes for further calculations and is important | ||
| for managing that amount of memory and speed of computation in the analyzer. | ||
|
|
||
| This will be used by the `waveforms`/`templates` extensions. | ||
|
|
||
| This internally use `random_spikes_selection()` parameters are the same. | ||
| This internally uses `random_spikes_selection()` parameters. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| method: "uniform" | "all", default: "uniform" | ||
| method : "uniform" | "all", default: "uniform" | ||
| The method to select the spikes | ||
| max_spikes_per_unit: int, default: 500 | ||
| max_spikes_per_unit : int, default: 500 | ||
| The maximum number of spikes per unit, ignored if method="all" | ||
| margin_size: int, default: None | ||
| margin_size : int, default: None | ||
| A margin on each border of segments to avoid border spikes, ignored if method="all" | ||
| seed: int or None, default: None | ||
| seed : int or None, default: None | ||
| A seed for the random generator, ignored if method="all" | ||
|
|
||
| Returns | ||
|
|
@@ -104,7 +106,7 @@ def get_random_spikes(self): | |
| return self._some_spikes | ||
|
|
||
| def get_selected_indices_in_spike_train(self, unit_id, segment_index): | ||
| # usefull for Waveforms extractor backwars compatibility | ||
| # useful for WaveformExtractor backwards compatibility | ||
| # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain | ||
| sorting = self.sorting_analyzer.sorting | ||
| random_spikes_indices = self.data["random_spikes_indices"] | ||
|
|
@@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension): | |
|
|
||
| Parameters | ||
| ---------- | ||
| ms_before: float, default: 1.0 | ||
| ms_before : float, default: 1.0 | ||
| The number of ms to extract before the spike events | ||
| ms_after: float, default: 2.0 | ||
| ms_after : float, default: 2.0 | ||
| The number of ms to extract after the spike events | ||
| dtype: None | dtype, default: None | ||
| dtype : None | dtype, default: None | ||
| The dtype of the waveforms. If None, the dtype of the recording is used. | ||
|
|
||
| Returns | ||
| ------- | ||
| waveforms: np.ndarray | ||
| waveforms : np.ndarray | ||
| Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels) | ||
| """ | ||
|
|
||
|
|
@@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N | |
| assert isinstance(operators, list) | ||
| for operator in operators: | ||
| if isinstance(operator, str): | ||
| assert operator in ("average", "std", "median", "mad") | ||
| if operator not in ("average", "std", "median", "mad"): | ||
alejoe91 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| error_msg = ( | ||
| f"You have entered an operator {operator} in your `operators` argument which is " | ||
| f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead." | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh nice, in that case discard above comment 😆 |
||
| raise ValueError(error_msg) | ||
| else: | ||
| assert isinstance(operator, (list, tuple)) | ||
| assert len(operator) == 2 | ||
|
|
@@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs): | |
| self._compute_and_append_from_waveforms(self.params["operators"]) | ||
|
|
||
| else: | ||
| for operator in self.params["operators"]: | ||
| if operator not in ("average", "std"): | ||
| raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension") | ||
| bad_operator_list = [ | ||
| operator for operator in self.params["operators"] if operator not in ("average", "std") | ||
| ] | ||
| if len(bad_operator_list) > 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe if
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that is stylistic. I don't use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's clear enough :) |
||
| raise ValueError( | ||
| f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting I assumed waveforms always needed to be computed for templates, what other ways ways of doing it are there?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should dig into the templates code. Alessio + Sam ( I think it was just them but maybe Heberto too) developed an accumulator method that uses the waveforms to make an average without saving them so you would analyzer.compute(['random_spikes', 'templates']) and it will read the waveforms from random spikes while making the templates then discard them to save on memory. I think the only think it breaks would be doing PCA later, but it is way less storage intensive since you don't save all the extra waveforms. It is limited in the types of operators it can do though as you can see from the error.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep this accumulator was a great stuf during the analyzer dev.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool! will check that out |
||
| ) | ||
|
|
||
| recording = self.sorting_analyzer.recording | ||
| sorting = self.sorting_analyzer.sorting | ||
|
|
@@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs): | |
|
|
||
| def _compute_and_append_from_waveforms(self, operators): | ||
| if not self.sorting_analyzer.has_extension("waveforms"): | ||
| raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension") | ||
| raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension") | ||
|
|
||
| unit_ids = self.sorting_analyzer.unit_ids | ||
| channel_ids = self.sorting_analyzer.channel_ids | ||
|
|
@@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators): | |
|
|
||
| assert self.sorting_analyzer.has_extension( | ||
| "random_spikes" | ||
| ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()" | ||
| ), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')" | ||
| some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() | ||
| for unit_index, unit_id in enumerate(unit_ids): | ||
| spike_mask = some_spikes["unit_index"] == unit_index | ||
|
|
@@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): | |
| if operator != "percentile": | ||
| key = operator | ||
| else: | ||
| assert percentile is not None, "You must provide percentile=..." | ||
| assert percentile is not None, "You must provide percentile=... if `operator=percentile`" | ||
| key = f"percentile_{percentile}" | ||
|
|
||
| if key not in self.data.keys(): | ||
| error_msg = ( | ||
| f"You have entered `operator={key}`, but the only operators calculated are " | ||
| f"{list(self.data.keys())}. Please use one of these as your `operator` in the " | ||
| f"`get_data` function." | ||
| ) | ||
| raise ValueError(error_msg) | ||
|
|
||
| templates_array = self.data[key] | ||
|
|
||
| if outputs == "numpy": | ||
|
|
@@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): | |
| probe=self.sorting_analyzer.get_probe(), | ||
| ) | ||
| else: | ||
| raise ValueError("outputs must be numpy or Templates") | ||
| raise ValueError("outputs must be `numpy` or `Templates`") | ||
|
|
||
| def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"): | ||
| """ | ||
|
|
@@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save | |
|
|
||
| Parameters | ||
| ---------- | ||
| unit_ids: list or None | ||
| unit_ids : list or None | ||
| Unit ids to retrieve waveforms for | ||
| operator: "average" | "median" | "std" | "percentile", default: "average" | ||
| operator : "average" | "median" | "std" | "percentile", default: "average" | ||
| The operator to compute the templates | ||
| percentile: float, default: None | ||
| percentile : float, default: None | ||
| Percentile to use for operator="percentile" | ||
| save: bool, default True | ||
| save : bool, default: True | ||
| In case, the operator is not computed yet it can be saved to folder or zarr | ||
| outputs: "numpy" | "Templates" | ||
| outputs : "numpy" | "Templates", default: "numpy" | ||
| Whether to return a numpy array or a Templates object | ||
|
|
||
| Returns | ||
| ------- | ||
| templates: np.array | ||
| templates : np.array | Templates | ||
| The returned templates (num_units, num_samples, num_channels) | ||
| """ | ||
| if operator != "percentile": | ||
| key = operator | ||
| else: | ||
| assert percentile is not None, "You must provide percentile=..." | ||
| assert percentile is not None, "You must provide percentile=... if `operator='percentile'`" | ||
| key = f"pencentile_{percentile}" | ||
|
|
||
| if key in self.data: | ||
|
|
@@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save | |
| is_scaled=self.sorting_analyzer.return_scaled, | ||
| ) | ||
| else: | ||
| raise ValueError("outputs must be numpy or Templates") | ||
| raise ValueError("`outputs` must be 'numpy' or 'Templates'") | ||
|
|
||
| def get_unit_template(self, unit_id, operator="average"): | ||
| """ | ||
|
|
@@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"): | |
| ---------- | ||
| unit_id: str | int | ||
| Unit id to retrieve waveforms for | ||
| operator: str | ||
| operator: str, default: "average" | ||
| The operator to compute the templates | ||
|
|
||
| Returns | ||
|
|
@@ -701,13 +720,13 @@ def _set_params(self, **noise_level_params): | |
| return params | ||
|
|
||
| def _select_extension_data(self, unit_ids): | ||
| # this do not depend on units | ||
| # this does not depend on units | ||
| return self.data | ||
|
|
||
| def _merge_extension_data( | ||
| self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs | ||
| ): | ||
| # this do not depend on units | ||
| # this does not depend on units | ||
| return self.data.copy() | ||
|
|
||
| def _run(self, verbose=False): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice addition, could 'for further calculations' and the below ' This will be used by the
waveforms/templatesextensions.' be combined and emphasise that this choice could potentially have important consequences for results. e.g. 'The samples spikes will be used for calculating waveforms and templates and as such determine many downstream parameters (e.g. quality metrics). Therefore it is important that spikes a sufficient number of spikes are sampled and that these are distributed evenly through the dataset'.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm torn. I was just trying to clarify. I prefer
one line short summary
params
return
examples
notes
And do that big explanation in the notes. The statement you're making I would put into notes, but I didn't want to change stuff too much although since I made the diff it's my fault you looked.
The real point of this PR is just to improve error messaging. I think we could improve further the docstrings etc in a separate PR (and this is a reminder for me to focus on my task so we don't get bogged down on other stuff for a PR with a specific purpose :) ).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds good, happy to leave this for another day. It's easier said than done to stick religiously to one change in a PR, there are always some appealing changes to be made that inevitably catch your eye!