Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3724][ML] RandomForest: More options for feature subset size. #11989

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
case _ => featureSubsetStrategy
}

val isIntRegex = "^([1-9]\\d*)$".r
val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
}

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
* - "(0.0-1.0]": use the specified fraction of features
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can simply consolidate the doc into something like:

- "n": when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features.

* - "n": use n features, for integer 0 < n <= (number of features)
* (default = "auto")
*
* These various settings are based on the following references:
Expand All @@ -346,7 +348,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
(value: String) =>
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
|| value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))

setDefault(featureSubsetStrategy -> "auto")

Expand Down Expand Up @@ -393,6 +396,9 @@ private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)

// The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}

private[ml] trait RandomForestClassifierParams
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,16 @@ import org.apache.spark.util.Utils
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
* Supported values: "auto", "all", "sqrt", "log2", "onethird".
* Supported numerical values: "(0.0-1.0]", "[1-n]".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* If a real value "(0.0-1.0]" is set, this parameter specifies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here, we can say something like

If a real value "n" in the range (0, 1.0] is set, this ...
If an integer value "n" in the range (1, num features) is set, this ...

* the fraction of features in each subset.
* If an integer value "n" is set, this parameter specifies
* the number of features used in each subset,
* for integer 0 < n <= (number of features).
* @param seed Random seed for bootstrapping and choosing feature subsets.
*/
private class RandomForest (
Expand All @@ -70,9 +76,11 @@ private class RandomForest (

strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
|| featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].")

/**
* Method to train a decision tree model over an RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ public void runDT() {
for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
rf.setFeatureSubsetStrategy(".1");
rf.setFeatureSubsetStrategy(".10");
rf.setFeatureSubsetStrategy("0.10");
rf.setFeatureSubsetStrategy("0.1");
rf.setFeatureSubsetStrategy("0.9");
rf.setFeatureSubsetStrategy("1.0");
rf.setFeatureSubsetStrategy("1");
rf.setFeatureSubsetStrategy("100");
rf.setFeatureSubsetStrategy("1000");
RandomForestClassificationModel model = rf.fit(dataFrame);

model.transform(dataFrame);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ public void runDT() {
for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
rf.setFeatureSubsetStrategy(".1");
rf.setFeatureSubsetStrategy(".10");
rf.setFeatureSubsetStrategy("0.10");
rf.setFeatureSubsetStrategy("0.1");
rf.setFeatureSubsetStrategy("0.9");
rf.setFeatureSubsetStrategy("1.0");
rf.setFeatureSubsetStrategy("1");
rf.setFeatureSubsetStrategy("100");
rf.setFeatureSubsetStrategy("1000");
RandomForestRegressionModel model = rf.fit(dataFrame);

model.transform(dataFrame);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.tree.model.RandomForestModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this import was added. It can be removed.

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.collection.OpenHashMap
Expand Down Expand Up @@ -422,13 +423,27 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
checkFeatureSubsetStrategy(numTrees = 1, "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "0.1", (0.1 * numFeatures).ceil.toInt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason these test cases differ from the Java ones? I notice in the Java tests we also test some of the regex like ".1" and ".10" and "0.10".

I'm wondering if we shouldn't just have a couple test cases for the regex edge cases here (to ensure it gets translated correctly).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MLnick That might be because of the repetitive lines to copy around. I consolidated the test cases so that it is easier to track what are covered. Let me know if additional test cases are needed.

checkFeatureSubsetStrategy(numTrees = 1, "0.5", (0.5 * numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "1.0", (1.0 * numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "1", 1)
checkFeatureSubsetStrategy(numTrees = 1, "2", 2)
checkFeatureSubsetStrategy(numTrees = 1, numFeatures.toString, numFeatures)
checkFeatureSubsetStrategy(numTrees = 1, (numFeatures * 2).toString, numFeatures)

checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "0.1", (0.1 * numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "0.5", (0.5 * numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "1.0", (1.0 * numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "1", 1)
checkFeatureSubsetStrategy(numTrees = 2, "2", 2)
checkFeatureSubsetStrategy(numTrees = 2, numFeatures.toString, numFeatures)
checkFeatureSubsetStrategy(numTrees = 2, (numFeatures * 2).toString, numFeatures)
}

test("Binary classification with continuous features: subsampling features") {
Expand Down