diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 40a806c41bad0..82279be6fbe77 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -52,12 +52,14 @@ summary.treeEnsemble <- function(model) { numFeatures <- callJMethod(jobj, "numFeatures") features <- callJMethod(jobj, "features") featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") numTrees <- callJMethod(jobj, "numTrees") treeWeights <- callJMethod(jobj, "treeWeights") list(formula = formula, numFeatures = numFeatures, features = features, featureImportances = featureImportances, + maxDepth = maxDepth, numTrees = numTrees, treeWeights = treeWeights, jobj = jobj) @@ -70,6 +72,7 @@ print.summary.treeEnsemble <- function(x) { cat("\nNumber of features: ", x$numFeatures) cat("\nFeatures: ", unlist(x$features)) cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) cat("\nNumber of trees: ", x$numTrees) cat("\nTree weights: ", unlist(x$treeWeights)) @@ -197,8 +200,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method #' @export @@ -403,8 +406,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index e6fda251ebea2..e0802a9b02d13 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -39,6 +39,7 @@ test_that("spark.gbt", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_equal(stats$formula, "Employed ~ .") expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) @@ -53,6 +54,7 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, stats2$numFeatures) expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$numTrees, stats2$numTrees) expect_equal(stats$treeWeights, stats2$treeWeights) @@ -66,6 +68,7 @@ test_that("spark.gbt", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) predictions <- collect(predict(model, data))$prediction @@ -93,6 +96,7 @@ test_that("spark.gbt", { expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(s$numFeatures, 5) expect_equal(s$numTrees, 20) + expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), @@ -116,6 +120,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numTrees, 1) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) @@ -129,6 +134,7 @@ test_that("spark.randomForest", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") write.ml(model, modelPath) @@ -141,6 +147,7 @@ test_that("spark.randomForest", { expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) @@ -153,6 +160,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) # Test string prediction values @@ -187,6 +195,8 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + # Test numeric prediction values predictions <- collect(predict(model, data))$prediction expect_equal(length(grep("1.0", predictions)), 50) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index aacb41ee2659b..c07eadb30a4d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class GBTClassifierWrapper private ( lazy val featureImportances: Vector = gbtcModel.featureImportances lazy val numTrees: Int = gbtcModel.getNumTrees lazy val treeWeights: Array[Double] = gbtcModel.treeWeights + lazy val maxDepth: Int = gbtcModel.getMaxDepth def summary: String = gbtcModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala index 585077588eb9b..b568d7859221f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class GBTRegressorWrapper private ( lazy val featureImportances: Vector = gbtrModel.featureImportances lazy val numTrees: Int = gbtrModel.getNumTrees lazy val treeWeights: Array[Double] = gbtrModel.treeWeights + lazy val maxDepth: Int = gbtrModel.getMaxDepth def summary: String = gbtrModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 366f375b58582..8a83d4e980f7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class RandomForestClassifierWrapper private ( lazy val featureImportances: Vector = rfcModel.featureImportances lazy val numTrees: Int = rfcModel.getNumTrees lazy val treeWeights: Array[Double] = rfcModel.treeWeights + lazy val maxDepth: Int = rfcModel.getMaxDepth def summary: String = rfcModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala index 4b9a3a731da9b..038bd79c7022b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class RandomForestRegressorWrapper private ( lazy val featureImportances: Vector = rfrModel.featureImportances lazy val numTrees: Int = rfrModel.getNumTrees lazy val treeWeights: Array[Double] = rfrModel.treeWeights + lazy val maxDepth: Int = rfrModel.getMaxDepth def summary: String = rfrModel.toDebugString