-
Notifications
You must be signed in to change notification settings - Fork 28.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-3724][ML] RandomForest: More options for feature subset size.
## What changes were proposed in this pull request? This PR tries to support more options for feature subset size in RandomForest implementation. Previously, RandomForest only support "auto", "all", "sort", "log2", "onethird". This PR tries to support any given value to allow model search. In this PR, `featureSubsetStrategy` could be passed with: a) a real number in the range of `(0.0-1.0]` that represents the fraction of the number of features in each subset, b) an integer number (`>0`) that represents the number of features in each subset. ## How was this patch tested? Two tests `JavaRandomForestClassifierSuite` and `JavaRandomForestRegressorSuite` have been updated to check the additional options for params in this PR. An additional test has been added to `org.apache.spark.mllib.tree.RandomForestSuite` to cover the cases in this PR. Author: Yong Tang <yong.tang.github@outlook.com> Closes #11989 from yongtang/SPARK-3724.
- Loading branch information
Showing
6 changed files
with
95 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -426,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
(math.log(numFeatures) / math.log(2)).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) | ||
|
||
val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0") | ||
for (strategy <- realStrategies) { | ||
val expected = (strategy.toDouble * numFeatures).ceil.toInt | ||
checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) | ||
} | ||
|
||
val integerStrategies = Array("1", "10", "100", "1000", "10000") | ||
for (strategy <- integerStrategies) { | ||
val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
yongtang
Author
Contributor
|
||
checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) | ||
} | ||
|
||
val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") | ||
for (invalidStrategy <- invalidStrategies) { | ||
intercept[MatchError]{ | ||
val metadata = | ||
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) | ||
} | ||
} | ||
|
||
checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) | ||
checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "log2", | ||
(math.log(numFeatures) / math.log(2)).ceil.toInt) | ||
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) | ||
|
||
for (strategy <- realStrategies) { | ||
val expected = (strategy.toDouble * numFeatures).ceil.toInt | ||
checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) | ||
} | ||
|
||
for (strategy <- integerStrategies) { | ||
This comment has been minimized.
Sorry, something went wrong. |
||
val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures | ||
checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) | ||
} | ||
for (invalidStrategy <- invalidStrategies) { | ||
intercept[MatchError]{ | ||
val metadata = | ||
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) | ||
} | ||
} | ||
} | ||
|
||
test("Binary classification with continuous features: subsampling features") { | ||
|
@yongtang @MLnick Isn't the line
math.min
? It could beval expected = math.min(strategy.toInt, numFeatures)
.