Skip to content

Commit

Permalink
remove noSplit and set Predict private to tree
Browse files Browse the repository at this point in the history
  • Loading branch information
chouqin committed Sep 10, 2014
1 parent d593ec7 commit 0278a11
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,7 @@ object DecisionTree extends Serializable with Logging {
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
} else {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
}
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
Expand All @@ -853,11 +849,7 @@ object DecisionTree extends Serializable with Logging {
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
} else {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
}
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
Expand Down Expand Up @@ -929,15 +921,11 @@ object DecisionTree extends Serializable with Logging {
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
} else {
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.annotation.DeveloperApi
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class Predict(
private[tree] class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable{

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,3 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType)
private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

private[tree] object Split {
/**
* A [[org.apache.spark.mllib.tree.model.Split]] object to denote that
* we can't find a valid split that satisfies minimum info gain
* or minimum number of instances per node.
*/
val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List())
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{Split, DecisionTreeModel, Node}
import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
import org.apache.spark.mllib.util.LocalSparkContext

class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Expand Down Expand Up @@ -689,11 +689,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
assert(model.topNode.predict == 0.0)
assert(model.topNode.split.get == Split.noSplit)
val predicts = input.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
}

// test for findBestSplits when no valid split can be found
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)

assert(bestSplits.length === 1)
val bestInfoStats = bestSplits(0)._2
assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
}

test("split must satisfy min info gain requirements") {
Expand All @@ -709,11 +719,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
assert(model.topNode.predict == 0.0)
assert(model.topNode.split.get == Split.noSplit)
val predicts = input.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
}

// test for findBestSplits when no valid split can be found
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)

assert(bestSplits.length === 1)
val bestInfoStats = bestSplits(0)._2
assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
}
}

Expand Down

0 comments on commit 0278a11

Please sign in to comment.