Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14351] [MLlib] [ML] Optimize findBestSplits method for decision trees (and random forest) #13959

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,64 @@ private[spark] class DTStatsAggregator(
impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}

/**
* Calculate gain for a given (featureOffset, leftBin, parentBin).
*
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
* @param leftBinIndex Index of the leftChild in allStats
* Given by featureOffset + leftBinIndex * statsSize
* @param parentBinIndex Index of the parent in allStats
* Given by featureOffset + parentBinIndex * statsSize
*/
def calculateGain(
featureOffset: Int,
leftBinIndex: Int,
parentBinIndex: Int): Double = {
val leftChildOffset = featureOffset + leftBinIndex * statsSize
val parentOffset = featureOffset + parentBinIndex * statsSize
val gain = metadata.impurity match {
case Gini => Gini.calculateGain(
allStats, leftChildOffset, allStats, parentOffset, statsSize,
metadata.minInstancesPerNode, metadata.minInfoGain)
case Entropy => Entropy.calculateGain(
allStats, leftChildOffset, allStats, parentOffset, statsSize,
metadata.minInstancesPerNode, metadata.minInfoGain)
case Variance => Variance.calculateGain(
allStats, leftChildOffset, allStats, parentOffset, statsSize,
metadata.minInstancesPerNode, metadata.minInfoGain)
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}
gain
}

/**
* Calculate gain for a given (featureOffset, leftBin).
* The stats of the parent are inferred from parentStats.
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
* @param leftBinIndex Index of the leftChild in allStats
* Given by featureOffset + leftBinIndex * statsSize
*/
def calculateGain(
featureOffset: Int,
leftBinIndex: Int): Double = {
val leftChildOffset = featureOffset + leftBinIndex * statsSize
val gain = metadata.impurity match {
case Gini => Gini.calculateGain(
allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode,
metadata.minInfoGain)
case Entropy => Entropy.calculateGain(
allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode,
metadata.minInfoGain)
case Variance => Variance.calculateGain(
allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode,
metadata.minInfoGain)
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}
gain
}

/**
* Get an [[ImpurityCalculator]] for the parent node.
*/
Expand Down
144 changes: 41 additions & 103 deletions mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -613,65 +613,6 @@ private[spark] object RandomForest extends Logging {
}
}

/**
* Calculate the impurity statistics for a given (feature, split) based upon left/right
* aggregates.
*
* @param stats the recycle impurity statistics for this feature's all splits,
* only 'impurity' and 'impurityCalculator' are valid between each iteration
* @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split)
* @param metadata learning and dataset metadata for DecisionTree
* @return Impurity statistics for this (feature, split)
*/
private def calculateImpurityStats(
stats: ImpurityStats,
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
metadata: DecisionTreeMetadata): ImpurityStats = {

val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
leftImpurityCalculator.copy.add(rightImpurityCalculator)
} else {
stats.impurityCalculator
}

val impurity: Double = if (stats == null) {
parentImpurityCalculator.calculate()
} else {
stats.impurity
}

val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count

val totalCount = leftCount + rightCount

// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}

val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()

val leftWeight = leftCount / totalCount.toDouble
val rightWeight = rightCount / totalCount.toDouble

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity

// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}

new ImpurityStats(gain, impurity, parentImpurityCalculator,
leftImpurityCalculator, rightImpurityCalculator)
}

/**
* Find the best split for a node.
*
Expand All @@ -684,16 +625,10 @@ private[spark] object RandomForest extends Logging {
featuresForNode: Option[Array[Int]],
node: LearningNode): (Split, ImpurityStats) = {

// Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id)
var gainAndImpurityStats: ImpurityStats = if (level == 0) {
null
} else {
node.stats
}

// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
val (bestSplit, bestGain, bestFeatureOffset, bestSplitIndex) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
Expand All @@ -712,30 +647,23 @@ private[spark] object RandomForest extends Logging {
splitIndex += 1
}
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIdx, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
val (bestFeatureSplitIndex, maxGain) =
Range(0, numSplits).map { splitIdx =>
val gain = binAggregates.calculateGain(nodeFeatureOffset, splitIdx, numSplits)
(splitIdx, gain)
}.maxBy(_._2)
val bestFeatureSplit = splits(featureIndex)(bestFeatureSplitIndex)
(bestFeatureSplit, maxGain, nodeFeatureOffset, bestFeatureSplitIndex)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getParentImpurityCalculator()
.subtract(leftChildStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
val (bestFeatureSplitIndex, maxGain) =
Range(0, numSplits).map { splitIdx =>
val gain = binAggregates.calculateGain(leftChildOffset, splitIdx)
(splitIdx, gain)
}.maxBy(_._2)
val bestFeatureSplit = splits(featureIndex)(bestFeatureSplitIndex)
(bestFeatureSplit, maxGain, leftChildOffset, bestFeatureSplitIndex)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
Expand Down Expand Up @@ -794,27 +722,37 @@ private[spark] object RandomForest extends Logging {
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
val (bestFeatureSplitIndex, maxGain) =
Range(0, numSplits).map { splitIdx =>
val featureValue = categoriesSortedByCentroid(splitIdx)._1
val gain = binAggregates.calculateGain(nodeFeatureOffset, featureValue, lastCategory)
(splitIdx, gain)
}.maxBy(_._2)
val bestFeatureValue = categoriesSortedByCentroid(bestFeatureSplitIndex)._1
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
(bestFeatureSplit, bestFeatureGainStats)
(bestFeatureSplit, maxGain, nodeFeatureOffset, bestFeatureValue)
}
}.maxBy(_._2.gain)

(bestSplit, bestSplitStats)
}.maxBy(_._2)

val leftImpurityCalculator = binAggregates.getImpurityCalculator(
bestFeatureOffset, bestSplitIndex)
val parentImpurityCalculator = binAggregates.getParentImpurityCalculator()
val rightImpurityCalculator = parentImpurityCalculator.copy.subtract(
leftImpurityCalculator)
val bestFeatureGainStats = {
if (bestGain == Double.MinValue) {
ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
else {
new ImpurityStats(bestGain, parentImpurityCalculator.calculate(),
parentImpurityCalculator, leftImpurityCalculator,
rightImpurityCalculator)
}
}
(bestSplit, bestFeatureGainStats)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
@Experimental
object Entropy extends Impurity {

private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
private[tree] def log2(x: Double): Double = {
if (x == 0) {
return 0.0
} else {
return scala.math.log(x) / scala.math.log(2)
}
}

/**
* :: DeveloperApi ::
Expand All @@ -47,10 +53,8 @@ object Entropy extends Impurity {
var classIndex = 0
while (classIndex < numClasses) {
val classCount = counts(classIndex)
if (classCount != 0) {
val freq = classCount / totalCount
impurity -= freq * log2(freq)
}
val freq = classCount / totalCount
impurity -= freq * log2(freq)
classIndex += 1
}
impurity
Expand All @@ -76,6 +80,72 @@ object Entropy extends Impurity {
@Since("1.1.0")
def instance: this.type = this

/**
* Information gain calculation.
* allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity
* information of the leftChild.
* parentsStats(parentOffset: parentOffset + statsSize) contains the impurity
* information of the parent.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param leftChildOffset Start index of stats for the left child.
* @param parentStats Flat stats array for impurity calculation of the parent.
* @param parentOffset Start index of stats for the parent.
* @param statsSize Size of the stats for the left child and the parent.
* @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain.
* @param minInfoGain return zero if gain < minInfoGain.
* @return information gain.
*/
override def calculateGain(
allStats: Array[Double],
leftChildOffset: Int,
parentStats: Array[Double],
parentOffset: Int,
statsSize: Int,
minInstancesPerNode: Int,
minInfoGain: Double): Double = {
var leftCount = 0.0
var totalCount = 0.0
var i = 0
while (i < statsSize) {
leftCount += allStats(leftChildOffset + i)
totalCount += parentStats(parentOffset + i)
i += 1
}
val rightCount = totalCount - leftCount

if ((leftCount < minInstancesPerNode) ||
(rightCount < minInstancesPerNode)) {
return Double.MinValue
}

var leftImpurity = 0.0
var rightImpurity = 0.0
var parentImpurity = 0.0

i = 0
while (i < statsSize) {
val leftStats = allStats(leftChildOffset + i)
val totalStats = parentStats(parentOffset + i)

val leftFreq = leftStats / leftCount
val rightFreq = (totalStats - leftStats) / rightCount
val parentFreq = totalStats / totalCount

leftImpurity -= leftFreq * log2(leftFreq)
rightImpurity -= rightFreq * log2(rightFreq)
parentImpurity -= parentFreq * log2(parentFreq)

i += 1
}
val leftWeighted = leftCount / totalCount * leftImpurity
val rightWeighted = rightCount / totalCount * rightImpurity
val gain = parentImpurity - leftWeighted - rightWeighted

if (gain < minInfoGain) {
return Double.MinValue
}
gain
}
}

/**
Expand Down