|
3 | 3 | */ |
4 | 4 | #include "xgboost/c_api.h" |
5 | 5 |
|
6 | | -#include <algorithm> // for copy, transform |
7 | | -#include <cinttypes> // for strtoimax |
8 | | -#include <cmath> // for nan |
9 | | -#include <cstring> // for strcmp |
10 | | -#include <limits> // for numeric_limits |
11 | | -#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre... |
12 | | -#include <memory> // for shared_ptr, allocator, __shared_ptr_access |
13 | | -#include <string> // for char_traits, basic_string, operator==, string |
14 | | -#include <system_error> // for errc |
15 | | -#include <utility> // for pair |
16 | | -#include <vector> // for vector |
17 | | - |
18 | | -#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry |
19 | | -#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... |
20 | | -#include "../common/error_msg.h" // for NoFederated |
21 | | -#include "../common/hist_util.h" // for HistogramCuts |
22 | | -#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... |
23 | | -#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor |
24 | | -#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte... |
25 | | -#include "../data/ellpack_page.h" // for EllpackPage |
26 | | -#include "../data/proxy_dmatrix.h" // for DMatrixProxy |
27 | | -#include "../data/simple_dmatrix.h" // for SimpleDMatrix |
28 | | -#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN |
29 | | -#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM... |
30 | | -#include "dmlc/base.h" // for BeginPtr |
31 | | -#include "dmlc/io.h" // for Stream |
32 | | -#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager |
33 | | -#include "dmlc/thread_local.h" // for ThreadLocalStore |
34 | | -#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat... |
35 | | -#include "xgboost/context.h" // for Context |
36 | | -#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage |
37 | | -#include "xgboost/feature_map.h" // for FeatureMap |
38 | | -#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal... |
39 | | -#include "xgboost/host_device_vector.h" // for HostDeviceVector |
40 | | -#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String |
41 | | -#include "xgboost/learner.h" // for Learner, PredictionType |
42 | | -#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ |
43 | | -#include "xgboost/predictor.h" // for PredictionCacheEntry |
44 | | -#include "xgboost/span.h" // for Span |
45 | | -#include "xgboost/string_view.h" // for StringView, operator<< |
46 | | -#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS... |
| 6 | +#include <algorithm> // for copy, transform |
| 7 | +#include <cinttypes> // for strtoimax |
| 8 | +#include <cmath> // for nan |
| 9 | +#include <cstring> // for strcmp |
| 10 | +#include <limits> // for numeric_limits |
| 11 | +#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre... |
| 12 | +#include <memory> // for shared_ptr, allocator, __shared_ptr_access |
| 13 | +#include <string> // for char_traits, basic_string, operator==, string |
| 14 | +#include <system_error> // for errc |
| 15 | +#include <utility> // for pair |
| 16 | +#include <vector> // for vector |
| 17 | + |
| 18 | +#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry |
| 19 | +#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... |
| 20 | +#include "../common/error_msg.h" // for NoFederated |
| 21 | +#include "../common/hist_util.h" // for HistogramCuts |
| 22 | +#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... |
| 23 | +#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor |
| 24 | +#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte... |
| 25 | +#include "../data/batch_utils.h" // for MatchingPageBytes, CachePageRatio |
| 26 | +#include "../data/ellpack_page.h" // for EllpackPage |
| 27 | +#include "../data/proxy_dmatrix.h" // for DMatrixProxy |
| 28 | +#include "../data/simple_dmatrix.h" // for SimpleDMatrix |
| 29 | +#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN |
| 30 | +#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM... |
| 31 | +#include "dmlc/base.h" // for BeginPtr |
| 32 | +#include "dmlc/io.h" // for Stream |
| 33 | +#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager |
| 34 | +#include "dmlc/thread_local.h" // for ThreadLocalStore |
| 35 | +#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat... |
| 36 | +#include "xgboost/context.h" // for Context |
| 37 | +#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage |
| 38 | +#include "xgboost/feature_map.h" // for FeatureMap |
| 39 | +#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal... |
| 40 | +#include "xgboost/host_device_vector.h" // for HostDeviceVector |
| 41 | +#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String |
| 42 | +#include "xgboost/learner.h" // for Learner, PredictionType |
| 43 | +#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ |
| 44 | +#include "xgboost/predictor.h" // for PredictionCacheEntry |
| 45 | +#include "xgboost/span.h" // for Span |
| 46 | +#include "xgboost/string_view.h" // for StringView, operator<< |
| 47 | +#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS... |
47 | 48 |
|
48 | 49 | using namespace xgboost; // NOLINT(*); |
49 | 50 |
|
@@ -296,15 +297,20 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy |
296 | 297 | auto jconfig = Json::Load(StringView{config}); |
297 | 298 | auto missing = GetMissing(jconfig); |
298 | 299 | std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__); |
299 | | - auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0); |
| 300 | + std::int32_t n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0); |
300 | 301 | auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false); |
| 302 | + auto min_cache_page_bytes = OptionalArg<Integer, std::int64_t>(jconfig, "min_cache_page_bytes", |
| 303 | + cuda_impl::MatchingPageBytes()); |
| 304 | + CHECK_EQ(min_cache_page_bytes, cuda_impl::MatchingPageBytes()) |
| 305 | + << "Page concatenation is not supported by the DMatrix yet."; |
301 | 306 |
|
302 | 307 | xgboost_CHECK_C_ARG_PTR(next); |
303 | 308 | xgboost_CHECK_C_ARG_PTR(reset); |
304 | 309 | xgboost_CHECK_C_ARG_PTR(out); |
305 | 310 |
|
| 311 | + auto config = ExtMemConfig{cache, on_host, min_cache_page_bytes, missing, n_threads}; |
306 | 312 | *out = new std::shared_ptr<xgboost::DMatrix>{ |
307 | | - xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache, on_host)}; |
| 313 | + xgboost::DMatrix::Create(iter, proxy, reset, next, config)}; |
308 | 314 | API_END(); |
309 | 315 | } |
310 | 316 |
|
@@ -368,17 +374,20 @@ XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr |
368 | 374 | xgboost_CHECK_C_ARG_PTR(config); |
369 | 375 | auto jconfig = Json::Load(StringView{config}); |
370 | 376 | auto missing = GetMissing(jconfig); |
371 | | - auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0); |
| 377 | + std::int32_t n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0); |
372 | 378 | auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256); |
373 | 379 | auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false); |
374 | 380 | std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__); |
| 381 | + auto min_cache_page_bytes = OptionalArg<Integer, std::int64_t>(jconfig, "min_cache_page_bytes", |
| 382 | + cuda_impl::AutoCachePageBytes()); |
375 | 383 |
|
376 | 384 | xgboost_CHECK_C_ARG_PTR(next); |
377 | 385 | xgboost_CHECK_C_ARG_PTR(reset); |
378 | 386 | xgboost_CHECK_C_ARG_PTR(out); |
379 | 387 |
|
380 | | - *out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create( |
381 | | - iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)}; |
| 388 | + auto config = ExtMemConfig{cache, on_host, min_cache_page_bytes, missing, n_threads}; |
| 389 | + *out = new std::shared_ptr<xgboost::DMatrix>{ |
| 390 | + xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, max_bin, config)}; |
382 | 391 | API_END(); |
383 | 392 | } |
384 | 393 |
|
|
0 commit comments