Skip to content

Commit ab982e7

Browse files
[R] Redesigned xgboost() interface skeleton (dmlc#10456)
--------- Co-authored-by: Michael Mayer <mayermichael79@gmail.com>
1 parent 17c6430 commit ab982e7

35 files changed

+1997
-242
lines changed

R-package/DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ Suggests:
5757
igraph (>= 1.0.1),
5858
float,
5959
titanic,
60-
RhpcBLASctl
60+
RhpcBLASctl,
61+
survival
6162
Depends:
6263
R (>= 4.3.0)
6364
Imports:

R-package/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ S3method(predict,xgb.Booster)
1313
S3method(print,xgb.Booster)
1414
S3method(print,xgb.DMatrix)
1515
S3method(print,xgb.cv.synchronous)
16+
S3method(print,xgboost)
1617
S3method(setinfo,xgb.Booster)
1718
S3method(setinfo,xgb.DMatrix)
1819
S3method(variable.names,xgb.Booster)

R-package/R/utils.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,40 @@ NVL <- function(x, val) {
3030
return(c('rank:pairwise', 'rank:ndcg', 'rank:map'))
3131
}
3232

33+
.OBJECTIVES_NON_DEFAULT_MODE <- function() {
34+
return(c("reg:logistic", "binary:logitraw", "multi:softmax"))
35+
}
36+
37+
.BINARY_CLASSIF_OBJECTIVES <- function() {
38+
return(c("binary:logistic", "binary:hinge"))
39+
}
40+
41+
.MULTICLASS_CLASSIF_OBJECTIVES <- function() {
42+
return("multi:softprob")
43+
}
44+
45+
.SURVIVAL_RIGHT_CENSORING_OBJECTIVES <- function() { # nolint
46+
return(c("survival:cox", "survival:aft"))
47+
}
48+
49+
.SURVIVAL_ALL_CENSORING_OBJECTIVES <- function() { # nolint
50+
return("survival:aft")
51+
}
52+
53+
.REGRESSION_OBJECTIVES <- function() {
54+
return(c(
55+
"reg:squarederror", "reg:squaredlogerror", "reg:logistic", "reg:pseudohubererror",
56+
"reg:absoluteerror", "reg:quantileerror", "count:poisson", "reg:gamma", "reg:tweedie"
57+
))
58+
}
59+
60+
.MULTI_TARGET_OBJECTIVES <- function() {
61+
return(c(
62+
"reg:squarederror", "reg:squaredlogerror", "reg:logistic", "reg:pseudohubererror",
63+
"reg:quantileerror", "reg:gamma"
64+
))
65+
}
66+
3367

3468
#
3569
# Low-level functions for boosting --------------------------------------------

R-package/R/xgb.Booster.R

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -663,9 +663,8 @@ validate.features <- function(bst, newdata) {
663663
#' data(agaricus.train, package = "xgboost")
664664
#' train <- agaricus.train
665665
#'
666-
#' bst <- xgboost(
667-
#' data = train$data,
668-
#' label = train$label,
666+
#' bst <- xgb.train(
667+
#' data = xgb.DMatrix(train$data, label = train$label),
669668
#' max_depth = 2,
670669
#' eta = 1,
671670
#' nthread = 2,
@@ -767,9 +766,8 @@ xgb.attributes <- function(object) {
767766
#' data.table::setDTthreads(nthread)
768767
#' train <- agaricus.train
769768
#'
770-
#' bst <- xgboost(
771-
#' data = train$data,
772-
#' label = train$label,
769+
#' bst <- xgb.train(
770+
#' data = xgb.DMatrix(train$data, label = train$label),
773771
#' max_depth = 2,
774772
#' eta = 1,
775773
#' nthread = nthread,
@@ -817,9 +815,8 @@ xgb.config <- function(object) {
817815
#' data(agaricus.train, package = "xgboost")
818816
#' train <- agaricus.train
819817
#'
820-
#' bst <- xgboost(
821-
#' data = train$data,
822-
#' label = train$label,
818+
#' bst <- xgb.train(
819+
#' data = xgb.DMatrix(train$data, label = train$label),
823820
#' max_depth = 2,
824821
#' eta = 1,
825822
#' nthread = 2,
@@ -1230,9 +1227,8 @@ xgb.is.same.Booster <- function(obj1, obj2) {
12301227
#' data(agaricus.train, package = "xgboost")
12311228
#' train <- agaricus.train
12321229
#'
1233-
#' bst <- xgboost(
1234-
#' data = train$data,
1235-
#' label = train$label,
1230+
#' bst <- xgb.train(
1231+
#' data = xgb.DMatrix(train$data, label = train$label),
12361232
#' max_depth = 2,
12371233
#' eta = 1,
12381234
#' nthread = 2,

R-package/R/xgb.DMatrix.R

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -853,36 +853,6 @@ xgb.DMatrix.hasinfo <- function(object, info) {
853853
}
854854

855855

856-
# get dmatrix from data, label
857-
# internal helper method
858-
xgb.get.DMatrix <- function(data, label, missing, weight, nthread) {
859-
if (inherits(data, "dgCMatrix") || is.matrix(data)) {
860-
if (is.null(label)) {
861-
stop("label must be provided when data is a matrix")
862-
}
863-
dtrain <- xgb.DMatrix(data, label = label, missing = missing, nthread = nthread)
864-
if (!is.null(weight)) {
865-
setinfo(dtrain, "weight", weight)
866-
}
867-
} else {
868-
if (!is.null(label)) {
869-
warning("xgboost: label will be ignored.")
870-
}
871-
if (is.character(data)) {
872-
data <- path.expand(data)
873-
dtrain <- xgb.DMatrix(data[1])
874-
} else if (inherits(data, "xgb.DMatrix")) {
875-
dtrain <- data
876-
} else if (inherits(data, "data.frame")) {
877-
stop("xgboost doesn't support data.frame as input. Convert it to matrix first.")
878-
} else {
879-
stop("xgboost: invalid input data")
880-
}
881-
}
882-
return(dtrain)
883-
}
884-
885-
886856
#' Dimensions of xgb.DMatrix
887857
#'
888858
#' Returns a vector of numbers of rows and of columns in an \code{xgb.DMatrix}.

R-package/R/xgb.dump.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
#' data(agaricus.test, package='xgboost')
3030
#' train <- agaricus.train
3131
#' test <- agaricus.test
32-
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
33-
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
32+
#' bst <- xgb.train(data = xgb.DMatrix(train$data, label = train$label), max_depth = 2,
33+
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
3434
#' # save the model in file 'xgb.model.dump'
3535
#' dump_path = file.path(tempdir(), 'model.dump')
3636
#' xgb.dump(bst, dump_path, with_stats = TRUE)

R-package/R/xgb.importance.R

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@
4646
#' # binomial classification using "gbtree":
4747
#' data(agaricus.train, package = "xgboost")
4848
#'
49-
#' bst <- xgboost(
50-
#' data = agaricus.train$data,
51-
#' label = agaricus.train$label,
49+
#' bst <- xgb.train(
50+
#' data = xgb.DMatrix(agaricus.train$data, label = agaricus.train$label),
5251
#' max_depth = 2,
5352
#' eta = 1,
5453
#' nthread = 2,
@@ -59,9 +58,8 @@
5958
#' xgb.importance(model = bst)
6059
#'
6160
#' # binomial classification using "gblinear":
62-
#' bst <- xgboost(
63-
#' data = agaricus.train$data,
64-
#' label = agaricus.train$label,
61+
#' bst <- xgb.train(
62+
#' data = xgb.DMatrix(agaricus.train$data, label = agaricus.train$label),
6563
#' booster = "gblinear",
6664
#' eta = 0.3,
6765
#' nthread = 1,
@@ -73,9 +71,11 @@
7371
#' # multiclass classification using "gbtree":
7472
#' nclass <- 3
7573
#' nrounds <- 10
76-
#' mbst <- xgboost(
77-
#' data = as.matrix(iris[, -5]),
78-
#' label = as.numeric(iris$Species) - 1,
74+
#' mbst <- xgb.train(
75+
#' data = xgb.DMatrix(
76+
#' as.matrix(iris[, -5]),
77+
#' label = as.numeric(iris$Species) - 1
78+
#' ),
7979
#' max_depth = 3,
8080
#' eta = 0.2,
8181
#' nthread = 2,
@@ -99,9 +99,11 @@
9999
#' )
100100
#'
101101
#' # multiclass classification using "gblinear":
102-
#' mbst <- xgboost(
103-
#' data = scale(as.matrix(iris[, -5])),
104-
#' label = as.numeric(iris$Species) - 1,
102+
#' mbst <- xgb.train(
103+
#' data = xgb.DMatrix(
104+
#' scale(as.matrix(iris[, -5])),
105+
#' label = as.numeric(iris$Species) - 1
106+
#' ),
105107
#' booster = "gblinear",
106108
#' eta = 0.2,
107109
#' nthread = 1,

R-package/R/xgb.model.dt.tree.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@
4343
#' nthread <- 1
4444
#' data.table::setDTthreads(nthread)
4545
#'
46-
#' bst <- xgboost(
47-
#' data = agaricus.train$data,
48-
#' label = agaricus.train$label,
46+
#' bst <- xgb.train(
47+
#' data = xgb.DMatrix(agaricus.train$data, label = agaricus.train$label),
4948
#' max_depth = 2,
5049
#' eta = 1,
5150
#' nthread = nthread,

R-package/R/xgb.plot.deepness.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@
4848
#' data.table::setDTthreads(nthread)
4949
#'
5050
#' ## Change max_depth to a higher number to get a more significant result
51-
#' bst <- xgboost(
52-
#' data = agaricus.train$data,
53-
#' label = agaricus.train$label,
51+
#' bst <- xgb.train(
52+
#' data = xgb.DMatrix(agaricus.train$data, label = agaricus.train$label),
5453
#' max_depth = 6,
5554
#' nthread = nthread,
5655
#' nrounds = 50,

R-package/R/xgb.plot.importance.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@
5151
#' nthread <- 2
5252
#' data.table::setDTthreads(nthread)
5353
#'
54-
#' bst <- xgboost(
55-
#' data = agaricus.train$data,
56-
#' label = agaricus.train$label,
54+
#' bst <- xgb.train(
55+
#' data = xgb.DMatrix(agaricus.train$data, label = agaricus.train$label),
5756
#' max_depth = 3,
5857
#' eta = 1,
5958
#' nthread = nthread,

0 commit comments

Comments
 (0)