Skip to content

Commit ed2465c

Browse files
authored
Add configuration to R interface. (dmlc#5217)
* Save and load internal parameter configuration as JSON.
1 parent 8ca9744 commit ed2465c

File tree

8 files changed

+114
-2
lines changed

8 files changed

+114
-2
lines changed

R-package/NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ S3method(setinfo,xgb.DMatrix)
1414
S3method(slice,xgb.DMatrix)
1515
export("xgb.attr<-")
1616
export("xgb.attributes<-")
17+
export("xgb.config<-")
1718
export("xgb.parameters<-")
1819
export(cb.cv.predict)
1920
export(cb.early.stop)
@@ -30,6 +31,7 @@ export(xgb.DMatrix)
3031
export(xgb.DMatrix.save)
3132
export(xgb.attr)
3233
export(xgb.attributes)
34+
export(xgb.config)
3335
export(xgb.create.features)
3436
export(xgb.cv)
3537
export(xgb.dump)

R-package/R/xgb.Booster.R

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,35 @@ xgb.attributes <- function(object) {
503503
object
504504
}
505505

506+
#' Accessors for model parameters as JSON string.
507+
#'
508+
#' @param object Object of class \code{xgb.Booster}
509+
#' @param value A JSON string.
510+
#'
511+
#' @examples
512+
#' data(agaricus.train, package='xgboost')
513+
#' train <- agaricus.train
514+
#'
515+
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
516+
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
517+
#' config <- xgb.config(bst)
518+
#'
519+
#' @rdname xgb.config
520+
#' @export
521+
xgb.config <- function(object) {
522+
handle <- xgb.get.handle(object)
523+
.Call(XGBoosterSaveJsonConfig_R, handle);
524+
}
525+
526+
#' @rdname xgb.config
527+
#' @export
528+
`xgb.config<-` <- function(object, value) {
529+
handle <- xgb.get.handle(object)
530+
.Call(XGBoosterLoadJsonConfig_R, handle, value)
531+
object$raw <- xgb.Booster.complete(object)
532+
object
533+
}
534+
506535
#' Accessors for model parameters.
507536
#'
508537
#' Only the setter for xgboost parameters is currently implemented.

R-package/man/xgb.config.Rd

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/src/init.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ extern SEXP XGBoosterGetAttrNames_R(SEXP);
2323
extern SEXP XGBoosterGetAttr_R(SEXP, SEXP);
2424
extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
2525
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
26+
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
27+
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
2628
extern SEXP XGBoosterModelToRaw_R(SEXP);
2729
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
2830
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
@@ -49,6 +51,8 @@ static const R_CallMethodDef CallEntries[] = {
4951
{"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2},
5052
{"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2},
5153
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
54+
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
55+
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
5256
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
5357
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
5458
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},

R-package/src/xgboost_R.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,24 @@ SEXP XGBoosterModelToRaw_R(SEXP handle) {
362362
return ret;
363363
}
364364

365+
SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
366+
const char* ret;
367+
R_API_BEGIN();
368+
bst_ulong len {0};
369+
CHECK_CALL(XGBoosterSaveJsonConfig(R_ExternalPtrAddr(handle),
370+
&len,
371+
&ret));
372+
R_API_END();
373+
return mkString(ret);
374+
}
375+
376+
SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) {
377+
R_API_BEGIN();
378+
XGBoosterLoadJsonConfig(R_ExternalPtrAddr(handle), CHAR(asChar(value)));
379+
R_API_END();
380+
return R_NilValue;
381+
}
382+
365383
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
366384
SEXP out;
367385
R_API_BEGIN();

R-package/src/xgboost_R.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,22 @@ XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
179179
* \brief save model into R's raw array
180180
* \param handle handle
181181
* \return raw array
182-
*/
182+
*/
183183
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
184184

185+
/*!
186+
* \brief Save internal parameters as a JSON string
187+
* \param handle handle
188+
* \return JSON string
189+
*/
190+
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
191+
/*!
192+
* \brief Load the JSON string returnd by XGBoosterSaveJsonConfig_R
193+
* \param handle handle
194+
* \param value JSON string
195+
* \return R_NilValue
196+
*/
197+
XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
185198
/*!
186199
* \brief dump model into a string
187200
* \param handle handle

R-package/tests/testthat/test_basic.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,13 @@ test_that("colsample_bytree works", {
324324
# in the 100 trees
325325
expect_gte(nrow(xgb.importance(model = bst)), 30)
326326
})
327+
328+
test_that("Configuration works", {
329+
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
330+
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic",
331+
eval_metric = 'error', eval_metric = 'auc', eval_metric = "logloss")
332+
config <- xgb.config(bst)
333+
xgb.config(bst) <- config
334+
reloaded_config <- xgb.config(bst)
335+
expect_equal(config, reloaded_config);
336+
})

doc/tutorials/saving_model.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ comments in the script for more details.
102102
Saving and Loading the internal parameters configuration
103103
********************************************************
104104

105-
XGBoost's ``C API`` and ``Python API`` supports saving and loading the internal
105+
XGBoost's ``C API``, ``Python API`` and ``R API`` support saving and loading the internal
106106
configuration directly as a JSON string. In Python package:
107107

108108
.. code-block:: python
@@ -111,6 +111,14 @@ configuration directly as a JSON string. In Python package:
111111
config = bst.save_config()
112112
print(config)
113113
114+
115+
or
116+
117+
.. code-block:: R
118+
119+
config <- xgb.config(bst)
120+
print(config)
121+
114122
Will print out something similiar to (not actual output as it's too long for demonstration):
115123

116124
.. code-block:: json

0 commit comments

Comments
 (0)