Skip to content

Commit

Permalink
Adding ability to load model from string in R (#472)
Browse files Browse the repository at this point in the history
* Update lgb.Booster.R

Added method to call LGBM_BoosterLoadModelFromString_R if model_str is provided for initialization, added option to load from model_str in lgb.load

* Update lightgbm_R.cpp

Adding LGBM_BoosterLoadModelFromString_R

* Update lightgbm_R.cpp

Added LGBM_BoosterSaveModelToString_R

* Update lightgbm_R.cpp

* Update lgb.Booster.R

Added save_model_to_string method

* Update lgb.Booster.R

Implemented @Laurae2 comments

* Update lgb.Booster.R

* Update lightgbm_R.h

Added load and save model from/to string exports
  • Loading branch information
JesseLimtiaco authored and guolinke committed May 5, 2017
1 parent 2a64bfe commit beb5fc5
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 6 deletions.
56 changes: 50 additions & 6 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Booster <- R6Class(
initialize = function(params = list(),
train_set = NULL,
modelfile = NULL,
model_str = NULL,
...) {

# Create parameters and handle
Expand Down Expand Up @@ -73,10 +74,22 @@ Booster <- R6Class(
ret = handle,
lgb.c_str(modelfile))

} else if (!is.null(model_str)) {

# Do we have a model_str as character?
if (!is.character(model_str)) {
stop("lgb.Booster: Can only use a string as model_str")
}

# Create booster from model
handle <- lgb.call("LGBM_BoosterLoadModelFromString_R",
ret = handle,
lgb.c_str(model_str))

} else {

# Booster non existent
stop("lgb.Booster: Need at least either training dataset or model file to create booster instance")
stop("lgb.Booster: Need at least either training dataset, model file, or model_str to create booster instance")

}

Expand Down Expand Up @@ -343,6 +356,21 @@ Booster <- R6Class(
return(self)
},

# Save model to string
save_model_to_string = function(num_iteration = NULL) {

# Check if number of iteration is non existent
if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}

# Return model string
return(lgb.call.return.str("LGBM_BoosterSaveModelToString_R",
private$handle,
as.integer(num_iteration)))

},

# Dump model in memory
dump_model = function(num_iteration = NULL) {

Expand Down Expand Up @@ -645,9 +673,12 @@ predict.lgb.Booster <- function(object, data,

#' Load LightGBM model
#'
#' Load LightGBM model from saved model file
#' Load LightGBM model from saved model file or string
#' Load LightGBM takes in either a file path or model string
#' If both are provided, Load will default to loading from file
#'
#' @param filename path of model file
#' @param model_str a str containing the model
#'
#' @return booster
#'
Expand All @@ -671,19 +702,32 @@ predict.lgb.Booster <- function(object, data,
#' early_stopping_rounds = 10)
#' lgb.save(model, "model.txt")
#' load_booster <- lgb.load("model.txt")
#' load_booster_from_str <- lgb.load(model$raw)
#' }
#'
#' @rdname lgb.load
#' @export
lgb.load <- function(filename){
lgb.load <- function(filename = NULL, model_str = NULL){

# Check if file name is character or not
if (!is.character(filename)) {
if (is.null(filename) && is.null(model_str)) {
stop("lgb.load: either filename or model_str must be given")
}

# Load from filename
if (!is.null(filename) && !is.character(filename)) {
stop("lgb.load: filename should be character")
}

# Return new booster
Booster$new(modelfile = filename)
if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename")
if (!is.null(filename)) return(Booster$new(modelfile = filename))

# Load from model_str
if (!is.null(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be character")
}
# Return new booster
if (!is.null(model_str)) return(Booster$new(model_str = model_str))

}

Expand Down
31 changes: 31 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,18 @@ LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
R_API_END();
}

LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
LGBM_SE out,
LGBM_SE call_state) {

R_API_BEGIN();
int out_num_iterations = 0;
BoosterHandle handle;
CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
R_SET_PTR(out, handle);
R_API_END();
}

LGBM_SE LGBM_BoosterMerge_R(LGBM_SE handle,
LGBM_SE other_handle,
LGBM_SE call_state) {
Expand Down Expand Up @@ -579,6 +591,25 @@ LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
R_API_END();
}

LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
LGBM_SE call_state) {
R_API_BEGIN();
int out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
if (out_len < R_AS_INT(buffer_len)) {
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
} else {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
}
R_API_END();
}

LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE buffer_len,
Expand Down
24 changes: 24 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
LGBM_SE out,
LGBM_SE call_state);

/*!
* \brief load an existing boosting from model_str
* \param model_str string containing the model
* \param out handle of created Booster
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
LGBM_SE out,
LGBM_SE call_state);

/*!
* \brief Merge model in two boosters to first handle
* \param handle handle, will merge other handle to this
Expand Down Expand Up @@ -468,6 +478,20 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
LGBM_SE filename,
LGBM_SE call_state);

/*!
* \brief create string containing model
* \param handle handle
* \param num_iteration, <= 0 means save all
* \param out_str string of model
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
LGBM_SE call_state);

/*!
* \brief dump model to json
* \param handle handle
Expand Down

0 comments on commit beb5fc5

Please sign in to comment.