diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index ad23a5f249..ff1dc5dafa 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -364,6 +364,16 @@ class ComputeTemplates(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = True + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + if "ms_before" not in self.params: + # compatibility february 2024 > july 2024 + self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency + + if "ms_after" not in self.params: + # compatibility february 2024 > july 2024 + self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=None): operators = operators or ["average", "std"] @@ -487,31 +497,11 @@ def _compute_and_append_from_waveforms(self, operators): @property def nbefore(self): - if "ms_before" not in self.params: - # compatibility february 2024 > july 2024 - self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency - warnings.warn( - "The 'nbefore' parameter is deprecated and it's been replaced by 'ms_before' in the params." - "You can save the sorting_analyzer to update the params.", - DeprecationWarning, - stacklevel=2, - ) - nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) return nbefore @property def nafter(self): - if "ms_after" not in self.params: - # compatibility february 2024 > july 2024 - warnings.warn( - "The 'nafter' parameter is deprecated and it's been replaced by 'ms_after' in the params." - "You can save the sorting_analyzer to update the params.", - DeprecationWarning, - stacklevel=2, - ) - self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency - nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) return nafter diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 2a2f7b6b5a..6994575150 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -347,6 +347,10 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): """ old_unit_ids = np.asarray(old_unit_ids) + dtype = old_unit_ids.dtype + if dtype.kind == "U": + # the new dtype can be longer + dtype = "U" assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups" @@ -361,7 +365,7 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): all_unit_ids.remove(unit_id) if new_unit_id not in all_unit_ids: all_unit_ids.append(new_unit_id) - return np.array(all_unit_ids) + return np.array(all_unit_ids, dtype=dtype) def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy="append"): diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 27a47a31ac..3e92733974 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -777,6 +777,10 @@ def _save_or_select_or_merge( # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing sorted_extensions = _sort_extensions_by_dependency(self.extensions) + # hack: quality metrics are computed at last + qm_extension_params = sorted_extensions.pop("quality_metrics", None) + if qm_extension_params is not None: + sorted_extensions["quality_metrics"] = qm_extension_params recompute_dict = {} for extension_name, extension in sorted_extensions.items(): @@ -1204,7 +1208,9 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar # check dependencies if extension_class.need_recording: - assert self.has_recording(), f"Extension {extension_name} requires the recording" + assert ( + self.has_recording() or self.has_temporary_recording() + ), f"Extension {extension_name} requires the recording" for dependency_name in extension_class.depend_on: if "|" in dependency_name: ok = any(self.get_extension(name) is not None for name in dependency_name.split("|")) @@ -1398,9 +1404,7 @@ def load_extension(self, extension_name: str): extension_class = get_extension_class(extension_name) - extension_instance = extension_class(self) - extension_instance.load_params() - extension_instance.load_data() + extension_instance = extension_class.load(self) self.extensions[extension_name] = extension_instance @@ -1699,6 +1703,7 @@ class AnalyzerExtension: use_nodepipeline = False nodepipeline_variables = None need_job_kwargs = False + need_backward_compatibility_on_load = False def __init__(self, sorting_analyzer): self._sorting_analyzer = weakref.ref(sorting_analyzer) @@ -1737,6 +1742,10 @@ def _get_data(self): # must be implemented in subclass raise NotImplementedError + def _handle_backward_compatibility_on_load(self): + # must be implemented in subclass only if need_backward_compatibility_on_load=True + raise NotImplementedError + @classmethod def function_factory(cls): # make equivalent @@ -1814,6 +1823,9 @@ def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() ext.load_data() + if cls.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + return ext def load_params(self): diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 38baf62c35..6d0e61f844 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -161,4 +161,4 @@ def test_generate_unit_ids_for_merge_group(): test_apply_merges_to_sorting() test_get_ids_after_merging() - test_generate_unit_ids_for_merge_group() + # test_generate_unit_ids_for_merge_group() diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index d6d60ee73b..da1f5a71f5 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -531,7 +531,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): templates[mode] = np.load(template_file) if len(templates) > 0: ext = ComputeTemplates(sorting_analyzer) - ext.params = dict(nbefore=nbefore, nafter=nafter, operators=list(templates.keys())) + ext.params = dict(ms_before=params["ms_before"], ms_after=params["ms_after"], operators=list(templates.keys())) for mode, arr in templates.items(): ext.data[mode] = arr sorting_analyzer.extensions["templates"] = ext @@ -544,10 +544,6 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext = new_class(sorting_analyzer) with open(ext_folder / "params.json", "r") as f: params = json.load(f) - # update params - new_params = ext._set_params() - updated_params = make_ext_params_up_to_date(ext, params, new_params) - ext.set_params(**updated_params) if new_name == "spike_amplitudes": amplitudes = [] @@ -604,6 +600,13 @@ def _read_old_waveforms_extractor_binary(folder, sorting): pc_all[mask, ...] = pc_one ext.data["pca_projection"] = pc_all + # update params + new_params = ext._set_params() + updated_params = make_ext_params_up_to_date(ext, params, new_params) + ext.set_params(**updated_params, save=False) + if ext.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + sorting_analyzer.extensions[new_name] = ext return sorting_analyzer @@ -614,13 +617,12 @@ def make_ext_params_up_to_date(ext, old_params, new_params): old_name = ext.extension_name updated_params = old_params.copy() for p, values in old_params.items(): - if isinstance(values, dict): + if p not in new_params: + warnings.warn(f"Removing legacy parameter {p} from {old_name} extension") + updated_params.pop(p) + elif isinstance(values, dict): new_values = new_params.get(p, {}) updated_params[p] = make_ext_params_up_to_date(ext, values, new_values) - else: - if p not in new_params: - warnings.warn(f"Removing legacy param {p} from {old_name} extension") - updated_params.pop(p) return updated_params diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index a9592b0b91..cb4cc323ad 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -43,10 +43,17 @@ class ComputeTemplateSimilarity(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) + def _handle_backward_compatibility_on_load(self): + if "max_lag_ms" not in self.params: + # make compatible analyzer created between february 24 and july 24 + self.params["max_lag_ms"] = 0.0 + self.params["support"] = "union" + def _set_params(self, method="cosine", max_lag_ms=0, support="union"): if method == "cosine_similarity": warnings.warn( diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 516f22e31e..818f0a8062 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -42,10 +42,17 @@ class ComputeUnitLocations(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) + def _handle_backward_compatibility_on_load(self): + if "method_kwargs" in self.params: + # make compatible analyzer created between february 24 and july 24 + method_kwargs = self.params.pop("method_kwargs") + self.params.update(**method_kwargs) + def _set_params(self, method="monopolar_triangulation", **method_kwargs): params = dict(method=method) params.update(method_kwargs)