Skip to content

Commit 6e18d3a

Browse files
authored
[pyspark] Handle the device parameter in pyspark. (dmlc#9390)
- Handle the new `device` parameter in PySpark. - Deprecate the old `use_gpu` parameter.
1 parent 2a0ff20 commit 6e18d3a

File tree

10 files changed

+244
-169
lines changed

10 files changed

+244
-169
lines changed

doc/tutorials/spark_estimator.rst

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ We can create a ``SparkXGBRegressor`` estimator like:
3535
)
3636
3737
38-
The above snippet creates a spark estimator which can fit on a spark dataset,
39-
and return a spark model that can transform a spark dataset and generate dataset
40-
with prediction column. We can set almost all of xgboost sklearn estimator parameters
41-
as ``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden
42-
in spark estimator, and some parameters are replaced with pyspark specific parameters
43-
such as ``weight_col``, ``validation_indicator_col``, ``use_gpu``, for details please see
44-
``SparkXGBRegressor`` doc.
38+
The above snippet creates a spark estimator which can fit on a spark dataset, and return a
39+
spark model that can transform a spark dataset and generate dataset with prediction
40+
column. We can set almost all of xgboost sklearn estimator parameters as
41+
``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden in
42+
spark estimator, and some parameters are replaced with pyspark specific parameters such as
43+
``weight_col``, ``validation_indicator_col``, for details please see ``SparkXGBRegressor``
44+
doc.
4545

4646
The following code snippet shows how to train a spark xgboost regressor model,
4747
first we need to prepare a training dataset as a spark dataframe contains
@@ -88,7 +88,7 @@ XGBoost PySpark fully supports GPU acceleration. Users are not only able to enab
8888
efficient training but also utilize their GPUs for the whole PySpark pipeline including
8989
ETL and inference. In below sections, we will walk through an example of training on a
9090
PySpark standalone GPU cluster. To get started, first we need to install some additional
91-
packages, then we can set the ``use_gpu`` parameter to ``True``.
91+
packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
9292

9393
Prepare the necessary packages
9494
==============================
@@ -128,7 +128,7 @@ Write your PySpark application
128128
==============================
129129

130130
Below snippet is a small example for training xgboost model with PySpark. Notice that we are
131-
using a list of feature names and the additional parameter ``use_gpu``:
131+
using a list of feature names and the additional parameter ``device``:
132132

133133
.. code-block:: python
134134
@@ -148,12 +148,12 @@ using a list of feature names and the additional parameter ``use_gpu``:
148148
# get a list with feature column names
149149
feature_names = [x.name for x in train_df.schema if x.name != label_name]
150150
151-
# create a xgboost pyspark regressor estimator and set use_gpu=True
151+
# create a xgboost pyspark regressor estimator and set device="cuda"
152152
regressor = SparkXGBRegressor(
153153
features_col=feature_names,
154154
label_col=label_name,
155155
num_workers=2,
156-
use_gpu=True,
156+
device="cuda",
157157
)
158158
159159
# train and return the model
@@ -163,6 +163,7 @@ using a list of feature names and the additional parameter ``use_gpu``:
163163
predict_df = model.transform(test_df)
164164
predict_df.show()
165165
166+
Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
166167

167168
Submit the PySpark application
168169
==============================

python-package/xgboost/core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,27 @@ def _check_call(ret: int) -> None:
276276
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
277277

278278

279+
def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
280+
"""Validate parameters in distributed environments."""
281+
device = kwargs.get("device", None)
282+
if device and not isinstance(device, str):
283+
msg = "Invalid type for the `device` parameter"
284+
msg += _expect((str,), type(device))
285+
raise TypeError(msg)
286+
287+
if device and device.find(":") != -1:
288+
raise ValueError(
289+
"Distributed training doesn't support selecting device ordinal as GPUs are"
290+
" managed by the distributed framework. use `device=cuda` or `device=gpu`"
291+
" instead."
292+
)
293+
294+
if kwargs.get("booster", None) == "gblinear":
295+
raise NotImplementedError(
296+
f"booster `{kwargs['booster']}` is not supported for distributed training."
297+
)
298+
299+
279300
def build_info() -> dict:
280301
"""Build information of XGBoost. The returned value format is not stable. Also,
281302
please note that build time dependency is not the same as runtime dependency. For

python-package/xgboost/dask.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
Metric,
7171
Objective,
7272
QuantileDMatrix,
73+
_check_distributed_params,
7374
_deprecate_positional_args,
7475
_expect,
7576
)
@@ -924,17 +925,7 @@ async def _train_async(
924925
) -> Optional[TrainReturnT]:
925926
workers = _get_workers_from_data(dtrain, evals)
926927
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)
927-
928-
if params.get("booster", None) == "gblinear":
929-
raise NotImplementedError(
930-
f"booster `{params['booster']}` is not yet supported for dask."
931-
)
932-
device = params.get("device", None)
933-
if device and device.find(":") != -1:
934-
raise ValueError(
935-
"The dask interface for XGBoost doesn't support selecting specific device"
936-
" ordinal. Use `device=cpu` or `device=cuda` instead."
937-
)
928+
_check_distributed_params(params)
938929

939930
def dispatched_train(
940931
parameters: Dict,

python-package/xgboost/sklearn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,13 +1004,17 @@ def fit(
10041004
Validation metrics will help us track the performance of the model.
10051005
10061006
eval_metric : str, list of str, or callable, optional
1007+
10071008
.. deprecated:: 1.6.0
1008-
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
1009+
1010+
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
10091011
10101012
early_stopping_rounds : int
1013+
10111014
.. deprecated:: 1.6.0
1012-
Use `early_stopping_rounds` in :py:meth:`__init__` or
1013-
:py:meth:`set_params` instead.
1015+
1016+
Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params`
1017+
instead.
10141018
verbose :
10151019
If `verbose` is True and an evaluation set is used, the evaluation metric
10161020
measured on the validation set is printed to stdout at each boosting stage.

python-package/xgboost/spark/core.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
import xgboost
6161
from xgboost import XGBClassifier
6262
from xgboost.compat import is_cudf_available
63-
from xgboost.core import Booster
63+
from xgboost.core import Booster, _check_distributed_params
6464
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
6565
from xgboost.training import train as worker_train
6666

@@ -92,6 +92,7 @@
9292
get_class_name,
9393
get_logger,
9494
serialize_booster,
95+
use_cuda,
9596
)
9697

9798
# Put pyspark specific params here, they won't be passed to XGBoost.
@@ -108,7 +109,6 @@
108109
"arbitrary_params_dict",
109110
"force_repartition",
110111
"num_workers",
111-
"use_gpu",
112112
"feature_names",
113113
"features_cols",
114114
"enable_sparse_data_optim",
@@ -132,8 +132,7 @@
132132
_inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()}
133133

134134
_unsupported_xgb_params = [
135-
"gpu_id", # we have "use_gpu" pyspark param instead.
136-
"device", # we have "use_gpu" pyspark param instead.
135+
"gpu_id", # we have "device" pyspark param instead.
137136
"enable_categorical", # Use feature_types param to specify categorical feature instead
138137
"use_label_encoder",
139138
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
@@ -198,11 +197,24 @@ class _SparkXGBParams(
198197
"The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.",
199198
TypeConverters.toInt,
200199
)
200+
device = Param(
201+
Params._dummy(),
202+
"device",
203+
(
204+
"The device type for XGBoost executors. Available options are `cpu`,`cuda`"
205+
" and `gpu`. Set `device` to `cuda` or `gpu` if the executors are running "
206+
"on GPU instances. Currently, only one GPU per task is supported."
207+
),
208+
TypeConverters.toString,
209+
)
201210
use_gpu = Param(
202211
Params._dummy(),
203212
"use_gpu",
204-
"A boolean variable. Set use_gpu=true if the executors "
205-
+ "are running on GPU instances. Currently, only one GPU per task is supported.",
213+
(
214+
"Deprecated, use `device` instead. A boolean variable. Set use_gpu=true "
215+
"if the executors are running on GPU instances. Currently, only one GPU per"
216+
" task is supported."
217+
),
206218
TypeConverters.toBoolean,
207219
)
208220
force_repartition = Param(
@@ -336,10 +348,20 @@ def _validate_params(self) -> None:
336348
f"It cannot be less than 1 [Default is 1]"
337349
)
338350

351+
tree_method = self.getOrDefault(self.getParam("tree_method"))
352+
if (
353+
self.getOrDefault(self.use_gpu) or use_cuda(self.getOrDefault(self.device))
354+
) and not _can_use_qdm(tree_method):
355+
raise ValueError(
356+
f"The `{tree_method}` tree method is not supported on GPU."
357+
)
358+
339359
if self.getOrDefault(self.features_cols):
340-
if not self.getOrDefault(self.use_gpu):
360+
if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault(
361+
self.use_gpu
362+
):
341363
raise ValueError(
342-
"features_col param with list value requires enabling use_gpu."
364+
"features_col param with list value requires `device=cuda`."
343365
)
344366

345367
if self.getOrDefault("objective") is not None:
@@ -392,17 +414,7 @@ def _validate_params(self) -> None:
392414
"`pyspark.ml.linalg.Vector` type."
393415
)
394416

395-
if self.getOrDefault(self.use_gpu):
396-
tree_method = self.getParam("tree_method")
397-
if (
398-
self.getOrDefault(tree_method) is not None
399-
and self.getOrDefault(tree_method) != "gpu_hist"
400-
):
401-
raise ValueError(
402-
f"tree_method should be 'gpu_hist' or None when use_gpu is True,"
403-
f"found {self.getOrDefault(tree_method)}."
404-
)
405-
417+
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
406418
gpu_per_task = (
407419
_get_spark_session()
408420
.sparkContext.getConf()
@@ -424,8 +436,8 @@ def _validate_params(self) -> None:
424436
# so it's okay for printing the below warning instead of checking the real
425437
# gpu numbers and raising the exception.
426438
get_logger(self.__class__.__name__).warning(
427-
"You enabled use_gpu in spark local mode. Please make sure your local node "
428-
"has at least %d GPUs",
439+
"You enabled GPU in spark local mode. Please make sure your local "
440+
"node has at least %d GPUs",
429441
self.getOrDefault(self.num_workers),
430442
)
431443
else:
@@ -558,6 +570,7 @@ def __init__(self) -> None:
558570
# they are added in `setParams`.
559571
self._setDefault(
560572
num_workers=1,
573+
device="cpu",
561574
use_gpu=False,
562575
force_repartition=False,
563576
repartition_random_shuffle=False,
@@ -566,9 +579,7 @@ def __init__(self) -> None:
566579
arbitrary_params_dict={},
567580
)
568581

569-
def setParams(
570-
self, **kwargs: Dict[str, Any]
571-
) -> None: # pylint: disable=invalid-name
582+
def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name
572583
"""
573584
Set params for the estimator.
574585
"""
@@ -613,6 +624,8 @@ def setParams(
613624
)
614625
raise ValueError(err_msg)
615626
_extra_params[k] = v
627+
628+
_check_distributed_params(kwargs)
616629
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
617630
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
618631

@@ -709,9 +722,6 @@ def _get_distributed_train_params(self, dataset: DataFrame) -> Dict[str, Any]:
709722
# TODO: support "num_parallel_tree" for random forest
710723
params["num_boost_round"] = self.getOrDefault("n_estimators")
711724

712-
if self.getOrDefault(self.use_gpu):
713-
params["tree_method"] = "gpu_hist"
714-
715725
return params
716726

717727
@classmethod
@@ -883,8 +893,9 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
883893
dmatrix_kwargs,
884894
) = self._get_xgb_parameters(dataset)
885895

886-
use_gpu = self.getOrDefault(self.use_gpu)
887-
896+
run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(
897+
self.use_gpu
898+
)
888899
is_local = _is_local(_get_spark_session().sparkContext)
889900

890901
num_workers = self.getOrDefault(self.num_workers)
@@ -903,7 +914,7 @@ def _train_booster(
903914
dev_ordinal = None
904915
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
905916

906-
if use_gpu:
917+
if run_on_gpu:
907918
dev_ordinal = (
908919
context.partitionId() if is_local else _get_gpu_id(context)
909920
)

0 commit comments

Comments
 (0)