Skip to content

Commit

Permalink
Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategor…
Browse files Browse the repository at this point in the history
…icalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature.

Added new test to DecisionTreeSuite to catch this: "regression stump with categorical variables of arity 2"

Bug fix: Modified upper bound discussed above.

Also: Small improvements to coding style in DecisionTree.
  • Loading branch information
jkbradley committed Aug 1, 2014
1 parent 78f2af5 commit 225822f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 18 deletions.
45 changes: 27 additions & 18 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ object DecisionTree extends Serializable with Logging {
val bin = binForFeatures(mid)
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
if ((lowThreshold < feature) && (highThreshold >= feature)){
if ((lowThreshold < feature) && (highThreshold >= feature)) {
return mid
}
else if (lowThreshold >= feature) {
Expand All @@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging {
}

/**
* Sequential search helper method to find bin for categorical feature.
* Sequential search helper method to find bin for categorical feature
* (for classification and regression).
*/
def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
val featureValue = labeledPoint.features(featureIndex)
var binIndex = 0
while (binIndex < numCategoricalBins) {
while (binIndex < featureCategories) {
val bin = bins(featureIndex)(binIndex)
val categories = bin.highSplit.categories
val features = labeledPoint.features
if (categories.contains(features(featureIndex))) {
if (categories.contains(featureValue)) {
return binIndex
}
binIndex += 1
}
if (featureValue < 0 || featureValue >= featureCategories) {
throw new IllegalArgumentException(
s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in" +
s" {0,...,${featureCategories - 1}," +
s" but a data point gives it value $featureValue.\n" +
" Bad data point: " + labeledPoint.toString)
}
-1
}

if (isFeatureContinuous) {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
if (binIndex == -1){
if (binIndex == -1) {
throw new UnknownError("no bin was found for continuous variable.")
}
binIndex
Expand All @@ -555,10 +563,10 @@ object DecisionTree extends Serializable with Logging {
if (isUnorderedFeature) {
sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
} else {
sequentialBinSearchForOrderedCategoricalFeatureInClassification()
sequentialBinSearchForOrderedCategoricalFeature()
}
}
if (binIndex == -1){
if (binIndex == -1) {
throw new UnknownError("no bin was found for categorical variable.")
}
binIndex
Expand Down Expand Up @@ -642,11 +650,12 @@ object DecisionTree extends Serializable with Logging {
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggShift = numClasses * numBins * numFeatures * nodeIndex
val aggIndex
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
val labelInt = label.toInt
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
val aggIndex =
numClasses * numBins * numFeatures * nodeIndex +
numClasses * numBins * featureIndex +
numClasses * arr(arrIndex).toInt +
label.toInt
agg(aggIndex) += 1
}

/**
Expand Down Expand Up @@ -1127,7 +1136,7 @@ object DecisionTree extends Serializable with Logging {
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
var featureIndex = 0
while (featureIndex < numFeatures) {
if (isMulticlassClassificationWithCategoricalFeatures){
if (isMulticlassClassificationWithCategoricalFeatures) {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
Expand Down Expand Up @@ -1393,7 +1402,7 @@ object DecisionTree extends Serializable with Logging {

// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures){
while (featureIndex < numFeatures) {
// Check whether the feature is continuous.
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
Expand Down Expand Up @@ -1513,7 +1522,7 @@ object DecisionTree extends Serializable with Logging {
if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
for (index <- 1 until numBins - 1){
for (index <- 1 until numBins - 1) {
val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
Continuous, Double.MinValue)
bins(featureIndex)(index) = bin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(accuracy >= requiredAccuracy)
}

def validateRegressor(
model: DecisionTreeModel,
input: Seq[LabeledPoint],
requiredMSE: Double) {
val predictions = input.map(x => model.predict(x.features))
val squaredError = predictions.zip(input).map { case (prediction, expected) =>
(prediction - expected.label) * (prediction - expected.label)
}.sum
val mse = squaredError / input.length
assert(mse <= requiredMSE)
}

test("split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
Expand Down Expand Up @@ -454,6 +466,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.impurity > 0.2)
}

test("regression stump with categorical variables of arity 2") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Regression,
Variance,
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))

val model = DecisionTree.train(rdd, strategy)
validateRegressor(model, arr, 0.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
}

test("stump with fixed label 0 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
Expand Down

0 comments on commit 225822f

Please sign in to comment.