Skip to content

Commit

Permalink
[R-package] expose start_iteration to dump/save/lgb.model.dt.tree (#6398
Browse files Browse the repository at this point in the history
)
  • Loading branch information
mayer79 committed May 16, 2024
1 parent a70e832 commit e0ac635
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 39 deletions.
46 changes: 37 additions & 9 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,12 @@ Booster <- R6::R6Class(
},

# Save model
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
save_model = function(
filename
, num_iteration = NULL
, feature_importance_type = 0L
, start_iteration = 1L
) {

self$restore_handle()

Expand All @@ -432,12 +437,18 @@ Booster <- R6::R6Class(
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, filename
, as.integer(start_iteration) - 1L # Turn to 0-based
)

return(invisible(self))
},

save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {
save_model_to_string = function(
num_iteration = NULL
, feature_importance_type = 0L
, as_char = TRUE
, start_iteration = 1L
) {

self$restore_handle()

Expand All @@ -450,6 +461,7 @@ Booster <- R6::R6Class(
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, as.integer(start_iteration) - 1L # Turn to 0-based
)

if (as_char) {
Expand All @@ -461,7 +473,9 @@ Booster <- R6::R6Class(
},

# Dump model in memory
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
dump_model = function(
num_iteration = NULL, feature_importance_type = 0L, start_iteration = 1L
) {

self$restore_handle()

Expand All @@ -474,6 +488,7 @@ Booster <- R6::R6Class(
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, as.integer(start_iteration) - 1L # Turn to 0-based
)

return(model_str)
Expand Down Expand Up @@ -1288,8 +1303,11 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' @title Save LightGBM model
#' @description Save LightGBM model
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param filename Saved filename
#' @param num_iteration Number of iterations to save, NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to save.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "save the fifth, sixth, and seventh tree"
#'
#' @return lgb.Booster
#'
Expand Down Expand Up @@ -1322,7 +1340,9 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' lgb.save(model, tempfile(fileext = ".txt"))
#' }
#' @export
lgb.save <- function(booster, filename, num_iteration = NULL) {
lgb.save <- function(
booster, filename, num_iteration = NULL, start_iteration = 1L
) {

if (!.is_Booster(x = booster)) {
stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
Expand All @@ -1338,6 +1358,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
invisible(booster$save_model(
filename = filename
, num_iteration = num_iteration
, start_iteration = start_iteration
))
)

Expand All @@ -1347,7 +1368,10 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param num_iteration Number of iterations to be dumped. NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to dump.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "dump the fifth, sixth, and seventh tree"
#'
#' @return json format of model
#'
Expand Down Expand Up @@ -1380,14 +1404,18 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' json_model <- lgb.dump(model)
#' }
#' @export
lgb.dump <- function(booster, num_iteration = NULL) {
lgb.dump <- function(booster, num_iteration = NULL, start_iteration = 1L) {

if (!.is_Booster(x = booster)) {
stop("lgb.dump: booster should be an ", sQuote("lgb.Booster"))
}

# Return booster at requested iteration
return(booster$dump_model(num_iteration = num_iteration))
return(
booster$dump_model(
num_iteration = num_iteration, start_iteration = start_iteration
)
)

}

Expand Down
21 changes: 14 additions & 7 deletions R-package/R/lgb.model.dt.tree.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#' @name lgb.model.dt.tree
#' @title Parse a LightGBM model json dump
#' @description Parse a LightGBM model json dump into a \code{data.table} structure.
#' @param model object of class \code{lgb.Booster}
#' @param num_iteration number of iterations you want to predict with. NULL or
#' <= 0 means use best iteration
#' @param model object of class \code{lgb.Booster}.
#' @param num_iteration Number of iterations to include. NULL or <= 0 means use best iteration.
#' @param start_iteration Index (1-based) of the first boosting round to include in the output.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "return information about the fifth, sixth, and seventh trees".
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
#'
Expand Down Expand Up @@ -51,9 +53,15 @@
#' @importFrom data.table := rbindlist
#' @importFrom jsonlite fromJSON
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {

json_model <- lgb.dump(booster = model, num_iteration = num_iteration)
lgb.model.dt.tree <- function(
model, num_iteration = NULL, start_iteration = 1L
) {

json_model <- lgb.dump(
booster = model
, num_iteration = num_iteration
, start_iteration = start_iteration
)

parsed_json_model <- jsonlite::fromJSON(
txt = json_model
Expand Down Expand Up @@ -84,7 +92,6 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
tree_dt[, split_feature := feature_names]

return(tree_dt)

}


Expand Down
8 changes: 6 additions & 2 deletions R-package/man/lgb.dump.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions R-package/man/lgb.model.dt.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 7 additions & 3 deletions R-package/man/lgb.save.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 16 additions & 11 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,32 +1093,35 @@ SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP filename) {
SEXP filename,
SEXP start_iteration) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
UNPROTECT(1);
return R_NilValue;
R_API_END();
}

SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type) {
SEXP feature_importance_type,
SEXP start_iteration) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
// if the model string was larger than the initial buffer, call the function again, writing directly to the R object
if (out_len > buf_len) {
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
} else {
std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
}
Expand All @@ -1129,21 +1132,23 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,

SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type) {
SEXP feature_importance_type,
SEXP start_iteration) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
}
model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
Expand Down Expand Up @@ -1261,9 +1266,9 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterPredictForMatSingleRow_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R , 9},
{"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8},
{"LGBM_BoosterPredictForMatSingleRowFast_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R , 3},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 4},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 4},
{"LGBM_NullBoosterHandleError_R" , (DL_FUNC) &LGBM_NullBoosterHandleError_R , 0},
{"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0},
{"LGBM_GetMaxThreads_R" , (DL_FUNC) &LGBM_GetMaxThreads_R , 1},
Expand Down
12 changes: 9 additions & 3 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -809,39 +809,45 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMatSingleRowFast_R(
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param filename file name
* \param start_iteration Starting iteration (0 based)
* \return R NULL value
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP filename
SEXP filename,
SEXP start_iteration
);

/*!
* \brief create string containing model
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param start_iteration Starting iteration (0 based)
* \return R character vector (length=1) with model string
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type
SEXP feature_importance_type,
SEXP start_iteration
);

/*!
* \brief dump model to JSON
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param start_iteration Index of starting iteration (0 based)
* \return R character vector (length=1) with model JSON
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type
SEXP feature_importance_type,
SEXP start_iteration
);

/*!
Expand Down

0 comments on commit e0ac635

Please sign in to comment.