-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-29232][ML] Update the parameter maps of the DecisionTreeRegression/Classification Models #26154
Conversation
…sion/Classification Models
@@ -203,6 +203,7 @@ class GBTClassifier @Since("1.4.0") ( | |||
} else { | |||
GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) | |||
} | |||
baseLearners.map(tree => copyValues(tree)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe baseLearners.foreach(copyValues)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I will change to baseLearners.foreach(copyValues(_))
because copyValues
takes two arguments.
@@ -143,6 +143,7 @@ class RandomForestClassifier @Since("1.4.0") ( | |||
val trees = RandomForest | |||
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) | |||
.map(_.asInstanceOf[DecisionTreeClassificationModel]) | |||
trees.map(tree => copyValues(tree)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trees.foreach(copyValues)
@@ -130,6 +130,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S | |||
val trees = RandomForest | |||
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) | |||
.map(_.asInstanceOf[DecisionTreeRegressionModel]) | |||
trees.map(tree => copyValues(tree)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
One more question: Is this satisfied? |
Test build #112244 has finished for PR 26154 at commit
|
The trees in |
Test build #112248 has finished for PR 26154 at commit
|
Merged to master, thanks @huaxingao |
Thanks a lot! @zhengruifeng |
What changes were proposed in this pull request?
The trees (Array[
DecisionTreeRegressionModel
]) inRandomForestRegressionModel
only contains the default parameter value. Need to update the parameter maps for these trees.Same issues in
RandomForestClassifier
,GBTClassifier
andGBTRegressor
Why are the changes needed?
User wants to access each individual tree and build the trees back up for the random forest estimator. This doesn't work because trees don't have the correct parameter values
Does this PR introduce any user-facing change?
Yes. Now the trees in
RandomForestRegressionModel
,RandomForestClassifier
,GBTClassifier
andGBTRegressor
have the correct parameter values.How was this patch tested?
Add tests