Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
qiping.lqp committed Sep 9, 2014
1 parent 845c6fa commit e72c7e4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ object DecisionTree extends Serializable with Logging {
val rightCount = rightImpurityCalculator.count

// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
return InformationGainStats.invalidInformationGainStats
Expand All @@ -764,13 +764,23 @@ object DecisionTree extends Serializable with Logging {
val rightWeight = rightCount / totalCount.toDouble

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity

// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
return InformationGainStats.invalidInformationGainStats
}

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
}

/**
* Calculate predict value for current node, given stats of any split.
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a node
* @return predict value for current node
*/
private def calculatePredict(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator): Predict = {
Expand Down Expand Up @@ -799,6 +809,7 @@ object DecisionTree extends Serializable with Logging {

logDebug("node impurity = " + nodeImpurity)

// calculate predict only once
var predict: Option[Predict] = None

// For each (feature, split), calculate the gain, and select the best (feature, split).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,10 @@ class InformationGainStats(


private[tree] object InformationGainStats {
/**
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
* denote that current split doesn't satisfies minimum info gain or
* minimum number of instances per node.
*/
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType)
private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())


private[tree] object Split {
/**
* A [[org.apache.spark.mllib.tree.model.Split]] object to denote that
* we can't find a valid split that satisfies minimum info gain
* or minimum number of instances per node.
*/
val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List())
}

0 comments on commit e72c7e4

Please sign in to comment.