Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .paircomparisons import compare_sorter_to_ground_truth

from .studytools import (setup_comparison_study, get_rec_names, get_recordings,
iter_output_folders, iter_computed_names, iter_computed_sorting, collect_run_times)
iter_working_folder, iter_computed_names, iter_computed_sorting, collect_run_times)


class GroundTruthStudy:
Expand Down Expand Up @@ -102,7 +102,7 @@ def copy_sortings(self):

log_olders.mkdir(parents=True, exist_ok=True)

for rec_name, sorter_name, output_folder in iter_output_folders(sorter_folders):
for rec_name, sorter_name, output_folder in iter_working_folder(sorter_folders):
SorterClass = sorter_dict[sorter_name]
fname = rec_name + '[#]' + sorter_name
npz_filename = sorting_folders / (fname + '.npz')
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface/comparison/studytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from spikeinterface.core import load_extractor
from spikeinterface.extractors import NpzSortingExtractor
from spikeinterface.sorters import sorter_dict
from spikeinterface.sorters.launcher import iter_output_folders, iter_sorting_output
from spikeinterface.sorters.launcher import iter_working_folder, iter_sorting_output

from .comparisontools import _perf_keys
from .paircomparisons import compare_sorter_to_ground_truth
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface/sorters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .runsorter import *

from .launcher import (run_sorters, run_sorter_by_property,
collect_sorting_outputs, iter_output_folders, iter_sorting_output)
collect_sorting_outputs, iter_working_folder, iter_sorting_output)
71 changes: 48 additions & 23 deletions spikeinterface/sorters/launcher.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
Utils functions to launch several sorter on several recording in parallel or not.
"""
import os
from pathlib import Path
import shutil
import numpy as np
import json

from spikeinterface.core import load_extractor, aggregate_units
from spikeinterface.core.core_tools import check_json

from .sorterlist import sorter_dict
from .runsorter import run_sorter, _common_param_doc, run_sorter
Expand Down Expand Up @@ -110,8 +110,7 @@ def run_sorter_by_property(sorter_name,

assert grouping_property in recording.get_property_keys(), f"The 'grouping_property' {grouping_property} is not " \
f"a recording property!"
recording_dict = recording.split_by(grouping_property)

recording_dict = recording.split_by(grouping_property)
sorting_output = run_sorters([sorter_name], recording_dict, working_folder,
mode_if_folder_exists=mode_if_folder_exists,
engine=engine,
Expand All @@ -122,13 +121,17 @@ def run_sorter_by_property(sorter_name,
singularity_images={sorter_name: singularity_image},
sorter_params={sorter_name: sorter_params})

grouping_property_values = np.array([])
grouping_property_values = None
sorting_list = []
for (output_name, sorting) in sorting_output.items():
prop_name, sorter_name = output_name
sorting_list.append(sorting)
grouping_property_values = np.concatenate(
(grouping_property_values, [prop_name] * len(sorting.get_unit_ids())))
if grouping_property_values is None:
grouping_property_values = np.array([prop_name] * len(sorting.get_unit_ids()),
dtype=np.dtype(type(prop_name)))
else:
grouping_property_values = np.concatenate(
(grouping_property_values, [prop_name] * len(sorting.get_unit_ids())))

aggregate_sorting = aggregate_units(sorting_list)
aggregate_sorting.set_property(
Expand Down Expand Up @@ -213,6 +216,10 @@ def run_sorters(sorter_list,
recording_dict = recording_dict_or_list
else:
raise ValueError('bad recording dict')

dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0]))
assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!"


need_dump = engine != 'loop'
task_args_list = []
Expand Down Expand Up @@ -249,7 +256,7 @@ def run_sorters(sorter_list,

task_args = (sorter_name, recording_arg, output_folder,
verbose, params, docker_image, singularity_image, with_output)
task_args_list.append(task_args)
task_args_list.append(task_args)

if engine == 'loop':
# simple loop in main process
Expand All @@ -274,6 +281,18 @@ def run_sorters(sorter_list,

for task in tasks:
task.result()

# dump spikeinterface_job.json
for rec_name, recording in recording_dict.items():
for sorter_name in sorter_list:
output_folder = working_folder / str(rec_name) / sorter_name
with open(output_folder / "spikeinterface_job.json", "w") as f:
dump_dict = {"rec_name": rec_name,
"sorter_name": sorter_name,
"engine": engine}
if engine != "dask":
dump_dict.update({"engine_kwargs": engine_kwargs})
json.dump(check_json(dump_dict), f)

if with_output:
if engine == 'dask':
Expand All @@ -297,35 +316,41 @@ def is_log_ok(output_folder):
return False


def iter_output_folders(output_folders):
output_folders = Path(output_folders)
for rec_name in os.listdir(output_folders):
if not os.path.isdir(output_folders / rec_name):
def iter_working_folder(working_folder):
working_folder = Path(working_folder)
for rec_folder in working_folder.iterdir():
if not rec_folder.is_dir():
continue
for sorter_name in os.listdir(output_folders / rec_name):
output_folder = output_folders / rec_name / sorter_name
if not os.path.isdir(output_folder):
continue
if not is_log_ok(output_folder):
continue
yield rec_name, sorter_name, output_folder
for output_folder in rec_folder.iterdir():
if (output_folder / "spikeinterface_job.json").is_file():
with open(output_folder / "spikeinterface_job.json", "r") as f:
job_dict = json.load(f)
rec_name = job_dict["rec_name"]
sorter_name = job_dict["sorter_name"]
yield rec_name, sorter_name, output_folder
else:
if not output_folder.is_dir():
continue
if not is_log_ok(output_folder):
continue
yield rec_name, sorter_name, output_folder


def iter_sorting_output(output_folders):
def iter_sorting_output(working_folder):
"""Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting)."""
for rec_name, sorter_name, output_folder in iter_output_folders(output_folders):
for rec_name, sorter_name, output_folder in iter_working_folder(working_folder):
SorterClass = sorter_dict[sorter_name]
sorting = SorterClass.get_result_from_folder(output_folder)
yield rec_name, sorter_name, sorting


def collect_sorting_outputs(output_folders):
"""Collect results in a output_folders.
def collect_sorting_outputs(working_folder):
"""Collect results in a working_folder.

The output is a dict with double key access results[(rec_name, sorter_name)] of SortingExtractor.
"""
results = {}
for rec_name, sorter_name, sorting in iter_sorting_output(output_folders):
for rec_name, sorter_name, sorting in iter_sorting_output(working_folder):
results[(rec_name, sorter_name)] = sorting
return results

Expand Down
37 changes: 29 additions & 8 deletions spikeinterface/sorters/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,42 @@ def test_run_sorters_with_list():


def test_run_sorter_by_property():
working_folder = cache_folder / 'test_run_sorter_by_property'
if working_folder.is_dir():
shutil.rmtree(working_folder)
working_folder1 = cache_folder / 'test_run_sorter_by_property1'
if working_folder1.is_dir():
shutil.rmtree(working_folder1)
working_folder2 = cache_folder / 'test_run_sorter_by_property2'
if working_folder2.is_dir():
shutil.rmtree(working_folder2)

rec0, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1)
rec0.set_channel_groups(["0"] * 4 + ["1"] * 4)
rec0_by = rec0.split_by("group")
group_names0 = list(rec0_by.keys())

# make dumpable
set_global_tmp_folder(cache_folder)
rec0 = rec0.save(name='rec000')
sorter_name = 'tridesclous'

sorting = run_sorter_by_property(sorter_name, rec0, "group", working_folder,
engine='loop', verbose=False)
assert "group" in sorting.get_property_keys()
sorting0 = run_sorter_by_property(sorter_name, rec0, "group", working_folder1,
engine='loop', verbose=False)
assert "group" in sorting0.get_property_keys()
assert all([g in group_names0 for g in sorting0.get_property("group")])

rec1, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1)
rec1.set_channel_groups([0] * 4 + [1] * 4)
rec1_by = rec1.split_by("group")
group_names1 = list(rec1_by.keys())

# make dumpable
set_global_tmp_folder(cache_folder)
rec1 = rec1.save(name='rec001')
sorter_name = 'tridesclous'

sorting1 = run_sorter_by_property(sorter_name, rec1, "group", working_folder2,
engine='loop', verbose=False)
assert "group" in sorting1.get_property_keys()
assert all([g in group_names1 for g in sorting1.get_property("group")])


def test_run_sorters_with_dict():
Expand Down Expand Up @@ -170,10 +191,10 @@ def test_sorter_installation():


if __name__ == '__main__':
pass
#pass
# test_run_sorters_with_list()

test_run_sorter_by_property()
test_run_sorter_by_property()

# test_run_sorters_with_dict()

Expand Down