diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index bbce3998c6..23d60a5ac5 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -42,9 +42,9 @@ def set_global_tmp_folder(folder): temp_folder_set = True -def is_set_global_tmp_folder(): +def is_set_global_tmp_folder() -> bool: """ - Check is the global path temporary folder have been manually set. + Check if the global path temporary folder have been manually set. """ global temp_folder_set return temp_folder_set @@ -88,9 +88,9 @@ def set_global_dataset_folder(folder): dataset_folder_set = True -def is_set_global_dataset_folder(): +def is_set_global_dataset_folder() -> bool: """ - Check is the global path dataset folder have been manually set. + Check if the global path dataset folder has been manually set. """ global dataset_folder_set return dataset_folder_set @@ -138,7 +138,10 @@ def reset_global_job_kwargs(): global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) -def is_set_global_job_kwargs_set(): +def is_set_global_job_kwargs_set() -> bool: + """ + Check if the global job kwargs have been manually set. + """ global global_job_kwargs_set return global_job_kwargs_set diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 6161d5f064..fa79a8ce01 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -28,8 +28,9 @@ Total memory usage (e.g. "500M", "2G") - chunk_duration : str or float or None Chunk duration in s if float or with units if str (e.g. "1s", "500ms") - * n_jobs: int - Number of jobs to use. With -1 the number of jobs is the same as number of cores + * n_jobs: int | float + Number of jobs to use. With -1 the number of jobs is the same as number of cores. + Using a float between 0 and 1 will use that fraction of the total cores. * progress_bar: bool If True, a progress bar is printed * mp_context: "fork" | "spawn" | None, default: None @@ -60,7 +61,7 @@ def fix_job_kwargs(runtime_job_kwargs): - from .globals import get_global_job_kwargs + from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set job_kwargs = get_global_job_kwargs() @@ -99,6 +100,15 @@ def fix_job_kwargs(runtime_job_kwargs): job_kwargs["n_jobs"] = max(n_jobs, 1) + if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_set_global_job_kwargs_set(): + warnings.warn( + "`n_jobs` is not set so parallel processing is disabled! " + "To speed up computations, it is recommended to set n_jobs either " + "globally (with the `spikeinterface.set_global_job_kwargs()` function) or " + "locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` " + "for more information about job_kwargs." + ) + return job_kwargs diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index d0672405d6..a45bb6f49c 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -1,4 +1,5 @@ import pytest +import warnings from pathlib import Path from spikeinterface import ( @@ -39,11 +40,22 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() + + # test warning when not setting n_jobs and calling fix_job_kwargs + with pytest.warns(UserWarning): + job_kwargs_split = fix_job_kwargs({}) + assert global_job_kwargs == dict( n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs + + # after setting global job kwargs, fix_job_kwargs should not raise a warning + with warnings.catch_warnings(): + warnings.simplefilter("error") + job_kwargs_split = fix_job_kwargs({}) + # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 1465d205b4..4cef5e9966 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -17,15 +17,14 @@ from spikeinterface.core import load_extractor, BaseRecordingSnippets from spikeinterface.core.core_tools import check_json +from spikeinterface.core.globals import get_global_job_kwargs from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .utils import SpikeSortingError, ShellScript -default_job_kwargs = {"n_jobs": -1} - default_job_kwargs_description = { - "n_jobs": "Number of jobs (when saving ti binary) - default -1 (all cores)", - "chunk_size": "Number of samples per chunk (when saving ti binary) - default global", + "n_jobs": "Number of jobs (when saving to binary) - default global", + "chunk_size": "Number of samples per chunk (when saving to binary) - default global", "chunk_memory": "Memory usage for each job (e.g. '100M', '1G') (when saving to binary) - default global", "total_memory": "Total memory usage (e.g. '500M', '2G') (when saving to binary) - default global", "chunk_duration": "Chunk duration in s if float or with units if str (e.g. '1s', '500ms') (when saving to binary)" @@ -156,7 +155,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo def default_params(cls): p = copy.deepcopy(cls._default_params) if cls.requires_binary_data: - job_kwargs = fix_job_kwargs(default_job_kwargs) + job_kwargs = get_global_job_kwargs() p.update(job_kwargs) return p