Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1b66a49
Smal fix or backward compatibuility
samuelgarcia Jul 16, 2024
c97802a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
841796c
oups
samuelgarcia Jul 16, 2024
9a918ba
Merge branch 'fix_sa' of github.com:samuelgarcia/spikeinterface into …
samuelgarcia Jul 16, 2024
7253e4f
Handle dtype in merges for unit_ids
samuelgarcia Jul 16, 2024
0b29f70
oups
samuelgarcia Jul 17, 2024
1e026cc
fix dtype when merging
samuelgarcia Jul 17, 2024
3284469
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
772c5e7
Merge branch 'main' into fix_sa
alejoe91 Jul 17, 2024
3cd4d03
Use also _handle_backward_compatibility_on_load for ComputeTemplates
samuelgarcia Jul 17, 2024
f457ba7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
8b1079b
waveforms backward compatibility, unit locations, and quality metrics…
alejoe91 Jul 17, 2024
a854dca
Merge branch 'fix_sa' of https://github.com/samuelgarcia/spikeinterfa…
alejoe91 Jul 17, 2024
8091f60
Fix pop
alejoe91 Jul 17, 2024
097706b
Extension job kwargs: check has_recording or has_temporary_recording
alejoe91 Jul 17, 2024
14615b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
456d15b
fix dict method
zm711 Jul 17, 2024
7bfccc2
Fix templates backward compatibility
alejoe91 Jul 18, 2024
a22182e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
11c0514
Merge branch 'main' into fix_sa
alejoe91 Jul 18, 2024
03c3d41
Update src/spikeinterface/core/waveforms_extractor_backwards_compatib…
alejoe91 Jul 18, 2024
f2bc276
move update params
alejoe91 Jul 18, 2024
ace9c19
Merge branch 'fix_sa' of https://github.com/samuelgarcia/spikeinterfa…
alejoe91 Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,16 @@ class ComputeTemplates(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = True
need_backward_compatibility_on_load = True

def _handle_backward_compatibility_on_load(self):
if "ms_before" not in self.params:
# compatibility february 2024 > july 2024
self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency

if "ms_after" not in self.params:
# compatibility february 2024 > july 2024
self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency

def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=None):
operators = operators or ["average", "std"]
Expand Down Expand Up @@ -487,31 +497,11 @@ def _compute_and_append_from_waveforms(self, operators):

@property
def nbefore(self):
if "ms_before" not in self.params:
# compatibility february 2024 > july 2024
self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency
warnings.warn(
"The 'nbefore' parameter is deprecated and it's been replaced by 'ms_before' in the params."
"You can save the sorting_analyzer to update the params.",
DeprecationWarning,
stacklevel=2,
)

nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
return nbefore

@property
def nafter(self):
if "ms_after" not in self.params:
# compatibility february 2024 > july 2024
warnings.warn(
"The 'nafter' parameter is deprecated and it's been replaced by 'ms_after' in the params."
"You can save the sorting_analyzer to update the params.",
DeprecationWarning,
stacklevel=2,
)
self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency

nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
return nafter

Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids):

"""
old_unit_ids = np.asarray(old_unit_ids)
dtype = old_unit_ids.dtype
if dtype.kind == "U":
# the new dtype can be longer
dtype = "U"

assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups"

Expand All @@ -361,7 +365,7 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids):
all_unit_ids.remove(unit_id)
if new_unit_id not in all_unit_ids:
all_unit_ids.append(new_unit_id)
return np.array(all_unit_ids)
return np.array(all_unit_ids, dtype=dtype)


def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy="append"):
Expand Down
20 changes: 16 additions & 4 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,10 @@ def _save_or_select_or_merge(
# make a copy of extensions
# note that the copy of extension handle itself the slicing of units when necessary and also the saveing
sorted_extensions = _sort_extensions_by_dependency(self.extensions)
# hack: quality metrics are computed at last
qm_extension_params = sorted_extensions.pop("quality_metrics", None)
if qm_extension_params is not None:
sorted_extensions["quality_metrics"] = qm_extension_params
recompute_dict = {}

for extension_name, extension in sorted_extensions.items():
Expand Down Expand Up @@ -1204,7 +1208,9 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar

# check dependencies
if extension_class.need_recording:
assert self.has_recording(), f"Extension {extension_name} requires the recording"
assert (
self.has_recording() or self.has_temporary_recording()
), f"Extension {extension_name} requires the recording"
for dependency_name in extension_class.depend_on:
if "|" in dependency_name:
ok = any(self.get_extension(name) is not None for name in dependency_name.split("|"))
Expand Down Expand Up @@ -1398,9 +1404,7 @@ def load_extension(self, extension_name: str):

extension_class = get_extension_class(extension_name)

extension_instance = extension_class(self)
extension_instance.load_params()
extension_instance.load_data()
extension_instance = extension_class.load(self)

self.extensions[extension_name] = extension_instance

Expand Down Expand Up @@ -1699,6 +1703,7 @@ class AnalyzerExtension:
use_nodepipeline = False
nodepipeline_variables = None
need_job_kwargs = False
need_backward_compatibility_on_load = False

def __init__(self, sorting_analyzer):
self._sorting_analyzer = weakref.ref(sorting_analyzer)
Expand Down Expand Up @@ -1737,6 +1742,10 @@ def _get_data(self):
# must be implemented in subclass
raise NotImplementedError

def _handle_backward_compatibility_on_load(self):
# must be implemented in subclass only if need_backward_compatibility_on_load=True
raise NotImplementedError

@classmethod
def function_factory(cls):
# make equivalent
Expand Down Expand Up @@ -1814,6 +1823,9 @@ def load(cls, sorting_analyzer):
ext = cls(sorting_analyzer)
ext.load_params()
ext.load_data()
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()

return ext

def load_params(self):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ def test_generate_unit_ids_for_merge_group():

test_apply_merges_to_sorting()
test_get_ids_after_merging()
test_generate_unit_ids_for_merge_group()
# test_generate_unit_ids_for_merge_group()
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
templates[mode] = np.load(template_file)
if len(templates) > 0:
ext = ComputeTemplates(sorting_analyzer)
ext.params = dict(nbefore=nbefore, nafter=nafter, operators=list(templates.keys()))
ext.params = dict(ms_before=params["ms_before"], ms_after=params["ms_after"], operators=list(templates.keys()))
for mode, arr in templates.items():
ext.data[mode] = arr
sorting_analyzer.extensions["templates"] = ext
Expand All @@ -544,10 +544,6 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
ext = new_class(sorting_analyzer)
with open(ext_folder / "params.json", "r") as f:
params = json.load(f)
# update params
new_params = ext._set_params()
updated_params = make_ext_params_up_to_date(ext, params, new_params)
ext.set_params(**updated_params)

if new_name == "spike_amplitudes":
amplitudes = []
Expand Down Expand Up @@ -604,6 +600,13 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
pc_all[mask, ...] = pc_one
ext.data["pca_projection"] = pc_all

# update params
new_params = ext._set_params()
updated_params = make_ext_params_up_to_date(ext, params, new_params)
ext.set_params(**updated_params, save=False)
if ext.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()

sorting_analyzer.extensions[new_name] = ext

return sorting_analyzer
Expand All @@ -614,13 +617,12 @@ def make_ext_params_up_to_date(ext, old_params, new_params):
old_name = ext.extension_name
updated_params = old_params.copy()
for p, values in old_params.items():
if isinstance(values, dict):
if p not in new_params:
warnings.warn(f"Removing legacy parameter {p} from {old_name} extension")
updated_params.pop(p)
elif isinstance(values, dict):
new_values = new_params.get(p, {})
updated_params[p] = make_ext_params_up_to_date(ext, values, new_values)
else:
if p not in new_params:
warnings.warn(f"Removing legacy param {p} from {old_name} extension")
updated_params.pop(p)
return updated_params


Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ class ComputeTemplateSimilarity(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _handle_backward_compatibility_on_load(self):
if "max_lag_ms" not in self.params:
# make compatible analyzer created between february 24 and july 24
self.params["max_lag_ms"] = 0.0
self.params["support"] = "union"

def _set_params(self, method="cosine", max_lag_ms=0, support="union"):
if method == "cosine_similarity":
warnings.warn(
Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/postprocessing/unit_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,17 @@ class ComputeUnitLocations(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _handle_backward_compatibility_on_load(self):
if "method_kwargs" in self.params:
# make compatible analyzer created between february 24 and july 24
method_kwargs = self.params.pop("method_kwargs")
self.params.update(**method_kwargs)
Comment thread
alejoe91 marked this conversation as resolved.

def _set_params(self, method="monopolar_triangulation", **method_kwargs):
params = dict(method=method)
params.update(method_kwargs)
Expand Down