Skip to content

Commit

Permalink
[R-package] Speed-up lgb.importance() (#6364)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Apr 10, 2024
1 parent 628e91a commit 5dfe716
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 29 deletions.
56 changes: 27 additions & 29 deletions R-package/R/lgb.model.dt.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,25 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {

#' @importFrom data.table := data.table rbindlist
.single_tree_parse <- function(lgb_tree) {
tree_info_cols <- c(
"split_index"
, "split_feature"
, "split_gain"
, "threshold"
, "decision_type"
, "default_left"
, "internal_value"
, "internal_count"
)

# Traverse tree function
pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {

if (is.null(env)) {
# Setup initial default data.table with default types
env <- new.env(parent = emptyenv())
env$single_tree_dt <- data.table::data.table(
env$single_tree_dt <- list()
env$single_tree_dt[[1L]] <- data.table::data.table(
tree_index = integer(0L)
, depth = integer(0L)
, split_index = integer(0L)
Expand Down Expand Up @@ -127,19 +138,10 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
if (!is.null(tree_node_leaf$split_index)) {

# update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("split_index",
"split_feature",
"split_gain",
"threshold",
"decision_type",
"default_left",
"internal_value",
"internal_count")],
"depth" = current_depth,
"node_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)
env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c(
tree_node_leaf[tree_info_cols]
, list("depth" = current_depth, "node_parent" = parent_index)
)

# Traverse tree again both left and right
pre_order_traversal(
Expand All @@ -154,31 +156,27 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
, current_depth = current_depth + 1L
, parent_index = tree_node_leaf$split_index
)

} else if (!is.null(tree_node_leaf$leaf_index)) {

# update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("leaf_index",
"leaf_value",
"leaf_count")],
"depth" = current_depth,
"leaf_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)

# update list
env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c(
tree_node_leaf[c("leaf_index", "leaf_value", "leaf_count")]
, list("depth" = current_depth, "leaf_parent" = parent_index)
)
}

}
return(env$single_tree_dt)
}

# Traverse structure
single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
# Traverse structure and rowbind everything
single_tree_dt <- data.table::rbindlist(
pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
, use.names = TRUE
, fill = TRUE
)

# Store index
single_tree_dt[, tree_index := lgb_tree$tree_index]

return(single_tree_dt)

}
158 changes: 158 additions & 0 deletions R-package/tests/testthat/test_lgb.model.dt.tree.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
NROUNDS <- 10L
MAX_DEPTH <- 3L
N <- nrow(iris)
X <- data.matrix(iris[2L:4L])
FEAT <- colnames(X)
NCLASS <- nlevels(iris[, 5L])

model_reg <- lgb.train(
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
)
, data = lgb.Dataset(X, label = iris[, 1L])
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)

model_binary <- lgb.train(
params = list(
objective = "binary"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
)
, data = lgb.Dataset(X, label = iris[, 5L] == "setosa")
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)

model_multiclass <- lgb.train(
params = list(
objective = "multiclass"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
, num_classes = NCLASS
)
, data = lgb.Dataset(X, label = as.integer(iris[, 5L]) - 1L)
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)

model_rank <- lgb.train(
params = list(
objective = "lambdarank"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
, lambdarank_truncation_level = 3L
)
, data = lgb.Dataset(
X
, label = as.integer(iris[, 1L] > 5.8)
, group = rep(10L, times = 15L)
)
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)

models <- list(
reg = model_reg
, bin = model_binary
, multi = model_multiclass
, rank = model_rank
)

for (model_name in names(models)) {
model <- models[[model_name]]
expected_n_trees <- NROUNDS
if (model_name == "multi") {
expected_n_trees <- NROUNDS * NCLASS
}
df <- as.data.frame(lgb.model.dt.tree(model))
df_list <- split(df, f = df$tree_index, drop = TRUE)

df_leaf <- df[!is.na(df$leaf_index), ]
df_internal <- df[is.na(df$leaf_index), ]

test_that("lgb.model.dt.tree() returns the right number of trees", {
expect_equal(length(unique(df$tree_index)), expected_n_trees)
})

test_that("num_iteration can return less trees", {
expect_equal(
length(unique(lgb.model.dt.tree(model, num_iteration = 2L)$tree_index))
, 2L * (if (model_name == "multi") NCLASS else 1L)
)
})

test_that("Tree index from lgb.model.dt.tree() is in 0:(NROUNS-1)", {
expect_equal(unique(df$tree_index), (0L:(expected_n_trees - 1L)))
})

test_that("Depth calculated from lgb.model.dt.tree() respects max.depth", {
expect_true(max(df$depth) <= MAX_DEPTH)
})

test_that("Each tree from lgb.model.dt.tree() has single root node", {
expect_equal(
unname(sapply(df_list, function(df) sum(df$depth == 0L)))
, rep(1L, expected_n_trees)
)
})

test_that("Each tree from lgb.model.dt.tree() has two depth 1 nodes", {
expect_equal(
unname(sapply(df_list, function(df) sum(df$depth == 1L)))
, rep(2L, expected_n_trees)
)
})

test_that("leaves from lgb.model.dt.tree() do not have split info", {
internal_node_cols <- c(
"split_index"
, "split_feature"
, "split_gain"
, "threshold"
, "decision_type"
, "default_left"
, "internal_value"
, "internal_count"
)
expect_true(all(is.na(df_leaf[internal_node_cols])))
})

test_that("leaves from lgb.model.dt.tree() have valid leaf info", {
expect_true(all(df_leaf$leaf_index %in% 0L:(2.0^MAX_DEPTH - 1.0)))
expect_true(all(is.finite(df_leaf$leaf_value)))
expect_true(all(df_leaf$leaf_count > 0L & df_leaf$leaf_count <= N))
})

test_that("non-leaves from lgb.model.dt.tree() do not have leaf info", {
leaf_node_cols <- c(
"leaf_index", "leaf_parent", "leaf_value", "leaf_count"
)
expect_true(all(is.na(df_internal[leaf_node_cols])))
})

test_that("non-leaves from lgb.model.dt.tree() have valid split info", {
expect_true(
all(
sapply(
split(df_internal, df_internal$tree_index),
function(x) all(x$split_index %in% 0L:(nrow(x) - 1L))
)
)
)

expect_true(all(df_internal$split_feature %in% FEAT))

num_cols <- c("split_gain", "threshold", "internal_value")
expect_true(all(is.finite(unlist(df_internal[, num_cols]))))

# range of decision type?
expect_true(all(df_internal$default_left %in% c(TRUE, FALSE)))

counts <- df_internal$internal_count
expect_true(all(counts > 1L & counts <= N))
})
}

0 comments on commit 5dfe716

Please sign in to comment.