New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc #3374
Conversation
|
||
val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was SquaredError
before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this. I am taking a look at it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are my findings. I added two more test cases with numIterations = 100.
numIterations = 10, learningRate = 1.0, subsamplingRate = 1.0
metric = 0.8400000000000005
numIterations = 100, learningRate = 1.0, subsamplingRate = 1.0
metric = 0.5344090056285183
numIterations = 10, learningRate = 0.1, subsamplingRate = 1.0
metric = 0.08399999999999984
numIterations = 10, learningRate = 1.0, subsamplingRate = 0.75
metric = 0.8102205882352937
numIterations = 100, learningRate = 1.0, subsamplingRate = 0.75
metric = 0.565608647936787
numIterations = 10, learningRate = 0.1, subsamplingRate = 0.75
metric = 0.11179411764705861
A learning rate of 1 doesn't work very well especially with low number of iterations. Our default learning rate is 0.1 which should be fine.
Suggestion: We remove the learningRate = 1 option from the absolute error test. I can do more testing to check what settings work well for our GBT model and include it as a part of the documentation. I will also compare with scikit-learn to see how much additional loss do we get from an ideal implementation during the documentation phase.
cc: @jkbradley
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@manishamde Thanks for checking this test! Let's fix it in a separate PR. We are going to cut a release candidate and I hope we can update the API before that. Let me know when you finish a pass, I will update the PR following your suggestions.
Test build #23643 has started for PR 3374 at commit
|
Will we have to rename |
@manishamde The current impl is attached to trees. Even if we rename it back to |
@mengxr The plan to move to mllib.ensemble namespace with a new class sounds good to me. |
Should the |
@@ -45,146 +43,92 @@ import org.apache.spark.storage.StorageLevel | |||
* but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should now read something like this "but tree predictions are not computed accurately for LogLoss or AbsoluteError loss functions since they use the mean of the samples at each leaf node of the decision tree".
cc: @jkbradley
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@manishamde The current explanation is correct for the original Gradient Boosting algorithm, which uses weak hypothesis weights and is oblivious to the weak learner being used. Your suggested explanation is really for TreeBoost, Friedman's improvement to the original algorithm which is specialized for trees (which we should add at some point but isn't what we're claiming to have now, I'd say).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jkbradley Agree. Having said that, I am not sure whether the algorithm predictions are changed or not based upon the loss function in other weak learners such as LR. Let's refine this later.
Test build #23643 has finished for PR 3374 at commit
|
Test PASSed. |
Completed my pass. LGTM! 👍 |
Test build #23662 has started for PR 3374 at commit
|
* Currently, gradients are computed correctly for the available loss functions, | ||
* but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. | ||
* Running with those losses will likely behave reasonably, but lacks the same guarantees. | ||
* but tree predictions are not computed correctly for LogLoss or AbsoluteError since they |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(copying comment here since it was on an outdated diff)
The original explanation is correct for the original Gradient Boosting algorithm, which uses weak hypothesis weights and is oblivious to the weak learner being used. This updated explanation is really for TreeBoost, Friedman's improvement to the original algorithm which is specialized for trees (which we should add at some point but isn't what we're claiming to have now, I'd say). So I think the original explanation is more accurate since we do not claim to implement TreeBoost.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverted the changes
@mengxr Thanks for the updates! Just added a few small comments. Other than those, LGTM |
Test build #23663 has started for PR 3374 at commit
|
Test build #23662 has finished for PR 3374 at commit
|
Test PASSed. |
Test build #23663 has finished for PR 3374 at commit
|
Test PASSed. |
@manishamde @jkbradley Thanks! Merged into master and branch-1.2. |
There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent. 1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly. 1. GradientBoosting -> GradientBoostedTrees 1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy 1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.) 1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train` 1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method. 1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting. 1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError 1. doc updates manishamde jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #3374 from mengxr/SPARK-4486 and squashes the following commits: 7097251 [Xiangrui Meng] address joseph's comments 98dea09 [Xiangrui Meng] address manish's comments 4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy ea4c467 [Xiangrui Meng] fix unit tests 751da4e [Xiangrui Meng] rename class method train -> run 19030a5 [Xiangrui Meng] update boosting public APIs (cherry picked from commit 15cacc8) Signed-off-by: Xiangrui Meng <meng@databricks.com>
There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent.
trainClassifier
andtrainRegressor
fromGradientBoostedTrees
because they are the same astrain
train
method torun
because it hides the static methods with the same name in Java. DeprecatedDecisionTree.train
class method.@manishamde @jkbradley