Skip to content

Commit

Permalink
corrected ranger.unify
Browse files Browse the repository at this point in the history
+ code style
  • Loading branch information
konrad-komisarczyk committed Oct 27, 2020
1 parent 94f61af commit eff70d8
Showing 1 changed file with 42 additions and 34 deletions.
76 changes: 42 additions & 34 deletions R/unifiers.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ xgboost.unify <- function(xgb_model) {
stop("Package \"xgboost\" needed for this function to work. Please install it.",
call. = FALSE)
}
xgbtree <- xgboost::xgb.model.dt.tree(model = xgb_model)
xgbtree <- xgboost::xgb.model.dt.tree(model = xgb_model)
stopifnot(c("Tree", "Node", "ID", "Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover") %in% colnames(xgbtree))
xgbtree$Yes <- match(xgbtree$Yes, xgbtree$ID)
xgbtree$No <- match(xgbtree$No, xgbtree$ID)
xgbtree$Missing <- match(xgbtree$Missing, xgbtree$ID)
xgbtree[xgbtree$Feature == 'Leaf', 'Feature'] <- NA
xgbtree <- xgbtree[,c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover")]
xgbtree <- xgbtree[, c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover")]
colnames(xgbtree) <- c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Quality/Score", "Cover")
attr(xgbtree, "model") <- "xgboost"
return(xgbtree)
Expand Down Expand Up @@ -128,17 +128,17 @@ lightgbm.unify <- function(lgb_model) {
ret.second <- function(x) x[2]
tmp <- data.table::merge.data.table(df[, .(node_parent, tree_index, split_index)], df[, .(tree_index, split_index, default_left, decision_type)],
by.x = c("tree_index", "node_parent"), by.y = c("tree_index", "split_index"))
y_n_m <- unique(tmp[,.(Yes = ifelse(decision_type %in% c("<=", "<"), ret.first(split_index),
y_n_m <- unique(tmp[, .(Yes = ifelse(decision_type %in% c("<=", "<"), ret.first(split_index),
ifelse(decision_type %in% c(">=", ">"), ret.second(split_index), stop("Unknown decision_type"))),
No = ifelse(decision_type %in% c(">=", ">"), ret.first(split_index),
ifelse(decision_type %in% c("<=", "<"), ret.second(split_index), stop("Unknown decision_type"))),
Missing = ifelse(default_left, ret.first(split_index),ret.second(split_index))),
.(tree_index, node_parent)])
df <- data.table::merge.data.table(df[,c("tree_index", "depth", "split_index", "split_feature", "node_parent", "split_gain",
df <- data.table::merge.data.table(df[, c("tree_index", "depth", "split_index", "split_feature", "node_parent", "split_gain",
"threshold", "internal_value", "internal_count")],
y_n_m, by.x = c("tree_index", "split_index"),
by.y = c("tree_index", "node_parent"), all.x = TRUE)
df <- df[,c("tree_index", "split_index", "split_feature", "threshold", "Yes", "No", "Missing", "split_gain", "internal_count")]
df <- df[, c("tree_index", "split_index", "split_feature", "threshold", "Yes", "No", "Missing", "split_gain", "internal_count")]
colnames(df) <- c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Quality/Score", "Cover")
attr(df, "model") <- "LightGBM"
attr(df, "sorted") <- NULL
Expand Down Expand Up @@ -198,25 +198,29 @@ lightgbm.unify <- function(lgb_model) {
#' gbm.unify(gbm_model)
#'}
gbm.unify <- function(gbm_model) {
if(class(gbm_model) != 'gbm'){stop('Object gbm_model was not of class "gbm"')}
if(any(gbm_model$var.type>0)){stop('Models built on data with categorical features are not supported - please encode them before training.')}
if(class(gbm_model) != 'gbm') {
stop('Object gbm_model was not of class "gbm"')
}
if(any(gbm_model$var.type > 0)) {
stop('Models built on data with categorical features are not supported - please encode them before training.')
}
x <- lapply(gbm_model$trees, data.table::as.data.table)
times_vec <- sapply(x, nrow)
y <- data.table::rbindlist(x)
data.table::setnames(y, c("Feature", "Split", "Yes",
"No", "Missing", "ErrorReduction", "Cover",
"Prediction"))
y[["Tree"]] <- rep(0:(length(gbm_model$trees)-1), times = times_vec)
y[["Node"]] <- unlist(lapply(times_vec, function(x) 0:(x-1)))
y[["Tree"]] <- rep(0:(length(gbm_model$trees) - 1), times = times_vec)
y[["Node"]] <- unlist(lapply(times_vec, function(x) 0:(x - 1)))
y <- y[, Feature:=as.character(Feature)]
y[y$Feature<0, "Feature"]<- NA
y[!is.na(y$Feature), "Feature"] <- attr(gbm_model$Terms, "term.labels")[as.integer(y[["Feature"]][!is.na(y$Feature)])+1]
y[!is.na(y$Feature), "Feature"] <- attr(gbm_model$Terms, "term.labels")[as.integer(y[["Feature"]][!is.na(y$Feature)]) + 1]
y[is.na(y$Feature), "ErrorReduction"] <- y[is.na(y$Feature), "Split"]
y[is.na(y$Feature), "Split"] <- NA
y[y$Yes<0, "Yes"] <- NA
y[y$No<0, "No"] <- NA
y[y$Missing<0, "Missing"] <- NA
y <- y[,c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "ErrorReduction", "Cover")]
y <- y[, c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "ErrorReduction", "Cover")]
colnames(y) <- c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Quality/Score", "Cover")
attr(y, "model") <- "gbm"

Expand Down Expand Up @@ -277,9 +281,13 @@ gbm.unify <- function(gbm_model) {
#' # logging_level = 'Info'))
#' #catboost.unify(cat_model, dt.pool)

catboost.unify <- function(catboost_model, pool){
if(class(catboost_model) != "catboost.Model"){stop('Object catboost_model is not of type "catboost.Model"')}
if(class(pool) != "catboost.Pool"){stop('Object pool is not of type "catboost.Pool"')}
catboost.unify <- function(catboost_model, pool) {
if(class(catboost_model) != "catboost.Model") {
stop('Object catboost_model is not of type "catboost.Model"')
}
if(class(pool) != "catboost.Pool") {
stop('Object pool is not of type "catboost.Pool"')
}
if (!requireNamespace("catboost", quietly = TRUE)) {
stop("Package \"catboost\" needed for this function to work. Please install it.",
call. = FALSE)
Expand Down Expand Up @@ -402,26 +410,28 @@ catboost.unify <- function(catboost_model, pool){
#'}
randomForest.unify <- function(rf_model, data) {
if(!'randomForest' %in% class(rf_model)){stop('Object rf_model was not of class "randomForest"')}
if(any(attr(rf_model$terms, "dataClasses") != "numeric")){stop('Models built on data with categorical features are not supported - please encode them before training.')}
if(any(attr(rf_model$terms, "dataClasses") != "numeric")) {
stop('Models built on data with categorical features are not supported - please encode them before training.')
}
n <- rf_model$ntree
ret <- data.table()
x <- lapply(1:n, function(tree){
tree_data <- as.data.table(randomForest::getTree(rf_model, k = tree, labelVar = TRUE))
tree_data[, c("left daughter", "right daughter", "split var", "split point", "prediction")]
tree_data[, c("left daughter", "right daughter", "split var", "split point", "prediction")]
})
times_vec <- sapply(x, nrow)
y <- rbindlist(x)
y[, Tree := rep(0:(n-1), times = times_vec)]
y[, Node := unlist(lapply(times_vec, function(x) 0:(x-1)))]
y[, Tree := rep(0:(n - 1), times = times_vec)]
y[, Node := unlist(lapply(times_vec, function(x) 0:(x - 1)))]
y[, Missing := NA]
y[, Cover := 0]
setnames(y, c("No", "Yes", "Feature", "Split", "Quality/Score", "Tree", "Node", "Missing", "Cover"))
setcolorder(y, c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Quality/Score", "Cover"))
y[, Feature:=as.character(Feature)]
y[, Yes:= Yes - 1]
y[, No:= No - 1]
y[y$Yes<0, "Yes"] <- NA
y[y$No<0, "No"] <- NA
y[, Feature := as.character(Feature)]
y[, Yes := Yes - 1]
y[, No := No - 1]
y[y$Yes < 0, "Yes"] <- NA
y[y$No < 0, "No"] <- NA
attr(y, "model") <- "randomForest"

ID <- paste0(y$Node, "-", y$Tree)
Expand Down Expand Up @@ -481,31 +491,29 @@ randomForest.unify <- function(rf_model, data) {
#' # ranger.unify(rf, data)
#'}
ranger.unify <- function(rf_model, data) {
if(!'ranger' %in% class(rf_model)){stop('Object rf_model was not of class "ranger"')}
if(!'ranger' %in% class(rf_model)) {
stop('Object rf_model was not of class "ranger"')
}
n <- rf_model$num.trees
ret <- data.table()
x <- lapply(1:n, function(tree){
tree_data <- as.data.table(ranger::treeInfo(rf_model, tree = tree))
tree_data[, c("nodeID", "leftChild", "rightChild", "splitvarName", "splitval", "prediction")]
})
times_vec <- sapply(x, nrow)
y <- rbindlist(x)
y[, Tree := rep(0:(n-1), times = times_vec)]
y[, Tree := rep(0:(n - 1), times = times_vec)]
y[, Missing := NA]
y[, Cover := 0]
setnames(y, c("Node" ,"No", "Yes", "Feature", "Split", "Quality/Score", "Tree", "Missing", "Cover"))
setnames(y, c("Node", "Yes", "No", "Feature", "Split", "Quality/Score", "Tree", "Missing", "Cover"))
setcolorder(y, c("Tree", "Node", "Feature", "Split", "Yes", "No", "Missing", "Quality/Score", "Cover"))
y[, Feature:=as.character(Feature)]
y[, Yes:= Yes - 1]
y[, No:= No - 1]
y[y$Yes<0, "Yes"] <- NA
y[y$No<0, "No"] <- NA
y[, Feature := as.character(Feature)]
y[y$Yes < 0, "Yes"] <- NA
y[y$No < 0, "No"] <- NA
attr(y, "model") <- "ranger"

ID <- paste0(y$Node, "-", y$Tree)
y$Yes <- match(paste0(y$Yes, "-", y$Tree), ID)
y$No <- match(paste0(y$No, "-", y$Tree), ID)
y[, Missing := Yes]
y <- recalculate_covers(y, data)
y <- recalculate_covers(y, as.data.frame(data))
return(y)
}

0 comments on commit eff70d8

Please sign in to comment.