Skip to content

Commit

Permalink
Use the soft prediction to order categories' bins.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Sep 13, 2015
1 parent 5f46444 commit 84260ca
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ private[ml] object RandomForest extends Logging {
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.predict
categoryStats.calculate()
} else {
Double.MaxValue
}
Expand Down
230 changes: 117 additions & 113 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -811,128 +811,132 @@ object DecisionTree extends Serializable with Logging {
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)

/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else {
Double.MaxValue
}
(featureValue, centroid)
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// the bins are ordered by the centroid of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.predict
} else {
Double.MaxValue
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)

/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else {
Double.MaxValue
}
(featureValue, centroid)
}
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.calculate()
} else {
Double.MaxValue
}
(featureValue, centroid)
}
(featureValue, centroid)
}
}

logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)

logDebug("Sorted centroids for categorical variable = " +
categoriesSortedByCentroid.mkString(","))
logDebug("Sorted centroids for categorical variable = " +
categoriesSortedByCentroid.mkString(","))

// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
var splitIndex = 0
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
// lastCategory = index of bin with total aggregates for this (node, feature)
val lastCategory = categoriesSortedByCentroid.last._1
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val featureValue = categoriesSortedByCentroid(splitIndex)._1
val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)

(bestSplit, bestSplitStats, predictWithImpurity.get._1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -294,8 +295,12 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.impurity !== -1.0)

// set impurity and predict for child nodes
assert(topNode.leftNode.get.predict.predict === 0.0)
assert(topNode.rightNode.get.predict.predict === 1.0)
if (topNode.leftNode.get.predict.predict === 0.0) {
assert(topNode.rightNode.get.predict.predict === 1.0)
} else {
assert(topNode.leftNode.get.predict.predict === 1.0)
assert(topNode.rightNode.get.predict.predict === 0.0)
}
assert(topNode.leftNode.get.impurity === 0.0)
assert(topNode.rightNode.get.impurity === 0.0)
}
Expand Down Expand Up @@ -337,12 +342,62 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.impurity !== -1.0)

// set impurity and predict for child nodes
assert(topNode.leftNode.get.predict.predict === 0.0)
assert(topNode.rightNode.get.predict.predict === 1.0)
if (topNode.leftNode.get.predict.predict === 0.0) {
assert(topNode.rightNode.get.predict.predict === 1.0)
} else {
assert(topNode.leftNode.get.predict.predict === 1.0)
assert(topNode.rightNode.get.predict.predict === 0.0)
}
assert(topNode.leftNode.get.impurity === 0.0)
assert(topNode.rightNode.get.impurity === 0.0)
}

test("Use soft prediction for binary classification with ordered categorical features") {
val arr = Array(
LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), // left node
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), // right node
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), // left node
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)), // right node
LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 0.0)), // left node
LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.0))) // left node
val input = sc.parallelize(arr)

val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)

val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)

val topNode = Node.emptyNode(nodeIndex = 1)
assert(topNode.predict.predict === Double.MinValue)
assert(topNode.impurity === -1.0)
assert(topNode.isLeaf === false)

val nodesForGroup = Map((0, Array(topNode)))
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodeQueue = new mutable.Queue[(Int, Node)]()
DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)

// don't enqueue leaf nodes into node queue
assert(nodeQueue.isEmpty)

// set impurity and predict for topNode
assert(topNode.predict.predict !== Double.MinValue)
assert(topNode.impurity !== -1.0)

val impurityForRightNode = Gini.calculate(Array(0.0, 3.0, 1.0), 4.0)

// set impurity and predict for child nodes
assert(topNode.leftNode.get.predict.predict === 0.0)
assert(topNode.rightNode.get.predict.predict === 1.0)
assert(topNode.leftNode.get.impurity ~== 0.44 absTol impurityForRightNode)
assert(topNode.rightNode.get.impurity === 0.0)
}
test("Second level node building with vs. without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
Expand Down Expand Up @@ -442,7 +497,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rootNode = DecisionTree.train(rdd, strategy).topNode

val split = rootNode.split.get
assert(split.categories === List(1.0))
assert(split.categories === List(0.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)

Expand Down Expand Up @@ -471,7 +526,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

val split = rootNode.split.get
assert(split.categories.length === 1)
assert(split.categories.contains(1.0))
assert(split.categories.contains(0.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)

Expand Down

0 comments on commit 84260ca

Please sign in to comment.