Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29232][ML] Update the parameter maps of the DecisionTreeRegression/Classification Models #26154

Closed
wants to merge 3 commits into from

Conversation

huaxingao
Copy link
Contributor

What changes were proposed in this pull request?

The trees (Array[DecisionTreeRegressionModel]) in RandomForestRegressionModel only contains the default parameter value. Need to update the parameter maps for these trees.
Same issues in RandomForestClassifier, GBTClassifier and GBTRegressor

Why are the changes needed?

User wants to access each individual tree and build the trees back up for the random forest estimator. This doesn't work because trees don't have the correct parameter values

Does this PR introduce any user-facing change?

Yes. Now the trees in RandomForestRegressionModel, RandomForestClassifier, GBTClassifier and GBTRegressor have the correct parameter values.

How was this patch tested?

Add tests

@@ -203,6 +203,7 @@ class GBTClassifier @Since("1.4.0") (
} else {
GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
}
baseLearners.map(tree => copyValues(tree))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe baseLearners.foreach(copyValues) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I will change to baseLearners.foreach(copyValues(_)) because copyValues takes two arguments.

@@ -143,6 +143,7 @@ class RandomForestClassifier @Since("1.4.0") (
val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
trees.map(tree => copyValues(tree))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trees.foreach(copyValues)

@@ -130,6 +130,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])
trees.map(tree => copyValues(tree))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@zhengruifeng
Copy link
Contributor

zhengruifeng commented Oct 18, 2019

One more question:
copyValues suppose that
Warning: This implicitly assumes that this [[Params]] instance and the target instance share the same set of default Params.

Is this satisfied?

@SparkQA
Copy link

SparkQA commented Oct 18, 2019

Test build #112244 has finished for PR 26154 at commit 3689884.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@huaxingao
Copy link
Contributor Author

copyValues suppose that
Warning: This implicitly assumes that this [[Params]] instance and the target instance share the same set of default Params.

Is this satisfied?

The trees in RandomForestRegressionModel is an Array of DecisionTreeRegressionModel. DecisionTreeRegressionModel doesn't have exact the same params as RandomForestRegressor, but for the params they share, it seems to me that the underlying trees and RandomForestRegressor should have the same values.

@SparkQA
Copy link

SparkQA commented Oct 18, 2019

Test build #112248 has finished for PR 26154 at commit 2423489.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor

Merged to master, thanks @huaxingao

@huaxingao
Copy link
Contributor Author

Thanks a lot! @zhengruifeng

@huaxingao huaxingao deleted the spark-29232 branch October 22, 2019 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
4 participants