Skip to content

Commit

Permalink
Merge pull request #2712 from zm711/n_jobs_sorter
Browse files Browse the repository at this point in the history
Remove separate default job_kwarg `n_jobs` for sorters
  • Loading branch information
samuelgarcia committed May 21, 2024
2 parents 26c145c + 2ddd206 commit 72c7717
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 13 deletions.
13 changes: 8 additions & 5 deletions src/spikeinterface/core/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def set_global_tmp_folder(folder):
temp_folder_set = True


def is_set_global_tmp_folder():
def is_set_global_tmp_folder() -> bool:
"""
Check is the global path temporary folder have been manually set.
Check if the global path temporary folder have been manually set.
"""
global temp_folder_set
return temp_folder_set
Expand Down Expand Up @@ -88,9 +88,9 @@ def set_global_dataset_folder(folder):
dataset_folder_set = True


def is_set_global_dataset_folder():
def is_set_global_dataset_folder() -> bool:
"""
Check is the global path dataset folder have been manually set.
Check if the global path dataset folder has been manually set.
"""
global dataset_folder_set
return dataset_folder_set
Expand Down Expand Up @@ -138,7 +138,10 @@ def reset_global_job_kwargs():
global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True)


def is_set_global_job_kwargs_set():
def is_set_global_job_kwargs_set() -> bool:
"""
Check if the global job kwargs have been manually set.
"""
global global_job_kwargs_set
return global_job_kwargs_set

Expand Down
16 changes: 13 additions & 3 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
Total memory usage (e.g. "500M", "2G")
- chunk_duration : str or float or None
Chunk duration in s if float or with units if str (e.g. "1s", "500ms")
* n_jobs: int
Number of jobs to use. With -1 the number of jobs is the same as number of cores
* n_jobs: int | float
Number of jobs to use. With -1 the number of jobs is the same as number of cores.
Using a float between 0 and 1 will use that fraction of the total cores.
* progress_bar: bool
If True, a progress bar is printed
* mp_context: "fork" | "spawn" | None, default: None
Expand Down Expand Up @@ -60,7 +61,7 @@


def fix_job_kwargs(runtime_job_kwargs):
from .globals import get_global_job_kwargs
from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set

job_kwargs = get_global_job_kwargs()

Expand Down Expand Up @@ -99,6 +100,15 @@ def fix_job_kwargs(runtime_job_kwargs):

job_kwargs["n_jobs"] = max(n_jobs, 1)

if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_set_global_job_kwargs_set():
warnings.warn(
"`n_jobs` is not set so parallel processing is disabled! "
"To speed up computations, it is recommended to set n_jobs either "
"globally (with the `spikeinterface.set_global_job_kwargs()` function) or "
"locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` "
"for more information about job_kwargs."
)

return job_kwargs


Expand Down
12 changes: 12 additions & 0 deletions src/spikeinterface/core/tests/test_globals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import warnings
from pathlib import Path

from spikeinterface import (
Expand Down Expand Up @@ -39,11 +40,22 @@ def test_global_tmp_folder():
def test_global_job_kwargs():
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
global_job_kwargs = get_global_job_kwargs()

# test warning when not setting n_jobs and calling fix_job_kwargs
with pytest.warns(UserWarning):
job_kwargs_split = fix_job_kwargs({})

assert global_job_kwargs == dict(
n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
set_global_job_kwargs(**job_kwargs)
assert get_global_job_kwargs() == job_kwargs

# after setting global job kwargs, fix_job_kwargs should not raise a warning
with warnings.catch_warnings():
warnings.simplefilter("error")
job_kwargs_split = fix_job_kwargs({})

# test updating only one field
partial_job_kwargs = dict(n_jobs=2)
set_global_job_kwargs(**partial_job_kwargs)
Expand Down
9 changes: 4 additions & 5 deletions src/spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

from spikeinterface.core import load_extractor, BaseRecordingSnippets
from spikeinterface.core.core_tools import check_json
from spikeinterface.core.globals import get_global_job_kwargs
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from .utils import SpikeSortingError, ShellScript


default_job_kwargs = {"n_jobs": -1}

default_job_kwargs_description = {
"n_jobs": "Number of jobs (when saving ti binary) - default -1 (all cores)",
"chunk_size": "Number of samples per chunk (when saving ti binary) - default global",
"n_jobs": "Number of jobs (when saving to binary) - default global",
"chunk_size": "Number of samples per chunk (when saving to binary) - default global",
"chunk_memory": "Memory usage for each job (e.g. '100M', '1G') (when saving to binary) - default global",
"total_memory": "Total memory usage (e.g. '500M', '2G') (when saving to binary) - default global",
"chunk_duration": "Chunk duration in s if float or with units if str (e.g. '1s', '500ms') (when saving to binary)"
Expand Down Expand Up @@ -156,7 +155,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo
def default_params(cls):
p = copy.deepcopy(cls._default_params)
if cls.requires_binary_data:
job_kwargs = fix_job_kwargs(default_job_kwargs)
job_kwargs = get_global_job_kwargs()
p.update(job_kwargs)
return p

Expand Down

0 comments on commit 72c7717

Please sign in to comment.