-
Notifications
You must be signed in to change notification settings - Fork 183
[bugfix, enhancement] Address affinity bug by using threadpoolctl/joblib for n_jobs dispatching #2364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[bugfix, enhancement] Address affinity bug by using threadpoolctl/joblib for n_jobs dispatching #2364
Changes from 14 commits
c78859f
e207110
043b09d
66b0b6d
bc66055
280c0e0
eb0df7f
2403e6d
009348b
29c318f
ed726de
2b58453
78d07bb
8da0891
79ced00
e021335
30f822a
6d02aea
84f91ac
a2a499a
ac16042
e6fdd80
04075dc
dd798fa
70d613e
1b121a5
7bd1fcb
bbb2337
a3ceaf3
97a906f
1948e7d
62c7d9f
ce79ace
8765c0a
e2fa126
1ca56cd
ab1c1eb
e9b5da5
f979da3
b56729e
2b2749c
3f0155b
385ad80
052bcdd
c638473
e00feb3
627df75
3bf60b5
df38221
51ccef3
17de438
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,53 +19,50 @@ | |
import threading | ||
from functools import wraps | ||
from inspect import Parameter, signature | ||
from multiprocessing import cpu_count | ||
from numbers import Integral | ||
from warnings import warn | ||
|
||
import threadpoolctl | ||
from joblib import cpu_count | ||
|
||
from daal4py import daalinit as set_n_threads | ||
from daal4py import num_threads as get_n_threads | ||
from daal4py import _get__daal_link_version__, daalinit, num_threads | ||
|
||
from ._utils import sklearn_check_version | ||
|
||
if sklearn_check_version("1.2"): | ||
from sklearn.utils._param_validation import validate_parameter_constraints | ||
else: | ||
|
||
def validate_parameter_constraints(n_jobs): | ||
if n_jobs is not None and n_jobs.__class__ != int: | ||
raise TypeError( | ||
f"n_jobs must be an instance of int, not {n_jobs.__class__.__name__}." | ||
) | ||
|
||
|
||
class oneDALLibController(threadpoolctl.LibController): | ||
user_api = "oneDAL" | ||
david-cortes-intel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
internal_api = "oneDAL" | ||
|
||
filename_prefixes = ("libonedal_thread", "libonedal") | ||
|
||
def get_num_threads(self): | ||
return num_threads() | ||
|
||
def set_num_threads(self, nthreads): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand this setting would apply globally, which could lead to race conditions if users call this in parallel, for example through some framework that would parallelize estimator calls. Could it somehow get a mutex (or use atomic ops) either here or on the oneDAL side? Also, would be better to add a warning that the setting is changed at a global level, so that a user would not try to call these inside multi-threaded code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually on a further look, it does already have a mutex on the daal side. Still better to document this behavior being global. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, will do! |
||
daalinit(nthreads) | ||
|
||
def get_version(self): | ||
return _get__daal_link_version__ | ||
|
||
|
||
threadpoolctl.register(oneDALLibController) | ||
|
||
# Note: getting controller in global scope of this module is required | ||
# to avoid overheads by its initialization per each function call | ||
threadpool_controller = threadpoolctl.ThreadpoolController() | ||
|
||
|
||
def get_suggested_n_threads(n_cpus): | ||
""" | ||
Function to get `n_threads` limit | ||
if `n_jobs` is set in upper parallelization context. | ||
Usually, limit is equal to `n_logical_cpus` // `n_jobs`. | ||
Returns None if limit is not set. | ||
""" | ||
n_threads_map = { | ||
lib_ctl.internal_api: lib_ctl.get_num_threads() | ||
for lib_ctl in threadpool_controller.lib_controllers | ||
if lib_ctl.internal_api != "mkl" | ||
} | ||
# openBLAS is limited to 24, 64 or 128 threads by default | ||
# depending on SW/HW configuration. | ||
# thus, these numbers of threads from openBLAS are uninformative | ||
if "openblas" in n_threads_map and n_threads_map["openblas"] in [24, 64, 128]: | ||
del n_threads_map["openblas"] | ||
# remove default values equal to n_cpus as uninformative | ||
for backend in list(n_threads_map.keys()): | ||
if n_threads_map[backend] == n_cpus: | ||
del n_threads_map[backend] | ||
if len(n_threads_map) > 0: | ||
return min(n_threads_map.values()) | ||
else: | ||
return None | ||
|
||
|
||
def _run_with_n_jobs(method): | ||
""" | ||
Decorator for running of methods containing oneDAL kernels with 'n_jobs'. | ||
|
@@ -79,59 +76,42 @@ def _run_with_n_jobs(method): | |
@wraps(method) | ||
def n_jobs_wrapper(self, *args, **kwargs): | ||
# threading parallel backend branch | ||
if not isinstance(threading.current_thread(), threading._MainThread): | ||
warn( | ||
"'Threading' parallel backend is not supported by " | ||
"Intel(R) Extension for Scikit-learn*. " | ||
"Falling back to usage of all available threads." | ||
) | ||
result = method(self, *args, **kwargs) | ||
return result | ||
# multiprocess parallel backends branch | ||
# preemptive validation of n_jobs parameter is required | ||
# because '_run_with_n_jobs' decorator is applied on top of method | ||
# where validation takes place | ||
if sklearn_check_version("1.2") and hasattr(self, "_parameter_constraints"): | ||
if sklearn_check_version("1.2"): | ||
validate_parameter_constraints( | ||
parameter_constraints={"n_jobs": self._parameter_constraints["n_jobs"]}, | ||
params={"n_jobs": self.n_jobs}, | ||
caller_name=self.__class__.__name__, | ||
) | ||
# search for specified n_jobs | ||
n_jobs = self.n_jobs | ||
n_cpus = cpu_count() | ||
else: | ||
validate_parameter_constraints(self.n_jobs) | ||
|
||
# receive n_threads limitation from upper parallelism context | ||
# using `threadpoolctl.ThreadpoolController` | ||
n_threads = get_suggested_n_threads(n_cpus) | ||
# get real `n_jobs` number of threads for oneDAL | ||
# using sklearn rules and `n_threads` from upper parallelism context | ||
if n_jobs is None or n_jobs == 0: | ||
if n_threads is None: | ||
# default branch with no setting for n_jobs | ||
return method(self, *args, **kwargs) | ||
else: | ||
n_jobs = n_threads | ||
elif n_jobs < 0: | ||
if n_threads is None: | ||
n_jobs = max(1, n_cpus + n_jobs + 1) | ||
else: | ||
n_jobs = max(1, n_threads + n_jobs + 1) | ||
# branch with set n_jobs | ||
old_n_threads = get_n_threads() | ||
if n_jobs == old_n_threads: | ||
return method(self, *args, **kwargs) | ||
|
||
try: | ||
if not self.n_jobs: | ||
n_jobs = cpu_count() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this later on get limited to the number of physical cores from oneDAL side? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll be honest, I'm not 100% sure yet. The default in the threading.h in daal will set it to the number of CPUs, but with the affinity I didn't spend the full time to track the default setting there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe @Alexsandruss could comment here on whether it'd end up limited to number of physical cores somewhere else? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like setting the number of threads like this would not result in that number later on getting limited to the number of physical cores. How about passing argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tried adding a line to print this value here: https://github.com/uxlfoundation/oneDAL/blob/31cafec9950f1db352b639dafad5875971ca00fe/cpp/daal/src/threading/threading.cpp#L267 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although from some further testing, this behavior also appears to be the same in the current main branch. |
||
else: | ||
n_jobs = ( | ||
self.n_jobs if self.n_jobs > 0 else max(1, cpu_count() + self.n_jobs + 1) | ||
david-cortes-intel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
if (old_n_threads := num_threads()) != n_jobs: | ||
logger = logging.getLogger("sklearnex") | ||
cl = self.__class__ | ||
logger.debug( | ||
f"{cl.__module__}.{cl.__name__}.{method.__name__}: " | ||
f"setting {n_jobs} threads (previous - {old_n_threads})" | ||
) | ||
set_n_threads(n_jobs) | ||
with threadpool_controller.limit(limits=n_jobs, user_api="oneDAL"): | ||
return method(self, *args, **kwargs) | ||
else: | ||
return method(self, *args, **kwargs) | ||
finally: | ||
set_n_threads(old_n_threads) | ||
|
||
return n_jobs_wrapper | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,9 +55,6 @@ | |
sklearn_clone_dict, | ||
) | ||
|
||
# to reproduce errors even in CI | ||
d4p.daalinit(nthreads=100) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing this causes all sorts of memory leak check test failures, not just in windows and not just with pandas. |
||
|
||
_dataset_dict = { | ||
"classification": [ | ||
partial(load_iris, return_X_y=True), | ||
|
Uh oh!
There was an error while loading. Please reload this page.