diff --git a/environment_rtd.yml b/environment_rtd.yml index 1032e19e53..569ecbb66c 100644 --- a/environment_rtd.yml +++ b/environment_rtd.yml @@ -25,10 +25,10 @@ dependencies: - isosplit5 - mountainsort4>=1.0.0 - tridesclous>=1.6.4 - - herdingspikes + - herdingspikes<=0.3.99 - sphinx-gallery - numpydoc - - numpy<1.21 - - git+https://github.com/scikit-learn-contrib/hdbscan.git + - numpy<1.22 + - hdbscan - numba==0.54.1 - git+https://github.com/SpikeInterface/spikeinterface.git diff --git a/spikeinterface/core/base.py b/spikeinterface/core/base.py index 3d9e662b1c..03fab911a4 100644 --- a/spikeinterface/core/base.py +++ b/spikeinterface/core/base.py @@ -8,6 +8,7 @@ import random import string import warnings +from packaging.version import parse import numpy as np @@ -874,9 +875,12 @@ def _get_class_from_string(class_string): def _check_same_version(class_string, version): module = class_string.split('.')[0] imported_module = importlib.import_module(module) + + current_version = parse(imported_module.__version__) + saved_version = parse(version) try: - return imported_module.__version__ == version + return current_version.major == saved_version.major and current_version.minor == saved_version.minor except AttributeError: return 'unknown' diff --git a/spikeinterface/sorters/launcher.py b/spikeinterface/sorters/launcher.py index 54ff5ac176..1f803336eb 100644 --- a/spikeinterface/sorters/launcher.py +++ b/spikeinterface/sorters/launcher.py @@ -329,6 +329,8 @@ def iter_working_folder(working_folder): sorter_name = job_dict["sorter_name"] yield rec_name, sorter_name, output_folder else: + rec_name = rec_folder.name + sorter_name = output_folder.name if not output_folder.is_dir(): continue if not is_log_ok(output_folder): diff --git a/spikeinterface/toolkit/postprocessing/template_metrics.py b/spikeinterface/toolkit/postprocessing/template_metrics.py index 2667537b5b..6ffdf32561 100644 --- a/spikeinterface/toolkit/postprocessing/template_metrics.py +++ b/spikeinterface/toolkit/postprocessing/template_metrics.py @@ -90,7 +90,7 @@ def calculate_template_metrics(waveform_extractor, feature_names=None, peak_sign multi_index = pd.MultiIndex.from_tuples(list(zip(unit_ids, channel_ids)), names=["unit_id", "channel_id"]) template_metrics = pd.DataFrame( - index=multi_index, columns=[feature_names]) + index=multi_index, columns=feature_names) for unit_id in unit_ids: template_all_chans = waveform_extractor.get_template(unit_id)