Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc #3374

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -27,18 +27,18 @@
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoosting;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;

/**
* Classification and regression using gradient-boosted decision trees.
*/
public final class JavaGradientBoostedTrees {
public final class JavaGradientBoostedTreesRunner {

private static void usage() {
System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
" <Classification/Regression>");
System.exit(-1);
}
Expand All @@ -55,7 +55,7 @@ public static void main(String[] args) {
if (args.length > 2) {
usage();
}
SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
JavaSparkContext sc = new JavaSparkContext(sparkConf);

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
Expand All @@ -64,7 +64,7 @@ public static void main(String[] args) {
// Note: All features are treated as continuous.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
boostingStrategy.setNumIterations(10);
boostingStrategy.weakLearnerParams().setMaxDepth(5);
boostingStrategy.treeStrategy().setMaxDepth(5);

if (algo.equals("Classification")) {
// Compute the number of classes from the data.
Expand All @@ -73,10 +73,10 @@ public static void main(String[] args) {
return p.label();
}
}).countByValue().size();
boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);

// Train a GradientBoosting model for classification.
final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);

// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
Expand All @@ -95,7 +95,7 @@ public static void main(String[] args) {
System.out.println("Learned classification tree model:\n" + model);
} else if (algo.equals("Regression")) {
// Train a GradientBoosting model for classification.
final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);

// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
Expand Down
Expand Up @@ -22,11 +22,11 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -349,24 +349,14 @@ object DecisionTreeRunner {
sc.stop()
}

/**
* Calculates the mean squared error for regression.
*/
private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
}.mean()
}

/**
* Calculates the mean squared error for regression.
*/
private[mllib] def meanSquaredError(
tree: WeightedEnsembleModel,
model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
val err = model.predict(y.features) - y.label
err * err
}.mean()
}
Expand Down
Expand Up @@ -21,21 +21,21 @@ import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.tree.GradientBoosting
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.util.Utils

/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
* ./bin/run-example mllib.GradientBoostedTreesRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*
* Note: This script treats all features as real-valued (not categorical).
* To include categorical features, modify categoricalFeaturesInfo.
*/
object GradientBoostedTrees {
object GradientBoostedTreesRunner {

case class Params(
input: String = null,
Expand Down Expand Up @@ -93,24 +93,24 @@ object GradientBoostedTrees {

def run(params: Params) {

val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
val sc = new SparkContext(conf)

println(s"GradientBoostedTrees with parameters:\n$params")
println(s"GradientBoostedTreesRunner with parameters:\n$params")

// Load training and test data and cache it.
val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)

val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
boostingStrategy.numClassesForClassification = numClasses
boostingStrategy.treeStrategy.numClassesForClassification = numClasses
boostingStrategy.numIterations = params.numIterations
boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
boostingStrategy.treeStrategy.maxDepth = params.maxDepth

val randomSeed = Utils.random.nextInt()
if (params.algo == "Classification") {
val startTime = System.nanoTime()
val model = GradientBoosting.trainClassifier(training, boostingStrategy)
val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
Expand All @@ -127,7 +127,7 @@ object GradientBoostedTrees {
println(s"Test accuracy = $testAccuracy")
} else if (params.algo == "Regression") {
val startTime = System.nanoTime()
val model = GradientBoosting.trainRegressor(training, boostingStrategy)
val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
Expand Down
Expand Up @@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.train(input)
rfModel.weakHypotheses(0)
val rfModel = rf.run(input)
rfModel.trees(0)
}

/**
* Trains a decision tree model over an RDD. This is deprecated because it hides the static
* methods with the same name in Java.
*/
@deprecated("Please use DecisionTree.run instead.", "1.2.0")
def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input)
}

object DecisionTree extends Serializable with Logging {
Expand All @@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging {
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
new DecisionTree(strategy).train(input)
new DecisionTree(strategy).run(input)
}

/**
Expand All @@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth)
new DecisionTree(strategy).train(input)
new DecisionTree(strategy).run(input)
}

/**
Expand Down Expand Up @@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging {
maxDepth: Int,
numClassesForClassification: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
new DecisionTree(strategy).train(input)
new DecisionTree(strategy).run(input)
}

/**
Expand Down Expand Up @@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging {
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).train(input)
new DecisionTree(strategy).run(input)
}

/**
Expand Down