Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jun 29, 2016
1 parent 67b401a commit e8b8914
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ private[spark] class DTStatsAggregator(
impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}

/**
* Calculate gain for a given (featureOffset, leftBin, parentBin).
*
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
* @param leftBinIndex Index of the leftChild in allStats
* Given by featureOffset + leftBinIndex * statsSize
* @param parentBinIndex Index of the parent in allStats
* Given by featureOffset + parentBinIndex * statsSize
*/
def calculateGain(
featureOffset: Int,
leftBinIndex: Int,
Expand All @@ -115,6 +125,14 @@ private[spark] class DTStatsAggregator(
gain
}

/**
* Calculate gain for a given (featureOffset, leftBin).
* The stats of the parent are inferred from parentStats.
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
* @param leftBinIndex Index of the leftChild in allStats
* Given by featureOffset + leftBinIndex * statsSize
*/
def calculateGain(
featureOffset: Int,
leftBinIndex: Int): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ private[spark] object RandomForest extends Logging {
}
// Find best split.
val (bestFeatureSplitIndex, maxGain) =
Range(0, numSplits).map { case splitIdx =>
Range(0, numSplits).map { splitIdx =>
val gain = binAggregates.calculateGain(nodeFeatureOffset, splitIdx, numSplits)
(splitIdx, gain)
}.maxBy(_._2)
Expand All @@ -658,7 +658,7 @@ private[spark] object RandomForest extends Logging {
// Unordered categorical feature
val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, maxGain) =
Range(0, numSplits).map { case splitIdx =>
Range(0, numSplits).map { splitIdx =>
val gain = binAggregates.calculateGain(leftChildOffset, splitIdx)
(splitIdx, gain)
}.maxBy(_._2)
Expand Down Expand Up @@ -723,7 +723,7 @@ private[spark] object RandomForest extends Logging {
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, maxGain) =
Range(0, numSplits).map { case splitIdx =>
Range(0, numSplits).map { splitIdx =>
val featureValue = categoriesSortedByCentroid(splitIdx)._1
val gain = binAggregates.calculateGain(nodeFeatureOffset, featureValue, lastCategory)
(splitIdx, gain)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ object Entropy extends Impurity {
@Since("1.1.0")
def instance: this.type = this

/**
* Information gain calculation.
* allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity
* information of the leftChild.
* parentsStats(parentOffset: parentOffset + statsSize) contains the impurity
* information of the parent.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param leftChildOffset Start index of stats for the left child.
* @param parentStats Flat stats array for impurity calculation of the parent.
* @param parentOffset Start index of stats for the parent.
* @param statsSize Size of the stats for the left child and the parent.
* @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain.
* @param minInfoGain return zero if gain < minInfoGain.
* @return information gain.
*/
override def calculateGain(
allStats: Array[Double],
leftChildOffset: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ object Gini extends Impurity {
@Since("1.1.0")
def instance: this.type = this

/**
* Information gain calculation.
* allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity
* information of the leftChild.
* parentsStats(parentOffset: parentOffset + statsSize) contains the impurity
* information of the parent.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param leftChildOffset Start index of stats for the left child.
* @param parentStats Flat stats array for impurity calculation of the parent.
* @param parentOffset Start index of stats for the parent.
* @param statsSize Size of the stats for the left child and the parent.
* @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain.
* @param minInfoGain return zero if gain < minInfoGain.
* @return information gain.
*/
override def calculateGain(
allStats: Array[Double],
leftChildOffset: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ trait Impurity extends Serializable {
@DeveloperApi
def calculate(count: Double, sum: Double, sumSquares: Double): Double

/**
* Information gain calculation.
* allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity
* information of the leftChild.
* parentsStats(parentOffset: parentOffset + statsSize) contains the impurity
* information of the parent.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param leftChildOffset Start index of stats for the left child.
* @param parentStats Flat stats array for impurity calculation of the parent.
* @param parentOffset Start index of stats for the parent.
* @param statsSize Size of the stats for the left child and the parent.
* @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain.
* @param minInfoGain return zero if gain < minInfoGain.
* @return information gain.
*/
protected def calculateGain(
allStats: Array[Double],
leftChildOffset: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ object Variance extends Impurity {
@Since("1.0.0")
def instance: this.type = this

/**
* Information gain calculation.
* allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity
* information of the leftChild.
* parentsStats(parentOffset: parentOffset + statsSize) contains the impurity
* information of the parent.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param leftChildOffset Start index of stats for the left child.
* @param parentStats Flat stats array for impurity calculation of the parent.
* @param parentOffset Start index of stats for the parent.
* @param statsSize Size of the stats for the left child and the parent.
* @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain.
* @param minInfoGain return zero if gain < minInfoGain.
* @return information gain.
*/
override def calculateGain(
allStats: Array[Double],
leftChildOffset: Int,
Expand Down

0 comments on commit e8b8914

Please sign in to comment.