Skip to content

Commit

Permalink
Merge branch 'decisiontree-bugfix' into decisiontree-python-new
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Aug 1, 2014
2 parents 6622247 + 2b20c61 commit 188cb0d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,16 @@ object DecisionTreeRunner {
println(s"numClasses = $numClasses.")
println(s"Per-class example fractions, counts:")
println(s"Class\tFrac\tCount")
sortedClasses.foreach { c => {
sortedClasses.foreach { c =>
val frac = classCounts(c) / numExamples.toDouble
println(s"$c\t$frac\t${classCounts(c)}")
}}
}
(examples, numClasses)
}
case Regression => {
case Regression =>
(origExamples, 0)
}
case _ => {
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
}

// Split into training, test.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel which can be used for prediction
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {

Expand Down Expand Up @@ -206,7 +206,7 @@ object DecisionTree extends Serializable with Logging {
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of algorithm (classification, regression, etc.), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
* @return DecisionTreeModel which can be used for prediction
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
new DecisionTree(strategy).train(input)
Expand All @@ -225,7 +225,7 @@ object DecisionTree extends Serializable with Logging {
* @param impurity impurity criterion used for information gain calculation
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @return DecisionTreeModel which can be used for prediction
* @return DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
Expand All @@ -250,7 +250,7 @@ object DecisionTree extends Serializable with Logging {
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @return DecisionTreeModel which can be used for prediction
* @return DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
Expand Down Expand Up @@ -284,7 +284,7 @@ object DecisionTree extends Serializable with Logging {
* an entry (n -> k) implies the feature n is categorical with k
* categories 0, 1, 2, ... , k-1. It's important to note that
* features are zero-indexed.
* @return DecisionTreeModel which can be used for prediction
* @return DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
Expand Down

0 comments on commit 188cb0d

Please sign in to comment.