Skip to content

Commit

Permalink
add min info gain and min instances per node parameters in decision tree
Browse files Browse the repository at this point in the history
  • Loading branch information
qiping.lqp committed Sep 9, 2014
1 parent 7db5339 commit ac42378
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -738,12 +738,15 @@ object DecisionTree extends Serializable with Logging {
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count

val totalCount = leftCount + rightCount
if (totalCount == 0) {
// Return arbitrary prediction.
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
return InformationGainStats.invalidInformationGainStats
}

val totalCount = leftCount + rightCount

val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
// impurity of parent node
Expand All @@ -763,6 +766,9 @@ object DecisionTree extends Serializable with Logging {
val rightWeight = rightCount / totalCount.toDouble

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
if (gain < metadata.minInfoGain) {
return InformationGainStats.invalidInformationGainStats
}

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
}
Expand Down Expand Up @@ -807,6 +813,9 @@ object DecisionTree extends Serializable with Logging {
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
}
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
Expand All @@ -820,6 +829,9 @@ object DecisionTree extends Serializable with Logging {
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
}
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
Expand Down Expand Up @@ -891,6 +903,9 @@ object DecisionTree extends Serializable with Logging {
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
}
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class Strategy (
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val minInstancesPerNode: Int = 0,
val minInfoGain: Double = 0.0,
val maxMemoryInMB: Int = 128) extends Serializable {

if (algo == Classification) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
val unorderedFeatures: Set[Int],
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {
val quantileStrategy: QuantileStrategy,
val minInstancesPerNode: Int,
val minInfoGain: Double) extends Serializable {

def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)

Expand Down Expand Up @@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy)
strategy.impurity, strategy.quantileCalculationStrategy,
strategy.minInstancesPerNode, strategy.minInfoGain)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ class InformationGainStats(
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
}
}


private[tree] object InformationGainStats {
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, 0.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
import org.apache.spark.mllib.tree.configuration.FeatureType
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -66,3 +68,7 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType)
private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())


private[tree] object Split {
val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List())
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
import org.apache.spark.mllib.tree.model.{Split, DecisionTreeModel, Node}
import org.apache.spark.mllib.util.LocalSparkContext


Expand Down Expand Up @@ -684,6 +684,45 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
validateClassifier(model, arr, 0.6)
}

test("split must satisfy min instances per node requirements") {
val arr = new Array[LabeledPoint](3)
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))

val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, minInstancesPerNode = 4)

val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
assert(model.topNode.predict == 0.0)
assert(model.topNode.split.get == Split.noSplit)
val predicts = input.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
}
}

test("split must satisfy min info gain requirements") {
val arr = new Array[LabeledPoint](3)
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))

val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, minInfoGain = 1.0)

val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
assert(model.topNode.predict == 0.0)
assert(model.topNode.split.get == Split.noSplit)
val predicts = input.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
}
}
}

object DecisionTreeSuite {
Expand Down

0 comments on commit ac42378

Please sign in to comment.