Skip to content

Commit

Permalink
Merge pull request #2848 from zm711/fix-template-tools
Browse files Browse the repository at this point in the history
Use the `is_scaled` attribute inside of `template_tools` functions
  • Loading branch information
samuelgarcia committed May 21, 2024
2 parents d66f8eb + 098a0bc commit 5a195f6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p
return_scaled = templates_or_sorting_analyzer.return_scaled
elif isinstance(templates_or_sorting_analyzer, Templates):
assert noise_levels is not None
return_scaled = True
return_scaled = templates_or_sorting_analyzer.is_scaled

mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")

Expand Down Expand Up @@ -369,7 +369,7 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None):
return_scaled = templates_or_sorting_analyzer.return_scaled
elif isinstance(templates_or_sorting_analyzer, Templates):
assert noise_levels is not None
return_scaled = True
return_scaled = templates_or_sorting_analyzer.is_scaled

from .template_tools import get_dense_templates_array

Expand Down
12 changes: 6 additions & 6 deletions src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def get_template_extremum_channel(
channel_ids = templates_or_sorting_analyzer.channel_ids

# if SortingAnalyzer need to use global SortingAnalyzer return_scaled otherwise
# we just use the previous default of return_scaled=True (for templates)
# we use the Templates is_scaled
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
return_scaled = templates_or_sorting_analyzer.return_scaled
else:
return_scaled = True
return_scaled = templates_or_sorting_analyzer.is_scaled

peak_values = get_template_amplitudes(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled
Expand Down Expand Up @@ -200,12 +200,12 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak

shifts = {}

# We need to use the SortingAnalyzer return_scaled if possible
# otherwise for Templates default to True
# We need to use the SortingAnalyzer return_scaled
# We need to use the Templates is_scaled
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
return_scaled = templates_or_sorting_analyzer.return_scaled
else:
return_scaled = True
return_scaled = templates_or_sorting_analyzer.is_scaled

templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled)

Expand Down Expand Up @@ -265,7 +265,7 @@ def get_template_extremum_amplitude(
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
return_scaled = templates_or_sorting_analyzer.return_scaled
else:
return_scaled = True
return_scaled = templates_or_sorting_analyzer.is_scaled

extremum_amplitudes = get_template_amplitudes(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled, abs_value=abs_value
Expand Down

0 comments on commit 5a195f6

Please sign in to comment.