Skip to content

Commit e228c1a

Browse files
authored
[EM] Make page concatenation optional. (dmlc#10826)
This PR introduces a new parameter `extmem_concat_pages` to make the page concatenation optional for GPU hist. In addition, the document is updated for the new GPU-based external memory.
1 parent 215da76 commit e228c1a

31 files changed

+687
-385
lines changed

demo/guide-python/external_memory.py

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
1111
See :doc:`the tutorial </tutorials/external_memory>` for more details.
1212
13+
.. versionchanged:: 3.0.0
14+
15+
Added :py:class:`~xgboost.ExtMemQuantileDMatrix`.
16+
1317
"""
1418

19+
import argparse
1520
import os
1621
import tempfile
1722
from typing import Callable, List, Tuple
@@ -43,30 +48,40 @@ def make_batches(
4348
class Iterator(xgboost.DataIter):
4449
"""A custom iterator for loading files in batches."""
4550

46-
def __init__(self, file_paths: List[Tuple[str, str]]) -> None:
51+
def __init__(self, device: str, file_paths: List[Tuple[str, str]]) -> None:
52+
self.device = device
53+
4754
self._file_paths = file_paths
4855
self._it = 0
49-
# XGBoost will generate some cache files under current directory with the prefix
50-
# "cache"
56+
# XGBoost will generate some cache files under the current directory with the
57+
# prefix "cache"
5158
super().__init__(cache_prefix=os.path.join(".", "cache"))
5259

5360
def load_file(self) -> Tuple[np.ndarray, np.ndarray]:
61+
"""Load a single batch of data."""
5462
X_path, y_path = self._file_paths[self._it]
55-
X = np.load(X_path)
56-
y = np.load(y_path)
63+
# When the `ExtMemQuantileDMatrix` is used, the device must match. This
64+
# constraint will be relaxed in the future.
65+
if self.device == "cpu":
66+
X = np.load(X_path)
67+
y = np.load(y_path)
68+
else:
69+
X = cp.load(X_path)
70+
y = cp.load(y_path)
71+
5772
assert X.shape[0] == y.shape[0]
5873
return X, y
5974

6075
def next(self, input_data: Callable) -> int:
61-
"""Advance the iterator by 1 step and pass the data to XGBoost. This function is
62-
called by XGBoost during the construction of ``DMatrix``
76+
"""Advance the iterator by 1 step and pass the data to XGBoost. This function
77+
is called by XGBoost during the construction of ``DMatrix``
6378
6479
"""
6580
if self._it == len(self._file_paths):
6681
# return 0 to let XGBoost know this is the end of iteration
6782
return 0
6883

69-
# input_data is a function passed in by XGBoost who has the similar signature to
84+
# input_data is a function passed in by XGBoost and has the similar signature to
7085
# the ``DMatrix`` constructor.
7186
X, y = self.load_file()
7287
input_data(data=X, label=y)
@@ -78,27 +93,74 @@ def reset(self) -> None:
7893
self._it = 0
7994

8095

81-
def main(tmpdir: str) -> xgboost.Booster:
82-
# generate some random data for demo
83-
files = make_batches(1024, 17, 31, tmpdir)
84-
it = Iterator(files)
96+
def hist_train(it: Iterator) -> None:
97+
"""The hist tree method can use a special data structure `ExtMemQuantileDMatrix` for
98+
faster initialization and lower memory usage.
99+
100+
.. versionadded:: 3.0.0
101+
102+
"""
85103
# For non-data arguments, specify it here once instead of passing them by the `next`
86104
# method.
87-
missing = np.nan
88-
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
105+
Xy = xgboost.ExtMemQuantileDMatrix(it, missing=np.nan, enable_categorical=False)
106+
booster = xgboost.train(
107+
{"tree_method": "hist", "max_depth": 4, "device": it.device},
108+
Xy,
109+
evals=[(Xy, "Train")],
110+
num_boost_round=10,
111+
)
112+
booster.predict(Xy)
113+
114+
115+
def approx_train(it: Iterator) -> None:
116+
"""The approx tree method uses the basic `DMatrix`."""
89117

90-
# ``approx`` is also supported, but less efficient due to sketching. GPU behaves
91-
# differently than CPU tree methods as it uses a hybrid approach. See tutorial in
92-
# doc for details.
118+
# For non-data arguments, specify it here once instead of passing them by the `next`
119+
# method.
120+
Xy = xgboost.DMatrix(it, missing=np.nan, enable_categorical=False)
121+
# ``approx`` is also supported, but less efficient due to sketching. It's
122+
# recommended to use `hist` instead.
93123
booster = xgboost.train(
94-
{"tree_method": "hist", "max_depth": 4},
124+
{"tree_method": "approx", "max_depth": 4, "device": it.device},
95125
Xy,
96126
evals=[(Xy, "Train")],
97127
num_boost_round=10,
98128
)
99-
return booster
129+
booster.predict(Xy)
130+
131+
132+
def main(tmpdir: str, args: argparse.Namespace) -> None:
133+
"""Entry point for training."""
134+
135+
# generate some random data for demo
136+
files = make_batches(
137+
n_samples_per_batch=1024, n_features=17, n_batches=31, tmpdir=tmpdir
138+
)
139+
it = Iterator(args.device, files)
140+
141+
hist_train(it)
142+
approx_train(it)
100143

101144

102145
if __name__ == "__main__":
103-
with tempfile.TemporaryDirectory() as tmpdir:
104-
main(tmpdir)
146+
parser = argparse.ArgumentParser()
147+
parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
148+
args = parser.parse_args()
149+
if args.device == "cuda":
150+
import cupy as cp
151+
import rmm
152+
from rmm.allocators.cupy import rmm_cupy_allocator
153+
154+
# It's important to use RMM for GPU-based external memory to improve performance.
155+
# If XGBoost is not built with RMM support, a warning will be raised.
156+
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
157+
rmm.mr.set_current_device_resource(mr)
158+
# Set the allocator for cupy as well.
159+
cp.cuda.set_allocator(rmm_cupy_allocator)
160+
# Make sure XGBoost is using RMM for all allocations.
161+
with xgboost.config_context(use_rmm=True):
162+
with tempfile.TemporaryDirectory() as tmpdir:
163+
main(tmpdir, args)
164+
else:
165+
with tempfile.TemporaryDirectory() as tmpdir:
166+
main(tmpdir, args)

doc/jvm/xgboost_spark_migration.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ When submitting the XGBoost application to the Spark cluster, you only need to s
5555
--jars xgboost-spark_2.12-3.0.0.jar \
5656
... \
5757
58-
**************
58+
***************
5959
XGBoost Ranking
60-
**************
60+
***************
6161

6262
Learning to rank using XGBoostRegressor has been replaced by a dedicated `XGBoostRanker`, which is specifically designed
6363
to support ranking algorithms.

doc/parameter.rst

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,35 @@ Parameters for Tree Booster
230230
- ``one_output_per_tree``: One model for each target.
231231
- ``multi_output_tree``: Use multi-target trees.
232232

233+
234+
Parameters for Non-Exact Tree Methods
235+
=====================================
236+
233237
* ``max_cached_hist_node``, [default = 65536]
234238

235-
Maximum number of cached nodes for histogram.
239+
Maximum number of cached nodes for histogram. This can be used with the ``hist`` and the
240+
``approx`` tree methods.
236241

237242
.. versionadded:: 2.0.0
238243

239244
- For most of the cases this parameter should not be set except for growing deep
240245
trees. After 3.0, this parameter affects GPU algorithms as well.
241246

247+
248+
* ``extmem_concat_pages``, [default = ``false``]
249+
250+
This parameter is only used for the ``hist`` tree method with ``device=cuda`` and
251+
``subsample != 1.0``. Before 3.0, pages were always concatenated.
252+
253+
.. versionadded:: 3.0.0
254+
255+
Whether the GPU-based ``hist`` tree method should concatenate the training data into a
256+
single batch instead of fetching data on-demand when external memory is used. For GPU
257+
devices that don't support address translation services, external memory training is
258+
expensive. This parameter can be used in combination with subsampling to reduce overall
259+
memory usage without significant overhead. See :doc:`/tutorials/external_memory` for
260+
more information.
261+
242262
.. _cat-param:
243263

244264
Parameters for Categorical Feature

doc/python/python_api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ Core Data Structure
2626

2727
.. autoclass:: xgboost.QuantileDMatrix
2828
:members:
29+
:inherited-members:
30+
:show-inheritance:
31+
32+
.. autoclass:: xgboost.ExtMemQuantileDMatrix
33+
:members:
34+
:inherited-members:
2935
:show-inheritance:
3036

3137
.. autoclass:: xgboost.Booster

0 commit comments

Comments
 (0)