Skip to content

Commit bed7ae4

Browse files
authored
Loop over thrust::reduce. (dmlc#6229)
* Check input chunk size of dqdm. * Add doc for current limitation.
1 parent 734a911 commit bed7ae4

File tree

10 files changed

+46
-8
lines changed

10 files changed

+46
-8
lines changed

doc/tutorials/saving_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ or in R:
167167
168168
Will print out something similiar to (not actual output as it's too long for demonstration):
169169

170-
.. code-block:: json
170+
.. code-block:: javascript
171171
172172
{
173173
"Learner": {

python-package/xgboost/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,8 @@ class DeviceQuantileDMatrix(DMatrix):
871871
872872
.. versionadded:: 1.1.0
873873
874+
Known limitation:
875+
The data size (rows * cols) can not exceed 2 ** 31 - 1000
874876
"""
875877

876878
def __init__(self, data, label=None, weight=None, # pylint: disable=W0231

python-package/xgboost/dask.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,10 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
509509
max_bin: Number of bins for histogram construction.
510510
511511
512+
Know issue:
513+
The size of each chunk (rows * cols for a single dask chunk/partition) can
514+
not exceed 2 ** 31 - 1000
515+
512516
'''
513517
def __init__(self, client,
514518
data,

src/common/device_helpers.cuh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,4 +1132,21 @@ size_t SegmentedUnique(Inputs &&...inputs) {
11321132
dh::XGBCachingDeviceAllocator<char> alloc;
11331133
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...);
11341134
}
1135+
1136+
template <typename Policy, typename InputIt, typename Init, typename Func>
1137+
auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) {
1138+
size_t constexpr kLimit = std::numeric_limits<int32_t>::max() / 2;
1139+
size_t size = std::distance(first, second);
1140+
using Ty = std::remove_cv_t<Init>;
1141+
Ty aggregate = init;
1142+
for (size_t offset = 0; offset < size; offset += kLimit) {
1143+
auto begin_it = first + offset;
1144+
auto end_it = first + std::min(offset + kLimit, size);
1145+
size_t batch_size = std::distance(begin_it, end_it);
1146+
CHECK_LE(batch_size, size);
1147+
auto ret = thrust::reduce(policy, begin_it, end_it, init, reduce_op);
1148+
aggregate = reduce_op(aggregate, ret);
1149+
}
1150+
return aggregate;
1151+
}
11351152
} // namespace dh

src/data/device_adapter.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
221221
}
222222
});
223223
dh::XGBCachingDeviceAllocator<char> alloc;
224-
size_t row_stride = thrust::reduce(
224+
size_t row_stride = dh::Reduce(
225225
thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
226226
thrust::device_pointer_cast(offset.data()) + offset.size(), size_t(0),
227227
thrust::maximum<size_t>());

src/data/ellpack_page.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,14 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
206206
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
207207
out(discard, functor);
208208
dh::XGBCachingDeviceAllocator<char> alloc;
209+
// 1000 as a safe factor for inclusive_scan, otherwise it might generate overflow and
210+
// lead to oom error.
211+
// or:
212+
// after reduction step 2: cudaErrorInvalidConfiguration: invalid configuration argument
213+
// https://github.com/NVIDIA/thrust/issues/1299
214+
CHECK_LE(batch.Size(), std::numeric_limits<int32_t>::max() - 1000)
215+
<< "Known limitation, size (rows * cols) of quantile based DMatrix "
216+
"cannot exceed the limit of 32-bit integer.";
209217
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
210218
key_value_index_iter + batch.Size(), out,
211219
[=] __device__(Tuple a, Tuple b) {

src/tree/gpu_hist/histogram.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct Pair {
5353
GradientPair first;
5454
GradientPair second;
5555
};
56-
XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
56+
__host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
5757
return {lhs.first + rhs.first, lhs.second + rhs.second};
5858
}
5959
} // anonymous namespace
@@ -86,7 +86,7 @@ GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair) {
8686
thrust::device_ptr<GradientPair const> gpair_end {gpair.data() + gpair.size()};
8787
auto beg = thrust::make_transform_iterator(gpair_beg, Clip());
8888
auto end = thrust::make_transform_iterator(gpair_end, Clip());
89-
Pair p = thrust::reduce(thrust::cuda::par(alloc), beg, end, Pair{});
89+
Pair p = dh::Reduce(thrust::cuda::par(alloc), beg, end, Pair{}, thrust::plus<Pair>{});
9090
GradientPair positive_sum {p.first}, negative_sum {p.second};
9191

9292
auto histogram_rounding = GradientSumT {

src/tree/updater_gpu_hist.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,11 @@ struct GPUHistMakerDevice {
642642
ExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
643643
constexpr bst_node_t kRootNIdx = 0;
644644
dh::XGBCachingDeviceAllocator<char> alloc;
645-
GradientPair root_sum = thrust::reduce(
645+
GradientPair root_sum = dh::Reduce(
646646
thrust::cuda::par(alloc),
647647
thrust::device_ptr<GradientPair const>(gpair.data()),
648-
thrust::device_ptr<GradientPair const>(gpair.data() + gpair.size()));
648+
thrust::device_ptr<GradientPair const>(gpair.data() + gpair.size()),
649+
GradientPair{}, thrust::plus<GradientPair>{});
649650
rabit::Allreduce<rabit::op::Sum, float>(reinterpret_cast<float*>(&root_sum),
650651
2);
651652

tests/cpp/common/test_device_helpers.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
/*!
32
* Copyright 2017 XGBoost contributors
43
*/
@@ -122,6 +121,14 @@ void TestSegmentedUniqueRegression(std::vector<SketchEntry> values, size_t n_dup
122121
ASSERT_EQ(segments.at(1), d_segments_out[1] + n_duplicated);
123122
}
124123

124+
TEST(DeviceHelpers, Reduce) {
125+
size_t kSize = std::numeric_limits<uint32_t>::max();
126+
auto it = thrust::make_counting_iterator(0ul);
127+
dh::XGBCachingDeviceAllocator<char> alloc;
128+
auto batched = dh::Reduce(thrust::cuda::par(alloc), it, it + kSize, 0ul, thrust::maximum<size_t>{});
129+
CHECK_EQ(batched, kSize - 1);
130+
}
131+
125132

126133
TEST(SegmentedUnique, Regression) {
127134
{

tests/cpp/data/test_ellpack_page.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,5 +234,4 @@ TEST(EllpackPage, Compact) {
234234
}
235235
}
236236
}
237-
238237
} // namespace xgboost

0 commit comments

Comments
 (0)