Skip to content

Commit

Permalink
rename class method train -> run
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 19, 2014
1 parent 19030a5 commit 751da4e
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
/**
* Classification and regression using gradient-boosted decision trees.
*/
public final class JavaGradientBoostedTrees {
public final class JavaGradientBoostedTreesRunner {

private static void usage() {
System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
" <Classification/Regression>");
System.exit(-1);
}
Expand All @@ -55,7 +55,7 @@ public static void main(String[] args) {
if (args.length > 2) {
usage();
}
SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
JavaSparkContext sc = new JavaSparkContext(sparkConf);

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
Expand All @@ -73,10 +73,10 @@ public static void main(String[] args) {
return p.label();
}
}).countByValue().size();
boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);

// Train a GradientBoosting model for classification.
final TreeEnsembleModel model = GradientBoostedTrees.trainClassifier(data, boostingStrategy);
final TreeEnsembleModel model = GradientBoostedTrees.train(data, boostingStrategy);

// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
Expand All @@ -95,7 +95,7 @@ public static void main(String[] args) {
System.out.println("Learned classification tree model:\n" + model);
} else if (algo.equals("Regression")) {
// Train a GradientBoosting model for classification.
final TreeEnsembleModel model = GradientBoostedTrees.trainRegressor(data, boostingStrategy);
final TreeEnsembleModel model = GradientBoostedTrees.train(data, boostingStrategy);

// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import org.apache.spark.util.Utils
/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
* ./bin/run-example mllib.GradientBoostedTreesRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*
* Note: This script treats all features as real-valued (not categorical).
* To include categorical features, modify categoricalFeaturesInfo.
*/
object GradientBoostedTrees {
object GradientBoostedTreesRunner {

case class Params(
input: String = null,
Expand Down Expand Up @@ -93,10 +93,10 @@ object GradientBoostedTrees {

def run(params: Params) {

val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
val sc = new SparkContext(conf)

println(s"GradientBoostedTrees with parameters:\n$params")
println(s"GradientBoostedTreesRunner with parameters:\n$params")

// Load training and test data and cache it.
val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
Expand Down Expand Up @@ -127,7 +127,7 @@ object GradientBoostedTrees {
println(s"Test accuracy = $testAccuracy")
} else if (params.algo == "Regression") {
val startTime = System.nanoTime()
val model = GradientBoostedTrees.trainRegressor(training, boostingStrategy)
val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.train(input)
val rfModel = rf.run(input)
rfModel.trees(0)
}

/**
* Trains a decision tree model over an RDD. This is deprecated because it hides the static
* methods with the same name in Java.
*/
@deprecated("Please use DecisionTree.run instead.", "1.2.0")
def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input)
}

object DecisionTree extends Serializable with Logging {
Expand All @@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging {
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
new DecisionTree(strategy).train(input)
new DecisionTree(strategy).run(input)
}

/**
Expand All @@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth)
new DecisionTree(strategy).train(input)
new DecisionTree(strategy).run(input)
}

/**
Expand Down Expand Up @@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging {
maxDepth: Int,
numClassesForClassification: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
new DecisionTree(strategy).train(input)
new DecisionTree(strategy).run(input)
}

/**
Expand Down Expand Up @@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging {
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).train(input)
new DecisionTree(strategy).run(input)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ import org.apache.spark.storage.StorageLevel
* @param boostingStrategy Parameters for the gradient boosting algorithm.
*/
@Experimental
class GradientBoostedTrees (
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
extends Serializable with Logging {

/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return WeightedEnsembleModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): TreeEnsembleModel = {
def run(input: RDD[LabeledPoint]): TreeEnsembleModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
Expand All @@ -67,6 +67,13 @@ class GradientBoostedTrees (
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
}

/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
*/
def run(input: JavaRDD[LabeledPoint]): TreeEnsembleModel = {
run(input.rdd)
}
}


Expand All @@ -84,7 +91,7 @@ object GradientBoostedTrees extends Logging {
def train(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
new GradientBoostedTrees(boostingStrategy).train(input)
new GradientBoostedTrees(boostingStrategy).run(input)
}

/**
Expand Down Expand Up @@ -137,7 +144,7 @@ object GradientBoostedTrees extends Logging {

// Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(ensembleStrategy).train(data)
val firstTreeModel = new DecisionTree(ensembleStrategy).run(data)
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = 1.0
val startingModel = new TreeEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, Sum)
Expand All @@ -155,7 +162,7 @@ object GradientBoostedTrees extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val model = new DecisionTree(ensembleStrategy).train(data)
val model = new DecisionTree(ensembleStrategy).run(data)
timer.stop(s"building tree $m")
// Create partial model
baseLearners(m) = model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private class RandomForest (
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return WeightedEnsembleModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): TreeEnsembleModel = {
def run(input: RDD[LabeledPoint]): TreeEnsembleModel = {

val timer = new TimeTracker()

Expand Down Expand Up @@ -245,7 +245,7 @@ object RandomForest extends Serializable with Logging {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
rf.train(input)
rf.run(input)
}

/**
Expand Down Expand Up @@ -333,7 +333,7 @@ object RandomForest extends Serializable with Logging {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
rf.train(input)
rf.run(input)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void runDTUsingConstructor() {
maxBins, categoricalFeaturesInfo);

DecisionTree learner = new DecisionTree(strategy);
DecisionTreeModel model = learner.train(rdd.rdd());
DecisionTreeModel model = learner.run(rdd.rdd());

int numCorrect = validatePrediction(arr, model);
Assert.assertTrue(numCorrect == rdd.count());
Expand Down

0 comments on commit 751da4e

Please sign in to comment.