Skip to content

[bugfix, enhancement] Address affinity bug by using threadpoolctl/joblib for n_jobs dispatching #2364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 51 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
c78859f
Update _n_jobs_support.py
icfaust Mar 17, 2025
e207110
Update test_run_to_run_stability.py
icfaust Mar 17, 2025
043b09d
Update test_n_jobs_support.py
icfaust Mar 17, 2025
66b0b6d
add changes
icfaust Mar 17, 2025
bc66055
add other changes
icfaust Mar 17, 2025
280c0e0
add an affinity test
icfaust Mar 17, 2025
eb0df7f
reduce lines
icfaust Mar 17, 2025
2403e6d
use pylance
icfaust Mar 17, 2025
009348b
further fixes
icfaust Mar 17, 2025
29c318f
better docs
icfaust Mar 17, 2025
ed726de
better docs
icfaust Mar 18, 2025
2b58453
mark
icfaust Mar 18, 2025
78d07bb
Update _n_jobs_support.py
icfaust Mar 18, 2025
8da0891
Update _n_jobs_support.py
icfaust Mar 18, 2025
79ced00
Update test_n_jobs_support.py
icfaust Mar 18, 2025
e021335
Update _n_jobs_support.py
icfaust Mar 18, 2025
30f822a
Update test_n_jobs_support.py
icfaust Mar 18, 2025
6d02aea
Update _n_jobs_support.py
icfaust Mar 18, 2025
84f91ac
Update test_n_jobs_support.py
icfaust Mar 18, 2025
a2a499a
Update _n_jobs_support.py
icfaust Mar 18, 2025
ac16042
Update _n_jobs_support.py
icfaust Mar 18, 2025
e6fdd80
Update incremental_linear.py
icfaust Mar 18, 2025
04075dc
Update incremental_ridge.py
icfaust Mar 18, 2025
dd798fa
Update test_n_jobs_support.py
icfaust Mar 18, 2025
70d613e
Update _n_jobs_support.py
icfaust Mar 19, 2025
1b121a5
Update test_run_to_run_stability.py
icfaust Mar 20, 2025
7bd1fcb
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust Mar 20, 2025
bbb2337
Update test_run_to_run_stability.py
icfaust Mar 21, 2025
a3ceaf3
Update requirements-test.txt
icfaust Mar 23, 2025
97a906f
Update requirements-test.txt
icfaust Mar 23, 2025
1948e7d
Update requirements-test.txt
icfaust Mar 23, 2025
62c7d9f
Update requirements-test.txt
icfaust Mar 23, 2025
ce79ace
Update requirements-test.txt
icfaust Mar 24, 2025
8765c0a
Merge branch 'main' into dev/njobs_fix
icfaust Mar 24, 2025
e2fa126
return values, and reduce test
icfaust Mar 24, 2025
1ca56cd
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust Apr 8, 2025
ab1c1eb
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust Apr 20, 2025
e9b5da5
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust May 1, 2025
f979da3
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust May 26, 2025
b56729e
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust May 31, 2025
2b2749c
Update data_conversion.cpp
icfaust Jun 13, 2025
3f0155b
Update table.cpp
icfaust Jun 13, 2025
385ad80
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust Jun 13, 2025
052bcdd
Update data_conversion.cpp
icfaust Jun 13, 2025
c638473
Update test_memory_usage.py
icfaust Jun 13, 2025
e00feb3
Update _n_jobs_support.py
icfaust Jun 17, 2025
627df75
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust Jun 18, 2025
3bf60b5
Merge branch 'uxlfoundation:main' into dev/njobs_fix
icfaust Jun 20, 2025
df38221
Update test_n_jobs_support.py
icfaust Jun 20, 2025
51ccef3
Update run_test.sh
icfaust Jun 22, 2025
17de438
Update ci.yml
icfaust Jun 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ jobs:
- name: Set Environment Variables
id: set-env
run: |
python -c "import os; print(len(os.sched_getaffinity(0)))"
echo "NO_DIST=1" >> "$GITHUB_ENV"
# enable coverage report generation
echo "COVERAGE_RCFILE=$(readlink -f .coveragerc)" >> "$GITHUB_ENV"
Expand Down
9 changes: 5 additions & 4 deletions conda-recipe/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,19 @@ function generate_pytest_args {
printf -- "${ARGS[*]}"
}

${PYTHON} -c "from daal4py import num_threads;print(f'threads={num_threads()}:0')"
${PYTHON} -c "from sklearnex import patch_sklearn; patch_sklearn()"
return_code=$(($return_code + $?))

${PYTHON} -c "from daal4py import num_threads;print(f'threads={num_threads()}:1')"
pytest --verbose -s "${sklex_root}/tests" $@ $(generate_pytest_args legacy)
return_code=$(($return_code + $?))

${PYTHON} -c "from daal4py import num_threads;print(f'threads={num_threads()}:2')"
pytest --verbose --pyargs daal4py $@ $(generate_pytest_args daal4py)
return_code=$(($return_code + $?))

${PYTHON} -c "from daal4py import num_threads;print(f'threads={num_threads()}:3')"
pytest --verbose --pyargs sklearnex $@ $(generate_pytest_args sklearnex)
return_code=$(($return_code + $?))

${PYTHON} -c "from daal4py import num_threads;print(f'threads={num_threads()}:4')"
pytest --verbose --pyargs onedal $@ $(generate_pytest_args onedal)
return_code=$(($return_code + $?))

Expand Down
108 changes: 47 additions & 61 deletions daal4py/sklearn/_n_jobs_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,50 @@
import threading
from functools import wraps
from inspect import Parameter, signature
from multiprocessing import cpu_count
from numbers import Integral
from warnings import warn

import threadpoolctl
from joblib import cpu_count

from daal4py import daalinit as set_n_threads
from daal4py import num_threads as get_n_threads
from daal4py import _get__daal_link_version__, daalinit, num_threads

from ._utils import sklearn_check_version

if sklearn_check_version("1.2"):
from sklearn.utils._param_validation import validate_parameter_constraints
else:

def validate_parameter_constraints(n_jobs):
if n_jobs is not None and not isinstance(n_jobs, Integral):
raise TypeError(
f"n_jobs must be an instance of int, not {n_jobs.__class__.__name__}."
)


class oneDALLibController(threadpoolctl.LibController):
user_api = "onedal"
internal_api = "onedal"

filename_prefixes = ("libonedal_thread", "libonedal")

def get_num_threads(self):
return num_threads()

def set_num_threads(self, nthreads):
Copy link
Contributor

@david-cortes-intel david-cortes-intel Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this setting would apply globally, which could lead to race conditions if users call this in parallel, for example through some framework that would parallelize estimator calls.

Could it somehow get a mutex (or use atomic ops) either here or on the oneDAL side?

Also, would be better to add a warning that the setting is changed at a global level, so that a user would not try to call these inside multi-threaded code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually on a further look, it does already have a mutex on the daal side. Still better to document this behavior being global.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will do!

daalinit(nthreads)

def get_version(self):
return _get__daal_link_version__()


threadpoolctl.register(oneDALLibController)

# Note: getting controller in global scope of this module is required
# to avoid overheads by its initialization per each function call
threadpool_controller = threadpoolctl.ThreadpoolController()


def get_suggested_n_threads(n_cpus):
"""
Function to get `n_threads` limit
if `n_jobs` is set in upper parallelization context.
Usually, limit is equal to `n_logical_cpus` // `n_jobs`.
Returns None if limit is not set.
"""
n_threads_map = {
lib_ctl.internal_api: lib_ctl.get_num_threads()
for lib_ctl in threadpool_controller.lib_controllers
if lib_ctl.internal_api != "mkl"
}
# openBLAS is limited to 24, 64 or 128 threads by default
# depending on SW/HW configuration.
# thus, these numbers of threads from openBLAS are uninformative
if "openblas" in n_threads_map and n_threads_map["openblas"] in [24, 64, 128]:
del n_threads_map["openblas"]
# remove default values equal to n_cpus as uninformative
for backend in list(n_threads_map.keys()):
if n_threads_map[backend] == n_cpus:
del n_threads_map[backend]
if len(n_threads_map) > 0:
return min(n_threads_map.values())
else:
return None


def _run_with_n_jobs(method):
"""
Decorator for running of methods containing oneDAL kernels with 'n_jobs'.
Expand All @@ -79,59 +76,46 @@ def _run_with_n_jobs(method):
@wraps(method)
def n_jobs_wrapper(self, *args, **kwargs):
# threading parallel backend branch
if not isinstance(threading.current_thread(), threading._MainThread):
warn(
"'Threading' parallel backend is not supported by "
"Extension for Scikit-learn*. "
"Falling back to usage of all available threads."
)
result = method(self, *args, **kwargs)
return result
# multiprocess parallel backends branch
# preemptive validation of n_jobs parameter is required
# because '_run_with_n_jobs' decorator is applied on top of method
# where validation takes place
if sklearn_check_version("1.2") and hasattr(self, "_parameter_constraints"):
if sklearn_check_version("1.2"):
validate_parameter_constraints(
parameter_constraints={"n_jobs": self._parameter_constraints["n_jobs"]},
params={"n_jobs": self.n_jobs},
caller_name=self.__class__.__name__,
)
# search for specified n_jobs
n_jobs = self.n_jobs
n_cpus = cpu_count()
else:
validate_parameter_constraints(self.n_jobs)

# receive n_threads limitation from upper parallelism context
# using `threadpoolctl.ThreadpoolController`
n_threads = get_suggested_n_threads(n_cpus)
# get real `n_jobs` number of threads for oneDAL
# using sklearn rules and `n_threads` from upper parallelism context
if n_jobs is None or n_jobs == 0:
if n_threads is None:
# default branch with no setting for n_jobs
return method(self, *args, **kwargs)
else:
n_jobs = n_threads
elif n_jobs < 0:
if n_threads is None:
n_jobs = max(1, n_cpus + n_jobs + 1)
else:
n_jobs = max(1, n_threads + n_jobs + 1)
# branch with set n_jobs
old_n_threads = get_n_threads()
if n_jobs == old_n_threads:

if self.n_jobs:
n_jobs = (
self.n_jobs if self.n_jobs > 0 else max(1, cpu_count() + self.n_jobs + 1)
)
elif self.n_jobs == 0:
# This is a small variation on joblib's equivalent error
raise ValueError("n_jobs == 0 has no meaning")
else:
return method(self, *args, **kwargs)

try:
# n_jobs value is attempting to be set
if (old_n_threads := num_threads()) != n_jobs:
logger = logging.getLogger("sklearnex")
cl = self.__class__
logger.debug(
f"{cl.__module__}.{cl.__name__}.{method.__name__}: "
f"setting {n_jobs} threads (previous - {old_n_threads})"
)
set_n_threads(n_jobs)
with threadpool_controller.limit(limits=n_jobs, user_api="onedal"):
return method(self, *args, **kwargs)
else:
return method(self, *args, **kwargs)
finally:
set_n_threads(old_n_threads)

return n_jobs_wrapper

Expand Down Expand Up @@ -185,6 +169,8 @@ def class_wrapper(original_class):
):
parameter_constraints = original_class._parameter_constraints
if "n_jobs" not in parameter_constraints:
# n_jobs = 0 is not allowed, but it is handled elsewhere
# This definition matches scikit-learn
parameter_constraints["n_jobs"] = [Integral, None]

@wraps(original_init)
Expand Down
4 changes: 4 additions & 0 deletions daal4py/sklearn/ensemble/AdaBoostClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@

@control_n_jobs(decorated_methods=["fit", "predict"])
class AdaBoostClassifier(ClassifierMixin, BaseEstimator):

if sklearn_check_version("1.2"):
_parameter_constraints = {}

def __init__(
self,
split_criterion="gini",
Expand Down
4 changes: 4 additions & 0 deletions daal4py/sklearn/ensemble/GBTDAAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@


class GBTDAALBase(BaseEstimator, d4p.mb.GBTDAALBaseModel):

if sklearn_check_version("1.2"):
_parameter_constraints = {}

def __init__(
self,
split_method="inexact",
Expand Down
7 changes: 4 additions & 3 deletions onedal/datatypes/numpy/data_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,10 @@ dal::table convert_to_table(py::object inp_obj, py::object queue, bool recursed)
return res;
}

static void free_capsule(PyObject *cap) {
template <class T>
void free_capsule(PyObject *cap) {
// TODO: check safe cast
dal::base *stored_array = static_cast<dal::base *>(PyCapsule_GetPointer(cap, NULL));
dal::array<T> *stored_array = static_cast<dal::array<T> *>(PyCapsule_GetPointer(cap, NULL));
if (stored_array) {
delete stored_array;
}
Expand Down Expand Up @@ -304,7 +305,7 @@ static PyObject *convert_to_numpy_impl(
throw std::invalid_argument("Conversion to numpy array failed");

void *opaque_value = static_cast<void *>(new dal::array<T>(host_array));
PyObject *cap = PyCapsule_New(opaque_value, NULL, free_capsule);
PyObject *cap = PyCapsule_New(opaque_value, NULL, free_capsule<T>);
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), cap);
return obj;
}
Expand Down
4 changes: 2 additions & 2 deletions onedal/datatypes/table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ ONEDAL_PY_INIT_MODULE(table) {
return numpy::convert_to_table(obj, queue);
});

m.def("from_table", [](const dal::table& t) -> py::handle {
m.def("from_table", [](const dal::table& t) -> py::object {
auto* obj_ptr = numpy::convert_to_pyobject(t);
return obj_ptr;
return py::reinterpret_steal<py::object>(obj_ptr);
});
m.def("dlpack_memory_order", &dlpack::dlpack_memory_order);
py::enum_<DLDeviceType>(m, "DLDeviceType")
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/linear_model/incremental_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class IncrementalLinearRegression(
_parameter_constraints: dict = {
"fit_intercept": ["boolean"],
"copy_X": ["boolean"],
"n_jobs": [Interval(numbers.Integral, -1, None, closed="left"), None],
"n_jobs": [numbers.Integral, None],
"batch_size": [Interval(numbers.Integral, 1, None, closed="left"), None],
}

Expand Down
2 changes: 1 addition & 1 deletion sklearnex/linear_model/incremental_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class IncrementalRidge(MultiOutputMixin, RegressorMixin, oneDALEstimator, BaseEs
"fit_intercept": ["boolean"],
"alpha": [Interval(numbers.Real, 0, None, closed="left")],
"copy_X": ["boolean"],
"n_jobs": [Interval(numbers.Integral, -1, None, closed="left"), None],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewers: let me know if I should remove n_jobs from this and the other estimator entirely in order to minimize maintenance cost (all other estimators get the kwarg from _n_jobs_support.py, so far that I see).

"n_jobs": [numbers.Integral, None],
"batch_size": [Interval(numbers.Integral, 1, None, closed="left"), None],
}

Expand Down
2 changes: 1 addition & 1 deletion sklearnex/tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def gen_functions(functions):

data_shapes = [
pytest.param((1000, 100), id="(1000, 100)"),
pytest.param((2000, 50), id="(2000, 50)"),
pytest.param((2000, 40), id="(2000, 40)"),
]

EXTRA_MEMORY_THRESHOLD = 0.15
Expand Down
42 changes: 33 additions & 9 deletions sklearnex/tests/test_n_jobs_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@

import inspect
import logging
from multiprocessing import cpu_count
import os

import pytest
from joblib import cpu_count
from sklearn.datasets import make_classification
from sklearn.exceptions import NotFittedError
from threadpoolctl import threadpool_info

from sklearnex.tests.utils import (
PATCHED_MODELS,
SPECIAL_INSTANCES,
call_method,
gen_dataset,
gen_models_info,
)
from sklearnex.tests.utils import PATCHED_MODELS, SPECIAL_INSTANCES, call_method

_X, _Y = make_classification(n_samples=40, n_features=4, random_state=42)

Expand All @@ -49,7 +45,7 @@ def _check_n_jobs_entry_in_logs(records, function_name, n_jobs):
if f"{function_name}: setting {expected_n_jobs} threads" in rec:
return True
# False if n_jobs is set and not found in logs
return n_jobs is None
return n_jobs is None or expected_n_jobs == cpu_count()


@pytest.mark.parametrize("estimator", {**PATCHED_MODELS, **SPECIAL_INSTANCES}.keys())
Expand Down Expand Up @@ -106,3 +102,31 @@ def test_n_jobs_support(estimator, n_jobs, caplog):

messages = [msg.message for msg in caplog.records]
assert _check_n_jobs_entry_in_logs(messages, method_name, n_jobs)


@pytest.mark.skipif(
not hasattr(os, "sched_setaffinity") or len(os.sched_getaffinity(0)) < 4,
reason="python CPU affinity control unavailable or too few threads",
)
@pytest.mark.parametrize("estimator", {**PATCHED_MODELS, **SPECIAL_INSTANCES}.keys())
def test_n_jobs_affinity(estimator, caplog):
# verify that n_jobs 1) starts at default value of cpu_count
# 2) respects os.sched_setaffinity on supported machines
n_t = next(i for i in threadpool_info() if i["user_api"] == "onedal")["num_threads"]

# get affinity mask of calling process
mask = os.sched_getaffinity(0)
# by default, oneDAL should match the number of threads made available to the sklearnex pytest suite
# This is currently disabled due to thread setting occurring in test_run_to_run_stability
# assert len(mask) == n_t

try:
# use half of the available threads
newmask = set(list(mask)[: len(mask) // 2])
os.sched_setaffinity(0, newmask)
# -2 is used as this forces n_jobs to be based on cpu_count and must value match in test
test_n_jobs_support(estimator, -2, caplog)

finally:
# reset affinity mask no matter what
os.sched_setaffinity(0, mask)
Loading