diff --git a/spikeinterface/comparison/groundtruthstudy.py b/spikeinterface/comparison/groundtruthstudy.py index 51044da097..6b270ed7af 100644 --- a/spikeinterface/comparison/groundtruthstudy.py +++ b/spikeinterface/comparison/groundtruthstudy.py @@ -15,7 +15,7 @@ from .paircomparisons import compare_sorter_to_ground_truth from .studytools import (setup_comparison_study, get_rec_names, get_recordings, - iter_output_folders, iter_computed_names, iter_computed_sorting, collect_run_times) + iter_working_folder, iter_computed_names, iter_computed_sorting, collect_run_times) class GroundTruthStudy: @@ -102,7 +102,7 @@ def copy_sortings(self): log_olders.mkdir(parents=True, exist_ok=True) - for rec_name, sorter_name, output_folder in iter_output_folders(sorter_folders): + for rec_name, sorter_name, output_folder in iter_working_folder(sorter_folders): SorterClass = sorter_dict[sorter_name] fname = rec_name + '[#]' + sorter_name npz_filename = sorting_folders / (fname + '.npz') diff --git a/spikeinterface/comparison/studytools.py b/spikeinterface/comparison/studytools.py index 7a75d575e2..da464aba13 100644 --- a/spikeinterface/comparison/studytools.py +++ b/spikeinterface/comparison/studytools.py @@ -22,7 +22,7 @@ from spikeinterface.core import load_extractor from spikeinterface.extractors import NpzSortingExtractor from spikeinterface.sorters import sorter_dict -from spikeinterface.sorters.launcher import iter_output_folders, iter_sorting_output +from spikeinterface.sorters.launcher import iter_working_folder, iter_sorting_output from .comparisontools import _perf_keys from .paircomparisons import compare_sorter_to_ground_truth diff --git a/spikeinterface/sorters/__init__.py b/spikeinterface/sorters/__init__.py index c4572bb1ed..9a36634edb 100644 --- a/spikeinterface/sorters/__init__.py +++ b/spikeinterface/sorters/__init__.py @@ -3,4 +3,4 @@ from .runsorter import * from .launcher import (run_sorters, run_sorter_by_property, - collect_sorting_outputs, iter_output_folders, iter_sorting_output) + collect_sorting_outputs, iter_working_folder, iter_sorting_output) diff --git a/spikeinterface/sorters/launcher.py b/spikeinterface/sorters/launcher.py index de4e0127b4..54ff5ac176 100644 --- a/spikeinterface/sorters/launcher.py +++ b/spikeinterface/sorters/launcher.py @@ -1,13 +1,13 @@ """ Utils functions to launch several sorter on several recording in parallel or not. """ -import os from pathlib import Path import shutil import numpy as np import json from spikeinterface.core import load_extractor, aggregate_units +from spikeinterface.core.core_tools import check_json from .sorterlist import sorter_dict from .runsorter import run_sorter, _common_param_doc, run_sorter @@ -110,8 +110,7 @@ def run_sorter_by_property(sorter_name, assert grouping_property in recording.get_property_keys(), f"The 'grouping_property' {grouping_property} is not " \ f"a recording property!" - recording_dict = recording.split_by(grouping_property) - + recording_dict = recording.split_by(grouping_property) sorting_output = run_sorters([sorter_name], recording_dict, working_folder, mode_if_folder_exists=mode_if_folder_exists, engine=engine, @@ -122,13 +121,17 @@ def run_sorter_by_property(sorter_name, singularity_images={sorter_name: singularity_image}, sorter_params={sorter_name: sorter_params}) - grouping_property_values = np.array([]) + grouping_property_values = None sorting_list = [] for (output_name, sorting) in sorting_output.items(): prop_name, sorter_name = output_name sorting_list.append(sorting) - grouping_property_values = np.concatenate( - (grouping_property_values, [prop_name] * len(sorting.get_unit_ids()))) + if grouping_property_values is None: + grouping_property_values = np.array([prop_name] * len(sorting.get_unit_ids()), + dtype=np.dtype(type(prop_name))) + else: + grouping_property_values = np.concatenate( + (grouping_property_values, [prop_name] * len(sorting.get_unit_ids()))) aggregate_sorting = aggregate_units(sorting_list) aggregate_sorting.set_property( @@ -213,6 +216,10 @@ def run_sorters(sorter_list, recording_dict = recording_dict_or_list else: raise ValueError('bad recording dict') + + dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0])) + assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!" + need_dump = engine != 'loop' task_args_list = [] @@ -249,7 +256,7 @@ def run_sorters(sorter_list, task_args = (sorter_name, recording_arg, output_folder, verbose, params, docker_image, singularity_image, with_output) - task_args_list.append(task_args) + task_args_list.append(task_args) if engine == 'loop': # simple loop in main process @@ -274,6 +281,18 @@ def run_sorters(sorter_list, for task in tasks: task.result() + + # dump spikeinterface_job.json + for rec_name, recording in recording_dict.items(): + for sorter_name in sorter_list: + output_folder = working_folder / str(rec_name) / sorter_name + with open(output_folder / "spikeinterface_job.json", "w") as f: + dump_dict = {"rec_name": rec_name, + "sorter_name": sorter_name, + "engine": engine} + if engine != "dask": + dump_dict.update({"engine_kwargs": engine_kwargs}) + json.dump(check_json(dump_dict), f) if with_output: if engine == 'dask': @@ -297,35 +316,41 @@ def is_log_ok(output_folder): return False -def iter_output_folders(output_folders): - output_folders = Path(output_folders) - for rec_name in os.listdir(output_folders): - if not os.path.isdir(output_folders / rec_name): +def iter_working_folder(working_folder): + working_folder = Path(working_folder) + for rec_folder in working_folder.iterdir(): + if not rec_folder.is_dir(): continue - for sorter_name in os.listdir(output_folders / rec_name): - output_folder = output_folders / rec_name / sorter_name - if not os.path.isdir(output_folder): - continue - if not is_log_ok(output_folder): - continue - yield rec_name, sorter_name, output_folder + for output_folder in rec_folder.iterdir(): + if (output_folder / "spikeinterface_job.json").is_file(): + with open(output_folder / "spikeinterface_job.json", "r") as f: + job_dict = json.load(f) + rec_name = job_dict["rec_name"] + sorter_name = job_dict["sorter_name"] + yield rec_name, sorter_name, output_folder + else: + if not output_folder.is_dir(): + continue + if not is_log_ok(output_folder): + continue + yield rec_name, sorter_name, output_folder -def iter_sorting_output(output_folders): +def iter_sorting_output(working_folder): """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" - for rec_name, sorter_name, output_folder in iter_output_folders(output_folders): + for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): SorterClass = sorter_dict[sorter_name] sorting = SorterClass.get_result_from_folder(output_folder) yield rec_name, sorter_name, sorting -def collect_sorting_outputs(output_folders): - """Collect results in a output_folders. +def collect_sorting_outputs(working_folder): + """Collect results in a working_folder. The output is a dict with double key access results[(rec_name, sorter_name)] of SortingExtractor. """ results = {} - for rec_name, sorter_name, sorting in iter_sorting_output(output_folders): + for rec_name, sorter_name, sorting in iter_sorting_output(working_folder): results[(rec_name, sorter_name)] = sorting return results diff --git a/spikeinterface/sorters/tests/test_launcher.py b/spikeinterface/sorters/tests/test_launcher.py index 1b644a7953..93820e5c9c 100644 --- a/spikeinterface/sorters/tests/test_launcher.py +++ b/spikeinterface/sorters/tests/test_launcher.py @@ -37,21 +37,42 @@ def test_run_sorters_with_list(): def test_run_sorter_by_property(): - working_folder = cache_folder / 'test_run_sorter_by_property' - if working_folder.is_dir(): - shutil.rmtree(working_folder) + working_folder1 = cache_folder / 'test_run_sorter_by_property1' + if working_folder1.is_dir(): + shutil.rmtree(working_folder1) + working_folder2 = cache_folder / 'test_run_sorter_by_property2' + if working_folder2.is_dir(): + shutil.rmtree(working_folder2) rec0, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1) rec0.set_channel_groups(["0"] * 4 + ["1"] * 4) + rec0_by = rec0.split_by("group") + group_names0 = list(rec0_by.keys()) # make dumpable set_global_tmp_folder(cache_folder) rec0 = rec0.save(name='rec000') sorter_name = 'tridesclous' - sorting = run_sorter_by_property(sorter_name, rec0, "group", working_folder, - engine='loop', verbose=False) - assert "group" in sorting.get_property_keys() + sorting0 = run_sorter_by_property(sorter_name, rec0, "group", working_folder1, + engine='loop', verbose=False) + assert "group" in sorting0.get_property_keys() + assert all([g in group_names0 for g in sorting0.get_property("group")]) + + rec1, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1) + rec1.set_channel_groups([0] * 4 + [1] * 4) + rec1_by = rec1.split_by("group") + group_names1 = list(rec1_by.keys()) + + # make dumpable + set_global_tmp_folder(cache_folder) + rec1 = rec1.save(name='rec001') + sorter_name = 'tridesclous' + + sorting1 = run_sorter_by_property(sorter_name, rec1, "group", working_folder2, + engine='loop', verbose=False) + assert "group" in sorting1.get_property_keys() + assert all([g in group_names1 for g in sorting1.get_property("group")]) def test_run_sorters_with_dict(): @@ -170,10 +191,10 @@ def test_sorter_installation(): if __name__ == '__main__': - pass + #pass # test_run_sorters_with_list() - # test_run_sorter_by_property() + test_run_sorter_by_property() # test_run_sorters_with_dict()