Skip to content

Commit ccc5f05

Browse files
authored
[EM] Add tests for irregular data shapes. (dmlc#10980)
- More tests. - Recommend arena in the document.
1 parent 3b8f432 commit ccc5f05

File tree

8 files changed

+64
-18
lines changed

8 files changed

+64
-18
lines changed

demo/dask/forward_logging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""Example of forwarding evaluation logs to the client
1+
"""
2+
Example of forwarding evaluation logs to the client
23
===================================================
34
45
The example runs on GPU. Two classes are defined to show how to use Dask builtins to

demo/guide-python/distributed_extmem_basic.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
If `device` is `cuda`, following are also needed:
1414
1515
- cupy
16+
- python-cuda
1617
- rmm
1718
1819
"""
@@ -104,11 +105,22 @@ def setup_rmm() -> None:
104105
if not xgboost.build_info()["USE_RMM"]:
105106
return
106107

107-
# The combination of pool and async is by design. As XGBoost needs to allocate large
108-
# pages repeatly, it's not easy to handle fragmentation. We can use more experiments
109-
# here.
110-
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
111-
rmm.mr.set_current_device_resource(mr)
108+
try:
109+
from cuda import cudart
110+
from rmm.mr import ArenaMemoryResource
111+
112+
status, free, total = cudart.cudaMemGetInfo()
113+
if status != cudart.cudaError_t.cudaSuccess:
114+
raise RuntimeError(cudart.cudaGetErrorString(status))
115+
116+
mr = rmm.mr.CudaMemoryResource()
117+
mr = ArenaMemoryResource(mr, arena_size=int(total * 0.9))
118+
except ImportError:
119+
# The combination of pool and async is by design. As XGBoost needs to allocate
120+
# large pages repeatly, it's not easy to handle fragmentation. We can use more
121+
# experiments here.
122+
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
123+
rmm.mr.set_current_device_resource(mr)
112124
# Set the allocator for cupy as well.
113125
cp.cuda.set_allocator(rmm_cupy_allocator)
114126

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def is_readthedocs_build():
294294
"dask": ("https://docs.dask.org/en/stable/", None),
295295
"distributed": ("https://distributed.dask.org/en/stable/", None),
296296
"pyspark": ("https://spark.apache.org/docs/latest/api/python/", None),
297+
"rmm": ("https://docs.rapids.ai/api/rmm/nightly/", None),
297298
}
298299

299300

doc/tutorials/external_memory.rst

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ the GPU. Following is a snippet from :ref:`sphx_glr_python_examples_external_mem
138138
139139
# It's important to use RMM for GPU-based external memory to improve performance.
140140
# If XGBoost is not built with RMM support, a warning will be raised.
141+
# We use the pool memory resource here, you can also try the `ArenaMemoryResource` for
142+
# improved memory fragmentation handling.
141143
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
142144
rmm.mr.set_current_device_resource(mr)
143145
# Set the allocator for cupy as well.
@@ -278,13 +280,15 @@ determines the time it takes to run inference, even if a C2C link is available.
278280
Xy_valid = xgboost.ExtMemQuantileDMatrix(it_valid, max_bin=n_bins, ref=Xy_train)
279281
280282
In addition, since the GPU implementation relies on asynchronous memory pool, which is
281-
subject to memory fragmentation even if the ``CudaAsyncMemoryResource`` is used. You might
282-
want to start the training with a fresh pool instead of starting training right after the
283-
ETL process. If you run into out-of-memory errors and you are convinced that the pool is
284-
not full yet (pool memory usage can be profiled with ``nsight-system``), consider tuning
285-
the RMM memory resource like using ``rmm.mr.CudaAsyncMemoryResource`` in conjunction with
286-
``rmm.mr.BinningMemoryResource(mr, 21, 25)`` instead of the
287-
``rmm.mr.PoolMemoryResource(mr)`` shown in the example.
283+
subject to memory fragmentation even if the :py:class:`~rmm.mr.CudaAsyncMemoryResource` is
284+
used. You might want to start the training with a fresh pool instead of starting training
285+
right after the ETL process. If you run into out-of-memory errors and you are convinced
286+
that the pool is not full yet (pool memory usage can be profiled with ``nsight-system``),
287+
consider tuning the RMM memory resource like using
288+
:py:class:`~rmm.mr.CudaAsyncMemoryResource` in conjunction with
289+
:py:class:`BinningMemoryResource(mr, 21, 25) <rmm.mr.BinningMemoryResource>` instead of
290+
the :py:class:`~rmm.mr.PoolMemoryResource`. Alternately, the
291+
:py:class:`~rmm.mr.ArenaMemoryResource` is also an excellent option.
288292

289293
During CPU benchmarking, we used an NVMe connected to a PCIe-4 slot. Other types of
290294
storage can be too slow for practical usage. However, your system will likely perform some

python-package/xgboost/testing/data_iter.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from xgboost import testing as tm
88

9-
from ..core import DataIter, ExtMemQuantileDMatrix, QuantileDMatrix
9+
from ..core import DataIter, DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix
1010

1111

1212
def run_mixed_sparsity(device: str) -> None:
@@ -78,6 +78,24 @@ def reset(self) -> None:
7878
ExtMemQuantileDMatrix(it, enable_categorical=True)
7979

8080

81+
def check_uneven_sizes(device: str) -> None:
82+
"""Tests for having irregular data shapes."""
83+
batches = [
84+
tm.make_regression(n_samples, 16, use_cupy=device == "cuda")
85+
for n_samples in [512, 256, 1024]
86+
]
87+
unzip = list(zip(*batches))
88+
it = tm.IteratorForTest(unzip[0], unzip[1], None, cache="cache", on_host=True)
89+
90+
Xy = DMatrix(it)
91+
assert Xy.num_col() == 16
92+
assert Xy.num_row() == sum(x.shape[0] for x in unzip[0])
93+
94+
Xy = ExtMemQuantileDMatrix(it)
95+
assert Xy.num_col() == 16
96+
assert Xy.num_row() == sum(x.shape[0] for x in unzip[0])
97+
98+
8199
class CatIter(DataIter): # pylint: disable=too-many-instance-attributes
82100
"""An iterator for testing categorical features."""
83101

src/data/ellpack_page_source.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,9 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
404404
this->GetCuts()};
405405
this->info_->Extend(proxy_->Info(), false, true);
406406
});
407-
// The size of ellpack is logged in write cache.
408-
LOG(INFO) << "Estimated batch size:"
407+
LOG(INFO) << "Generated an Ellpack page with size: "
408+
<< common::HumanMemUnit(this->page_->Impl()->MemCostBytes())
409+
<< " from an batch with estimated size: "
409410
<< cuda_impl::Dispatch<false>(proxy_, [](auto const& adapter) {
410411
return common::HumanMemUnit(adapter->SizeBytes());
411412
});

tests/python-gpu/test_gpu_data_iterator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import xgboost as xgb
88
from xgboost import testing as tm
99
from xgboost.testing import no_cupy
10-
from xgboost.testing.data_iter import check_invalid_cat_batches
10+
from xgboost.testing.data_iter import check_invalid_cat_batches, check_uneven_sizes
1111
from xgboost.testing.updater import (
1212
check_categorical_missing,
1313
check_categorical_ohe,
@@ -231,3 +231,7 @@ def test_categorical_ohe(tree_method: str) -> None:
231231
@pytest.mark.skipif(**tm.no_cupy())
232232
def test_invalid_cat_batches() -> None:
233233
check_invalid_cat_batches("cuda")
234+
235+
236+
def test_uneven_sizes() -> None:
237+
check_uneven_sizes("cuda")

tests/python/test_data_iterator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from xgboost import testing as tm
1313
from xgboost.data import SingleBatchInternalIter as SingleBatch
1414
from xgboost.testing import IteratorForTest, make_batches, non_increasing
15-
from xgboost.testing.data_iter import check_invalid_cat_batches
15+
from xgboost.testing.data_iter import check_invalid_cat_batches, check_uneven_sizes
1616
from xgboost.testing.updater import (
1717
check_categorical_missing,
1818
check_categorical_ohe,
@@ -375,3 +375,8 @@ def test_categorical_ohe(tree_method: str) -> None:
375375

376376
def test_invalid_cat_batches() -> None:
377377
check_invalid_cat_batches("cpu")
378+
379+
380+
@pytest.mark.skipif(**tm.no_cupy())
381+
def test_uneven_sizes() -> None:
382+
check_uneven_sizes("cpu")

0 commit comments

Comments
 (0)