From 1cf54972f4209b9ad63de95df5a72d5c35b5fd87 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 12 Apr 2024 08:35:34 -0400 Subject: [PATCH 1/5] remove default all core for sorters for n_jobs --- src/spikeinterface/sorters/basesorter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 1465d205b4..5dc99d2a84 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -21,11 +21,11 @@ from .utils import SpikeSortingError, ShellScript -default_job_kwargs = {"n_jobs": -1} +default_job_kwargs = {} default_job_kwargs_description = { - "n_jobs": "Number of jobs (when saving ti binary) - default -1 (all cores)", - "chunk_size": "Number of samples per chunk (when saving ti binary) - default global", + "n_jobs": "Number of jobs (when saving to binary) - default global", + "chunk_size": "Number of samples per chunk (when saving to binary) - default global", "chunk_memory": "Memory usage for each job (e.g. '100M', '1G') (when saving to binary) - default global", "total_memory": "Total memory usage (e.g. '500M', '2G') (when saving to binary) - default global", "chunk_duration": "Chunk duration in s if float or with units if str (e.g. '1s', '500ms') (when saving to binary)" From aec9c16e08af3052131cfabcb5777fc9d56311c6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 14 May 2024 13:04:37 +0200 Subject: [PATCH 2/5] Remove sorter default jobs --- src/spikeinterface/sorters/basesorter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 5dc99d2a84..4cef5e9966 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -17,12 +17,11 @@ from spikeinterface.core import load_extractor, BaseRecordingSnippets from spikeinterface.core.core_tools import check_json +from spikeinterface.core.globals import get_global_job_kwargs from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .utils import SpikeSortingError, ShellScript -default_job_kwargs = {} - default_job_kwargs_description = { "n_jobs": "Number of jobs (when saving to binary) - default global", "chunk_size": "Number of samples per chunk (when saving to binary) - default global", @@ -156,7 +155,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo def default_params(cls): p = copy.deepcopy(cls._default_params) if cls.requires_binary_data: - job_kwargs = fix_job_kwargs(default_job_kwargs) + job_kwargs = get_global_job_kwargs() p.update(job_kwargs) return p From cb51854309adb1e2de3bf0cc269b77c2a6441f35 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 14 May 2024 13:25:02 +0200 Subject: [PATCH 3/5] Add warning for n_jobs --- src/spikeinterface/core/job_tools.py | 17 +++++++++++++---- src/spikeinterface/core/tests/test_globals.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 779414b337..984d306be1 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -10,7 +10,6 @@ import warnings import sys -import contextlib from tqdm.auto import tqdm from concurrent.futures import ProcessPoolExecutor @@ -28,8 +27,9 @@ Total memory usage (e.g. "500M", "2G") - chunk_duration : str or float or None Chunk duration in s if float or with units if str (e.g. "1s", "500ms") - * n_jobs: int - Number of jobs to use. With -1 the number of jobs is the same as number of cores + * n_jobs: int | float + Number of jobs to use. With -1 the number of jobs is the same as number of cores. + Using a float between 0 and 1 will use that fraction of the total cores. * progress_bar: bool If True, a progress bar is printed * mp_context: "fork" | "spawn" | None, default: None @@ -60,7 +60,7 @@ def fix_job_kwargs(runtime_job_kwargs): - from .globals import get_global_job_kwargs + from .globals import get_global_job_kwargs, global_job_kwargs_set job_kwargs = get_global_job_kwargs() @@ -99,6 +99,15 @@ def fix_job_kwargs(runtime_job_kwargs): job_kwargs["n_jobs"] = max(n_jobs, 1) + if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not global_job_kwargs_set: + warnings.warn( + "`n_jobs` is not set so parallel processing is disabled! " + "To speed up computations, it is recommended to set n_jobs either " + "globally (with the `spikeinterface.set_global_job_kwargs()` function) or " + "locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` " + "for more information about job_kwargs." + ) + return job_kwargs diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index d0672405d6..a45bb6f49c 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -1,4 +1,5 @@ import pytest +import warnings from pathlib import Path from spikeinterface import ( @@ -39,11 +40,22 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() + + # test warning when not setting n_jobs and calling fix_job_kwargs + with pytest.warns(UserWarning): + job_kwargs_split = fix_job_kwargs({}) + assert global_job_kwargs == dict( n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs + + # after setting global job kwargs, fix_job_kwargs should not raise a warning + with warnings.catch_warnings(): + warnings.simplefilter("error") + job_kwargs_split = fix_job_kwargs({}) + # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) From 24b5eb22170c36640e67e3070fd5219a82661c5c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 May 2024 13:14:51 +0200 Subject: [PATCH 4/5] Use function instead of global variable --- src/spikeinterface/core/globals.py | 8 ++++++++ src/spikeinterface/core/job_tools.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index bbce3998c6..8386c358d3 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -130,6 +130,14 @@ def set_global_job_kwargs(**job_kwargs): global_job_kwargs_set = True +def is_global_job_kwargs_set() -> bool: + """ + Check is the global job kwargs have been manually set. + """ + global global_job_kwargs_set + return global_job_kwargs_set + + def reset_global_job_kwargs(): """ Reset the global job kwargs. diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 984d306be1..f7d9f216ff 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -60,7 +60,7 @@ def fix_job_kwargs(runtime_job_kwargs): - from .globals import get_global_job_kwargs, global_job_kwargs_set + from .globals import get_global_job_kwargs, is_global_job_kwargs_set job_kwargs = get_global_job_kwargs() @@ -99,7 +99,7 @@ def fix_job_kwargs(runtime_job_kwargs): job_kwargs["n_jobs"] = max(n_jobs, 1) - if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not global_job_kwargs_set: + if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_global_job_kwargs_set(): warnings.warn( "`n_jobs` is not set so parallel processing is disabled! " "To speed up computations, it is recommended to set n_jobs either " From 2da00a424039d7f6770af96910918d17fae81685 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 17 May 2024 11:20:19 -0400 Subject: [PATCH 5/5] use original is_set_global_job_kwargs_set function --- src/spikeinterface/core/globals.py | 21 ++++++++------------- src/spikeinterface/core/job_tools.py | 4 ++-- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index 8386c358d3..23d60a5ac5 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -42,9 +42,9 @@ def set_global_tmp_folder(folder): temp_folder_set = True -def is_set_global_tmp_folder(): +def is_set_global_tmp_folder() -> bool: """ - Check is the global path temporary folder have been manually set. + Check if the global path temporary folder have been manually set. """ global temp_folder_set return temp_folder_set @@ -88,9 +88,9 @@ def set_global_dataset_folder(folder): dataset_folder_set = True -def is_set_global_dataset_folder(): +def is_set_global_dataset_folder() -> bool: """ - Check is the global path dataset folder have been manually set. + Check if the global path dataset folder has been manually set. """ global dataset_folder_set return dataset_folder_set @@ -130,14 +130,6 @@ def set_global_job_kwargs(**job_kwargs): global_job_kwargs_set = True -def is_global_job_kwargs_set() -> bool: - """ - Check is the global job kwargs have been manually set. - """ - global global_job_kwargs_set - return global_job_kwargs_set - - def reset_global_job_kwargs(): """ Reset the global job kwargs. @@ -146,7 +138,10 @@ def reset_global_job_kwargs(): global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) -def is_set_global_job_kwargs_set(): +def is_set_global_job_kwargs_set() -> bool: + """ + Check if the global job kwargs have been manually set. + """ global global_job_kwargs_set return global_job_kwargs_set diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index f7d9f216ff..201b11db2b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -60,7 +60,7 @@ def fix_job_kwargs(runtime_job_kwargs): - from .globals import get_global_job_kwargs, is_global_job_kwargs_set + from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set job_kwargs = get_global_job_kwargs() @@ -99,7 +99,7 @@ def fix_job_kwargs(runtime_job_kwargs): job_kwargs["n_jobs"] = max(n_jobs, 1) - if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_global_job_kwargs_set(): + if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_set_global_job_kwargs_set(): warnings.warn( "`n_jobs` is not set so parallel processing is disabled! " "To speed up computations, it is recommended to set n_jobs either "