From 27eeaf7579428a8de180dbe08fb30d1e30228b33 Mon Sep 17 00:00:00 2001 From: Eric Denovitzer Date: Mon, 9 Feb 2015 10:39:07 -0500 Subject: [PATCH] NumSplits for categorical value cannot be larger than possible number of subsets of categories for that variable. When choosing a subset of subsets, make this random. --- .../spark/mllib/tree/DecisionTree.scala | 35 ++++++++++++++++++- .../tree/impl/DecisionTreeMetadata.scala | 2 +- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 73e7e32c6db31..dd5418e1c03ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.apache.spark.annotation.Experimental @@ -1044,9 +1045,10 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex) = new Array[Split](numSplits) bins(featureIndex) = new Array[Bin](numBins) var splitIndex = 0 + var randomSubsetsAsIntegers = generateRandomIntegers(featureArity, numSplits) while (splitIndex < numSplits) { val categories: List[Double] = - extractMultiClassCategories(splitIndex + 1, featureArity) + extractMultiClassCategories(randomSubsetsAsIntegers(splitIndex), featureArity) splits(featureIndex)(splitIndex) = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(splitIndex) = { @@ -1084,6 +1086,37 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Generate a list of random integers between 1..2^n-2 without duplicates. Used to choose + * random subsets based on each number's binary representation, excluding the empty set + * and the full set. + */ + private[tree] def generateRandomIntegers( + maxFeatureValue: Int, + splits: Int): List[Int] = { + var maxValue = (1 << maxFeatureValue) - 2 + var integers = List[Int]() + if (maxValue == splits) { + integers = (1 to maxValue).toList + } else { + var selectedToUnselected = Map[Int, Int]() + var selected = 0 + var rand = new Random() + while (selected < splits) { + var randomInt = rand.nextInt(maxValue) + 1 + if (selectedToUnselected.contains(randomInt)) { + integers = selectedToUnselected.get(randomInt).get :: integers + } else { + integers = randomInt :: integers + } + selectedToUnselected += (randomInt -> maxValue) + maxValue -= 1 + selected += 1 + } + } + integers + } + /** * Nested method to extract list of eligible categories given an index. It extracts the * position of ones in a binary representation of the input. If binary diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 951733fada6be..18db6ea8a6182 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -71,7 +71,7 @@ private[tree] class DecisionTreeMetadata( * For ordered features, there is 1 more bin than split. */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) >> 1 + Math.min(numBins(featureIndex) >> 1, (1 << featureArity(featureIndex)) - 2) } else { numBins(featureIndex) - 1 }