From 2b20c6151bab8a2ee218b851f40d54133f9807a2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 31 Jul 2014 13:39:43 -0700 Subject: [PATCH] Small doc and style updates --- .../spark/examples/mllib/DecisionTreeRunner.scala | 10 ++++------ .../org/apache/spark/mllib/tree/DecisionTree.scala | 10 +++++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 9736b624e415d..cf3d2cca81ff6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -144,18 +144,16 @@ object DecisionTreeRunner { println(s"numClasses = $numClasses.") println(s"Per-class example fractions, counts:") println(s"Class\tFrac\tCount") - sortedClasses.foreach { c => { + sortedClasses.foreach { c => val frac = classCounts(c) / numExamples.toDouble println(s"$c\t$frac\t${classCounts(c)}") - }} + } (examples, numClasses) } - case Regression => { + case Regression => (origExamples, 0) - } - case _ => { + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") - } } // Split into training, test. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 31574182094cb..7d123dd6ae996 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -43,7 +43,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { @@ -206,7 +206,7 @@ object DecisionTree extends Serializable with Logging { * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).train(input) @@ -225,7 +225,7 @@ object DecisionTree extends Serializable with Logging { * @param impurity impurity criterion used for information gain calculation * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -250,7 +250,7 @@ object DecisionTree extends Serializable with Logging { * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -284,7 +284,7 @@ object DecisionTree extends Serializable with Logging { * an entry (n -> k) implies the feature n is categorical with k * categories 0, 1, 2, ... , k-1. It's important to note that * features are zero-indexed. - * @return DecisionTreeModel which can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint],