Skip to content

Commit

Permalink
add RandomForestModel and GradientBoostedTreesModel, hide CombiningSt…
Browse files Browse the repository at this point in the history
…rategy
  • Loading branch information
mengxr committed Nov 20, 2014
1 parent ea4c467 commit 4aae3b7
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 122 deletions.
Expand Up @@ -29,7 +29,7 @@
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.model.TreeEnsembleModel;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;

/**
Expand Down Expand Up @@ -76,7 +76,7 @@ public static void main(String[] args) {
boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);

// Train a GradientBoosting model for classification.
final TreeEnsembleModel model = GradientBoostedTrees.train(data, boostingStrategy);
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);

// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
Expand All @@ -95,7 +95,7 @@ public static void main(String[] args) {
System.out.println("Learned classification tree model:\n" + model);
} else if (algo.equals("Regression")) {
// Train a GradientBoosting model for classification.
final TreeEnsembleModel model = GradientBoostedTrees.train(data, boostingStrategy);
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);

// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
Expand Down
Expand Up @@ -22,11 +22,11 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.model.{TreeEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -349,24 +349,14 @@ object DecisionTreeRunner {
sc.stop()
}

/**
* Calculates the mean squared error for regression.
*/
private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
}.mean()
}

/**
* Calculates the mean squared error for regression.
*/
private[mllib] def meanSquaredError(
tree: TreeEnsembleModel,
model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
val err = model.predict(y.features) - y.label
err * err
}.mean()
}
Expand Down
Expand Up @@ -23,9 +23,8 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -53,9 +52,9 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return WeightedEnsembleModel that can be used for prediction
* @return a gradient boosted trees model that can be used for prediction
*/
def run(input: RDD[LabeledPoint]): TreeEnsembleModel = {
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
Expand All @@ -71,7 +70,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
*/
def run(input: JavaRDD[LabeledPoint]): TreeEnsembleModel = {
def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
run(input.rdd)
}
}
Expand All @@ -86,11 +85,11 @@ object GradientBoostedTrees extends Logging {
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
* @return a tree ensemble model that can be used for prediction
* @return a gradient boosted trees model that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
new GradientBoostedTrees(boostingStrategy).run(input)
}

Expand All @@ -99,19 +98,19 @@ object GradientBoostedTrees extends Logging {
*/
def train(
input: JavaRDD[LabeledPoint],
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
train(input.rdd, boostingStrategy)
}

/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param boostingStrategy boosting parameters
* @return a tree ensemble model that can be used for prediction
* @return a gradient boosted trees model that can be used for prediction
*/
private def boost(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {

val timer = new TimeTracker()
timer.start("total")
Expand Down Expand Up @@ -148,7 +147,7 @@ object GradientBoostedTrees extends Logging {
val firstTreeModel = new DecisionTree(ensembleStrategy).run(data)
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = 1.0
val startingModel = new TreeEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, Sum)
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
logDebug("error of gbt = " + loss.computeError(startingModel, input))
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
Expand All @@ -172,8 +171,8 @@ object GradientBoostedTrees extends Logging {
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new TreeEnsembleModel(baseLearners.slice(0, m + 1),
baseLearnerWeights.slice(0, m + 1), Regression, Sum)
val partialModel = new GradientBoostedTreesModel(
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
logDebug("error of gbt = " + loss.computeError(partialModel, input))
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
Expand All @@ -186,6 +185,7 @@ object GradientBoostedTrees extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

new TreeEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.treeStrategy.algo, Sum)
new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
}
}
34 changes: 16 additions & 18 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
Expand Up @@ -17,18 +17,18 @@

package org.apache.spark.mllib.tree

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.JavaConverters._

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache,
TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -79,9 +79,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return WeightedEnsembleModel that can be used for prediction
* @return a random forest model that can be used for prediction
*/
def run(input: RDD[LabeledPoint]): TreeEnsembleModel = {
def run(input: RDD[LabeledPoint]): RandomForestModel = {

val timer = new TimeTracker()

Expand Down Expand Up @@ -212,8 +212,7 @@ private class RandomForest (
}

val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
val treeWeights = Array.fill[Double](numTrees)(1.0)
new TreeEnsembleModel(trees, treeWeights, strategy.algo, Average)
new RandomForestModel(strategy.algo, trees)
}

}
Expand All @@ -234,14 +233,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return WeightedEnsembleModel that can be used for prediction
* @return a random forest model that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Int): TreeEnsembleModel = {
seed: Int): RandomForestModel = {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
Expand Down Expand Up @@ -272,7 +271,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return WeightedEnsembleModel that can be used for prediction
* @return a random forest model that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
Expand All @@ -283,7 +282,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
seed: Int = Utils.random.nextInt()): TreeEnsembleModel = {
seed: Int = Utils.random.nextInt()): RandomForestModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
Expand All @@ -302,7 +301,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
seed: Int): TreeEnsembleModel = {
seed: Int): RandomForestModel = {
trainClassifier(input.rdd, numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
Expand All @@ -322,14 +321,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return WeightedEnsembleModel that can be used for prediction
* @return a random forest model that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Int): TreeEnsembleModel = {
seed: Int): RandomForestModel = {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
Expand Down Expand Up @@ -359,7 +358,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return WeightedEnsembleModel that can be used for prediction
* @return a random forest model that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
Expand All @@ -369,7 +368,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
seed: Int = Utils.random.nextInt()): TreeEnsembleModel = {
seed: Int = Utils.random.nextInt()): RandomForestModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Regression, impurityType, maxDepth,
0, maxBins, Sort, categoricalFeaturesInfo)
Expand Down Expand Up @@ -479,5 +478,4 @@ object RandomForest extends Serializable with Logging {
3 * totalBins
}
}

}
Expand Up @@ -17,14 +17,10 @@

package org.apache.spark.mllib.tree.configuration

import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
* Enum to select ensemble combining strategy for base learners
*/
@DeveloperApi
object EnsembleCombiningStrategy extends Enumeration {
private[tree] object EnsembleCombiningStrategy extends Enumeration {
type EnsembleCombiningStrategy = Value
val Sum, Average = Value
val Average, Sum, Vote = Value
}
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.mllib.tree.model

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector

/**
* :: Experimental ::
Expand Down

0 comments on commit 4aae3b7

Please sign in to comment.