From 57456d352ae9b45596ac776341ac09199aadc30c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 13 Apr 2016 08:04:55 -0700 Subject: [PATCH 1/2] [SPARK-14565][ML] RandomForest should use parseInt and parseDouble for feature subset size instead of regexes This fix tries to change RandomForest's supported strategies from using regexes to using parseInt and parseDouble, for the purpose of robustness and maintainability. --- .../ml/tree/impl/DecisionTreeMetadata.scala | 9 ++--- .../org/apache/spark/ml/tree/treeParams.scala | 33 ++++++++++++++++--- .../spark/mllib/tree/RandomForest.scala | 3 +- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index c7cde1563fc79..12225b24ee308 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.tree.impl import scala.collection.mutable import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.RandomForestParams import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -184,15 +185,15 @@ private[spark] object DecisionTreeMetadata extends Logging { case _ => featureSubsetStrategy } - val isIntRegex = "^([1-9]\\d*)$".r - val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r val numFeaturesPerNode: Int = _featureSubsetStrategy match { case "all" => numFeatures case "sqrt" => math.sqrt(numFeatures).ceil.toInt case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) case "onethird" => (numFeatures / 3.0).ceil.toInt - case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt - case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt + case RandomForestParams.integerFeatureSubsetStrategy(number) => + if (number > numFeatures) numFeatures else number + case RandomForestParams.doubleFeatureSubsetStrategy(fraction) => + (fraction * numFeatures).ceil.toInt } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 0767dc17e5562..45552b80c82c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -346,10 +346,12 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + + s", (0.0-1.0], [1-n].", (value: String) => RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) - || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex)) + || RandomForestParams.integerFeatureSubsetStrategy.unapply(value).isDefined + || RandomForestParams.doubleFeatureSubsetStrategy.unapply(value).isDefined) setDefault(featureSubsetStrategy -> "auto") @@ -397,8 +399,31 @@ private[spark] object RandomForestParams { final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) - // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features) - final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$" + object integerFeatureSubsetStrategy { + def unapply(strategy: String): Option[Int] = try { + val number = strategy.toInt + if (0 < number) { + Some(number) + } else { + None + } + } catch { + case _ : java.lang.NumberFormatException => None + } + } + + object doubleFeatureSubsetStrategy { + def unapply(strategy: String): Option[Double] = try { + val fraction = strategy.toDouble + if (0.0 < fraction && fraction <= 1.0) { + Some(fraction) + } else { + None + } + } catch { + case _ : java.lang.NumberFormatException => None + } + } } private[ml] trait RandomForestClassifierParams diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 26755849ad1a2..082e48e2a4c73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -76,7 +76,8 @@ private class RandomForest ( strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) - || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex), + || NewRFParams.integerFeatureSubsetStrategy.unapply(featureSubsetStrategy).isDefined + || NewRFParams.doubleFeatureSubsetStrategy.unapply(featureSubsetStrategy).isDefined, s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," + s" (0.0-1.0], [1-n].") From ed346cd0f8b07876f95ae9636f7714e6c5f26c13 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 13 Apr 2016 21:46:24 -0700 Subject: [PATCH 2/2] [SPARK-14565][ML] RandomForest should use parseInt and parseDouble for feature subset size instead of regexes Update to use Try and filter to simplify the code. --- .../ml/tree/impl/DecisionTreeMetadata.scala | 16 +++++++--- .../org/apache/spark/ml/tree/treeParams.scala | 32 +++---------------- .../spark/mllib/tree/RandomForest.scala | 7 ++-- .../ml/tree/impl/RandomForestSuite.scala | 4 +-- 4 files changed, 22 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 12225b24ee308..5f7c40f6071f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.tree.impl import scala.collection.mutable +import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.RandomForestParams @@ -190,10 +191,17 @@ private[spark] object DecisionTreeMetadata extends Logging { case "sqrt" => math.sqrt(numFeatures).ceil.toInt case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) case "onethird" => (numFeatures / 3.0).ceil.toInt - case RandomForestParams.integerFeatureSubsetStrategy(number) => - if (number > numFeatures) numFeatures else number - case RandomForestParams.doubleFeatureSubsetStrategy(fraction) => - (fraction * numFeatures).ceil.toInt + case _ => + Try(_featureSubsetStrategy.toInt).filter(_ > 0).toOption match { + case Some(value) => math.min(value, numFeatures) + case None => + Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match { + case Some(value) => math.ceil(value * numFeatures).toInt + case _ => throw new IllegalArgumentException(s"Supported values:" + + s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" (0.0-1.0], [1-n].") + } + } } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 45552b80c82c4..55e12acc1c51e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import scala.util.Try + import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -350,8 +352,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { s", (0.0-1.0], [1-n].", (value: String) => RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) - || RandomForestParams.integerFeatureSubsetStrategy.unapply(value).isDefined - || RandomForestParams.doubleFeatureSubsetStrategy.unapply(value).isDefined) + || Try(value.toInt).filter(_ > 0).isSuccess + || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) setDefault(featureSubsetStrategy -> "auto") @@ -398,32 +400,6 @@ private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) - - object integerFeatureSubsetStrategy { - def unapply(strategy: String): Option[Int] = try { - val number = strategy.toInt - if (0 < number) { - Some(number) - } else { - None - } - } catch { - case _ : java.lang.NumberFormatException => None - } - } - - object doubleFeatureSubsetStrategy { - def unapply(strategy: String): Option[Double] = try { - val fraction = strategy.toDouble - if (0.0 < fraction && fraction <= 1.0) { - Some(fraction) - } else { - None - } - } catch { - case _ : java.lang.NumberFormatException => None - } - } } private[ml] trait RandomForestClassifierParams diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 082e48e2a4c73..ca7fb7f51c3fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.util.Try import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD @@ -76,10 +77,10 @@ private class RandomForest ( strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) - || NewRFParams.integerFeatureSubsetStrategy.unapply(featureSubsetStrategy).isDefined - || NewRFParams.doubleFeatureSubsetStrategy.unapply(featureSubsetStrategy).isDefined, + || Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess + || Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess, s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + - s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," + s" (0.0-1.0], [1-n].") /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 6db9ce150d930..1719f9fab5345 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -440,7 +440,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") for (invalidStrategy <- invalidStrategies) { - intercept[MatchError]{ + intercept[IllegalArgumentException]{ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) } @@ -463,7 +463,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) } for (invalidStrategy <- invalidStrategies) { - intercept[MatchError]{ + intercept[IllegalArgumentException]{ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) }