Skip to content

Commit

Permalink
fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 19, 2014
1 parent 751da4e commit ea4c467
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ object GradientBoostedTrees extends Logging {
val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
// Prepare strategy for tree ensembles. Tree ensembles use regression with variance impurity.
val ensembleStrategy = boostingStrategy.treeStrategy.copy
ensembleStrategy.algo = Regression
ensembleStrategy.impurity = impurity.Variance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration
import org.apache.spark.annotation.DeveloperApi

/**
* :: Experimental ::
* :: DeveloperApi ::
* Enum to select ensemble combining strategy for base learners
*/
@DeveloperApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ object EnsembleTestHelper {
val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
prediction - label
}
println(predictions.zip(input.map(_.label)).toSeq)
val metric = metricName match {
case "mse" =>
errors.map(err => err * err).sum / errors.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
val rdd = sc.parallelize(arr)
val rdd = sc.parallelize(arr, 2)

val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
val boostingStrategy =
new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
new BoostingStrategy(treeStrategy, AbsoluteError, numIterations, learningRate)

val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)

assert(gbt.trees.size === numIterations)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.03)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.85, "mae")

val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
val dt = DecisionTree.train(remappedInput, treeStrategy)
Expand All @@ -84,7 +84,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
val rdd = sc.parallelize(arr)
val rdd = sc.parallelize(arr, 2)

val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2,
numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty,
Expand Down

0 comments on commit ea4c467

Please sign in to comment.