Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10524][ML] Use the soft prediction to order categories' bins #8734

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
Expand Down Expand Up @@ -740,7 +740,11 @@ private[ml] object RandomForest extends Logging {
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.predict
if (categoryStats.count == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant categoryStats.stats.length == 2. categoryStats.count is the count of data points falling into that particular bin. Since we are trying to determine here whether this is regression or binary classification, I think checking if (binAggregates.metadata.isClassification) is more clear.

Additionally, the code under the if and else statements of centroidForCategories is identical except for a single line. It seems cleaner to restructure to something like:

val centroidForCategories = Range(0, numCategories).map { case featureValue =>
  val categoryStats =
    binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
  val centroid = if (categoryStats.count != 0) {
    if (binAggregates.metadata.isMulticlass) {
      // multiclass classification
      categoryStats.calculate()
    } else if (binAggregates.metadata.isClassification) {
      // binary classification
      categoryStats.stats(1)
    } else {
      // regression
      categoryStats.predict
    }
  } else {
    Double.MaxValue
  }
  (featureValue, centroid)
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sethah Thanks. You are right. I didn't read this part of codes thoroughly.

categoryStats.stats(1)
} else {
categoryStats.predict
}
} else {
Double.MaxValue
}
Expand Down
236 changes: 122 additions & 114 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
Expand All @@ -808,128 +808,136 @@ object DecisionTree extends Serializable with Logging {
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)

/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else {
Double.MaxValue
}
(featureValue, centroid)
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// the bins are ordered by the centroid of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.predict
} else {
Double.MaxValue
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)

/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else {
Double.MaxValue
}
(featureValue, centroid)
}
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
if (categoryStats.count == 2) {
categoryStats.stats(1)
} else {
categoryStats.predict
}
} else {
Double.MaxValue
}
(featureValue, centroid)
}
(featureValue, centroid)
}
}

logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)

logDebug("Sorted centroids for categorical variable = " +
categoriesSortedByCentroid.mkString(","))
logDebug("Sorted centroids for categorical variable = " +
categoriesSortedByCentroid.mkString(","))

// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)

(bestSplit, bestSplitStats, predictWithImpurity.get._1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -288,8 +289,12 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.impurity !== -1.0)

// set impurity and predict for child nodes
assert(topNode.leftNode.get.predict.predict === 0.0)
assert(topNode.rightNode.get.predict.predict === 1.0)
if (topNode.leftNode.get.predict.predict === 0.0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened here? Did the leaf nodes switch because of the internal change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because previously we rank the bins according to hard prediction. So for example the bin of hard prediction 0 is always in front of hard prediction 1. Now the order of bins can be switched.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now that we are ordering by the correct thing, this change is not necessary. Since we order by proportion falling in outcome class 1, the hard prediction 0 will always be before hard prediction 1.

assert(topNode.rightNode.get.predict.predict === 1.0)
} else {
assert(topNode.leftNode.get.predict.predict === 1.0)
assert(topNode.rightNode.get.predict.predict === 0.0)
}
assert(topNode.leftNode.get.impurity === 0.0)
assert(topNode.rightNode.get.impurity === 0.0)
}
Expand Down Expand Up @@ -331,12 +336,62 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.impurity !== -1.0)

// set impurity and predict for child nodes
assert(topNode.leftNode.get.predict.predict === 0.0)
assert(topNode.rightNode.get.predict.predict === 1.0)
if (topNode.leftNode.get.predict.predict === 0.0) {
assert(topNode.rightNode.get.predict.predict === 1.0)
} else {
assert(topNode.leftNode.get.predict.predict === 1.0)
assert(topNode.rightNode.get.predict.predict === 0.0)
}
assert(topNode.leftNode.get.impurity === 0.0)
assert(topNode.rightNode.get.impurity === 0.0)
}

test("Use soft prediction for binary classification with ordered categorical features") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the goal of this test? I guessed that the goal would be to find a case where ordering by hard predictions produces a different (suboptimal) tree than ordering by soft predictions. However, I did a quick simulation for this dataset and the results I got were the same either way. Just wanted to clarify.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I just want a test case to show it actually order the bins by soft prediction. Although @jkbradley suggested we should use directly binsToBestSplit, but in order to do that, we also need to expose many details of findBestSplits too, e.g., binSeqOp, getNodeToFeatures and partitionAggregates...etc.

val arr = Array(
LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), // left node
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), // right node
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), // left node
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)), // right node
LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 0.0)), // left node
LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.0))) // left node
val input = sc.parallelize(arr)

val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)

val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)

val topNode = Node.emptyNode(nodeIndex = 1)
assert(topNode.predict.predict === Double.MinValue)
assert(topNode.impurity === -1.0)
assert(topNode.isLeaf === false)

val nodesForGroup = Map((0, Array(topNode)))
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodeQueue = new mutable.Queue[(Int, Node)]()
DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please update this test to call binsToBestSplit directly? You can change it to be private[tree] so that it's callable from this test suite.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to call binsToBestSplit directly, we need to expose many details of findBestSplits too, e.g., binSeqOp, getNodeToFeatures and partitionAggregates...etc., because binsToBestSplit needs binAggregates and featuresForNode..etc. as parameters. Is it a good idea?

nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)

// don't enqueue leaf nodes into node queue
assert(nodeQueue.isEmpty)

// set impurity and predict for topNode
assert(topNode.predict.predict !== Double.MinValue)
assert(topNode.impurity !== -1.0)

val impurityForRightNode = Gini.calculate(Array(0.0, 3.0, 1.0), 4.0)

// set impurity and predict for child nodes
assert(topNode.leftNode.get.predict.predict === 0.0)
assert(topNode.rightNode.get.predict.predict === 1.0)
assert(topNode.leftNode.get.impurity ~== 0.44 absTol impurityForRightNode)
assert(topNode.rightNode.get.impurity === 0.0)
}
test("Second level node building with vs. without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
Expand Down