-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-3724][ML] RandomForest: More options for feature subset size. #11989
Changes from 7 commits
6398ae6
326f5a0
e154354
de3d7ac
704a8f0
f02604b
c2b662b
bebd544
13edc07
08feaaa
8a4c298
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,10 +55,16 @@ import org.apache.spark.util.Utils | |
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. | ||
* @param featureSubsetStrategy Number of features to consider for splits at each node. | ||
* Supported values: "auto", "all", "sqrt", "log2", "onethird". | ||
* Supported numerical values: "(0.0-1.0]", "[1-n]". | ||
* If "auto" is set, this parameter is set based on numTrees: | ||
* if numTrees == 1, set to "all"; | ||
* if numTrees > 1 (forest) set to "sqrt" for classification and | ||
* to "onethird" for regression. | ||
* If a real value "(0.0-1.0]" is set, this parameter specifies | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similarly here, we can say something like
|
||
* the fraction of features in each subset. | ||
* If an integer value "n" is set, this parameter specifies | ||
* the number of features used in each subset, | ||
* for integer 0 < n <= (number of features). | ||
* @param seed Random seed for bootstrapping and choosing feature subsets. | ||
*/ | ||
private class RandomForest ( | ||
|
@@ -70,9 +76,11 @@ private class RandomForest ( | |
|
||
strategy.assertValid() | ||
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") | ||
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), | ||
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) | ||
|| featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex), | ||
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + | ||
s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.") | ||
s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," + | ||
s" (0.0-1.0], [1-n].") | ||
|
||
/** | ||
* Method to train a decision tree model over an RDD | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint | |
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} | ||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} | ||
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} | ||
import org.apache.spark.mllib.tree.model.RandomForestModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why this import was added. It can be removed. |
||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
import org.apache.spark.mllib.util.TestingUtils._ | ||
import org.apache.spark.util.collection.OpenHashMap | ||
|
@@ -422,13 +423,27 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
checkFeatureSubsetStrategy(numTrees = 1, "log2", | ||
(math.log(numFeatures) / math.log(2)).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 1, "0.1", (0.1 * numFeatures).ceil.toInt) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a particular reason these test cases differ from the Java ones? I notice in the Java tests we also test some of the regex like ".1" and ".10" and "0.10". I'm wondering if we shouldn't just have a couple test cases for the regex edge cases here (to ensure it gets translated correctly). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @MLnick That might be because of the repetitive lines to copy around. I consolidated the test cases so that it is easier to track what are covered. Let me know if additional test cases are needed. |
||
checkFeatureSubsetStrategy(numTrees = 1, "0.5", (0.5 * numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 1, "1.0", (1.0 * numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 1, "1", 1) | ||
checkFeatureSubsetStrategy(numTrees = 1, "2", 2) | ||
checkFeatureSubsetStrategy(numTrees = 1, numFeatures.toString, numFeatures) | ||
checkFeatureSubsetStrategy(numTrees = 1, (numFeatures * 2).toString, numFeatures) | ||
|
||
checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) | ||
checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "log2", | ||
(math.log(numFeatures) / math.log(2)).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "0.1", (0.1 * numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "0.5", (0.5 * numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "1.0", (1.0 * numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "1", 1) | ||
checkFeatureSubsetStrategy(numTrees = 2, "2", 2) | ||
checkFeatureSubsetStrategy(numTrees = 2, numFeatures.toString, numFeatures) | ||
checkFeatureSubsetStrategy(numTrees = 2, (numFeatures * 2).toString, numFeatures) | ||
} | ||
|
||
test("Binary classification with continuous features: subsampling features") { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we can simply consolidate the doc into something like: