From 84260ca52bc037f2074bc062c7d5b88e1657314a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 13 Sep 2015 23:27:08 +0800 Subject: [PATCH 1/5] Use the soft prediction to order categories' bins. --- .../spark/ml/tree/impl/RandomForest.scala | 2 +- .../spark/mllib/tree/DecisionTree.scala | 230 +++++++++--------- .../spark/mllib/tree/DecisionTreeSuite.scala | 67 ++++- 3 files changed, 179 insertions(+), 120 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4ac51a475474a..f261935847322 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -776,7 +776,7 @@ private[ml] object RandomForest extends Logging { val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { - categoryStats.predict + categoryStats.calculate() } else { Double.MaxValue } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4a77d4adcd865..0612cab81edce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -811,128 +811,132 @@ object DecisionTree extends Serializable with Logging { // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIdx, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numBins = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * - * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = if (binAggregates.metadata.isMulticlass) { - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - categoryStats.calculate() - } else { - Double.MaxValue - } - (featureValue, centroid) + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 } - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // the bins are ordered by the centroid of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - categoryStats.predict - } else { - Double.MaxValue + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIdx, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = + binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = + binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numBins = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = if (binAggregates.metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.calculate() + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.calculate() + } else { + Double.MaxValue + } + (featureValue, centroid) } - (featureValue, centroid) } - } - logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - logDebug("Sorted centroids for categorical variable = " + - categoriesSortedByCentroid.mkString(",")) + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) + (bestFeatureSplit, bestFeatureGainStats) } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) - (bestFeatureSplit, bestFeatureGainStats) - } }.maxBy(_._2.gain) (bestSplit, bestSplitStats, predictWithImpurity.get._1) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 356d957f15909..ceb1a1e32e15c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils @@ -294,8 +295,12 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.impurity !== -1.0) // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) + if (topNode.leftNode.get.predict.predict === 0.0) { + assert(topNode.rightNode.get.predict.predict === 1.0) + } else { + assert(topNode.leftNode.get.predict.predict === 1.0) + assert(topNode.rightNode.get.predict.predict === 0.0) + } assert(topNode.leftNode.get.impurity === 0.0) assert(topNode.rightNode.get.impurity === 0.0) } @@ -337,12 +342,62 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.impurity !== -1.0) // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) + if (topNode.leftNode.get.predict.predict === 0.0) { + assert(topNode.rightNode.get.predict.predict === 1.0) + } else { + assert(topNode.leftNode.get.predict.predict === 1.0) + assert(topNode.rightNode.get.predict.predict === 0.0) + } assert(topNode.leftNode.get.impurity === 0.0) assert(topNode.rightNode.get.impurity === 0.0) } + test("Use soft prediction for binary classification with ordered categorical features") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), // left node + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), // right node + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), // left node + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)), // right node + LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 0.0)), // left node + LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.0))) // left node + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + val impurityForRightNode = Gini.calculate(Array(0.0, 3.0, 1.0), 4.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity ~== 0.44 absTol impurityForRightNode) + assert(topNode.rightNode.get.impurity === 0.0) + } test("Second level node building with vs. without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) @@ -442,7 +497,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get - assert(split.categories === List(1.0)) + assert(split.categories === List(0.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) @@ -471,7 +526,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val split = rootNode.split.get assert(split.categories.length === 1) - assert(split.categories.contains(1.0)) + assert(split.categories.contains(0.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) From 2c3235054d2d2a98f9f88ae67243bd1ba89eca1e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 20 Jan 2016 12:09:45 +0800 Subject: [PATCH 2/5] Instead of calculate(), we should call prob() to get soft prediction. --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 38b6bc97411e2..234142659e74c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging { * @param binAggregates Bin statistics. * @return tuple for best split: (Split, information gain, prediction at node) */ - private def binsToBestSplit( + private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], @@ -740,7 +740,7 @@ private[ml] object RandomForest extends Logging { val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { - categoryStats.calculate() + categoryStats.prob(categoryStats.predict) } else { Double.MaxValue } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index d50a8e99bc1c6..b387d6ff70ca5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging { * @param binAggregates Bin statistics. * @return tuple for best split: (Split, information gain, prediction at node) */ - private def binsToBestSplit( + private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], @@ -885,7 +885,7 @@ object DecisionTree extends Serializable with Logging { val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { - categoryStats.calculate() + categoryStats.prob(categoryStats.predict) } else { Double.MaxValue } From cd252142a29e6083648bdec6a0646f2b6921b603 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 21 Jan 2016 14:49:24 +0800 Subject: [PATCH 3/5] For comments. --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 6 +++++- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 6 +++++- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 234142659e74c..23caeca75feee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -740,7 +740,11 @@ private[ml] object RandomForest extends Logging { val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { - categoryStats.prob(categoryStats.predict) + if (categoryStats.count == 2) { + categoryStats.stats(1) + } else { + categoryStats.predict + } } else { Double.MaxValue } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b387d6ff70ca5..08a6319eeef7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -885,7 +885,11 @@ object DecisionTree extends Serializable with Logging { val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { - categoryStats.prob(categoryStats.predict) + if (categoryStats.count == 2) { + categoryStats.stats(1) + } else { + categoryStats.predict + } } else { Double.MaxValue } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 25708b7fb0d43..2a2a9c3c4e2c5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -491,7 +491,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get - assert(split.categories === List(0.0)) + assert(split.categories === List(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) @@ -520,7 +520,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val split = rootNode.split.get assert(split.categories.length === 1) - assert(split.categories.contains(0.0)) + assert(split.categories.contains(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) From 5c44e23cac423ad57e8e602c227dfa7a0f7b822c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 23 Jan 2016 16:08:44 +0800 Subject: [PATCH 4/5] For comments. --- .../spark/ml/tree/impl/RandomForest.scala | 46 ++++++++----------- .../spark/mllib/tree/DecisionTree.scala | 43 +++++++---------- .../spark/mllib/tree/DecisionTreeSuite.scala | 16 ++----- 3 files changed, 41 insertions(+), 64 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 23caeca75feee..ea733d577a5fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -720,36 +720,30 @@ private[ml] object RandomForest extends Logging { * * centroidForCategories is a list: (category, centroid) */ - val centroidForCategories = if (binAggregates.metadata.isMulticlass) { - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { + val centroidForCategories = Range(0, numCategories).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // multiclass classification + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) } else { - Double.MaxValue - } - (featureValue, centroid) - } - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // the bins are ordered by the centroid of their corresponding labels. - Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (categoryStats.count == 2) { - categoryStats.stats(1) - } else { - categoryStats.predict - } - } else { - Double.MaxValue + // regression + // For categorical variables in regression and binary classification, + // the bins are ordered by the prediction. + categoryStats.predict } - (featureValue, centroid) + } else { + Double.MaxValue } + (featureValue, centroid) } logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 08a6319eeef7e..51235a23711a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -865,36 +865,27 @@ object DecisionTree extends Serializable with Logging { * * centroidForCategories is a list: (category, centroid) */ - val centroidForCategories = if (binAggregates.metadata.isMulticlass) { - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { + val centroidForCategories = Range(0, numBins).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) } else { - Double.MaxValue - } - (featureValue, centroid) - } - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (categoryStats.count == 2) { - categoryStats.stats(1) - } else { - categoryStats.predict - } - } else { - Double.MaxValue + // For categorical variables in regression, + // the bins are ordered by the prediction. + categoryStats.predict } - (featureValue, centroid) + } else { + Double.MaxValue } + (featureValue, centroid) } logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2a2a9c3c4e2c5..0dd83cfd89583 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -289,12 +289,8 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.impurity !== -1.0) // set impurity and predict for child nodes - if (topNode.leftNode.get.predict.predict === 0.0) { - assert(topNode.rightNode.get.predict.predict === 1.0) - } else { - assert(topNode.leftNode.get.predict.predict === 1.0) - assert(topNode.rightNode.get.predict.predict === 0.0) - } + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) assert(topNode.leftNode.get.impurity === 0.0) assert(topNode.rightNode.get.impurity === 0.0) } @@ -336,12 +332,8 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.impurity !== -1.0) // set impurity and predict for child nodes - if (topNode.leftNode.get.predict.predict === 0.0) { - assert(topNode.rightNode.get.predict.predict === 1.0) - } else { - assert(topNode.leftNode.get.predict.predict === 1.0) - assert(topNode.rightNode.get.predict.predict === 0.0) - } + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) assert(topNode.leftNode.get.impurity === 0.0) assert(topNode.rightNode.get.impurity === 0.0) } From c10872b97bbfd1ae2abb168ca3e4cf24f18865cb Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 9 Feb 2016 15:25:56 -0800 Subject: [PATCH 5/5] fixed unit test --- .../DecisionTreeClassifierSuite.scala | 36 ++++++++++- .../spark/mllib/tree/DecisionTreeSuite.scala | 61 +++++++------------ 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index fda2711fed0fd..baf6b9083900f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val model = dt.fit(df) } + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val data = sc.parallelize(arr) + val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(1) + .setMaxBins(3) + val model = dt.fit(df) + model.rootNode match { + case n: InternalNode => + n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 0dd83cfd89583..dca8ea815aa6a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -339,51 +339,34 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. val arr = Array( - LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), // left node - LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), // right node - LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), // left node - LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)), // right node - LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 0.0)), // left node - LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.0))) // left node + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) val input = sc.parallelize(arr) + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - val impurityForRightNode = Gini.calculate(Array(0.0, 3.0, 1.0), 4.0) + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity ~== 0.44 absTol impurityForRightNode) - assert(topNode.rightNode.get.impurity === 0.0) + val model = new DecisionTree(strategy).run(input) + model.topNode.split.get match { + case Split(_, _, _, categories: List[Double]) => + assert(categories === List(1.0)) + } } + test("Second level node building with vs. without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000)