-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW] MNMG RF broadcast feature #3349
[REVIEW] MNMG RF broadcast feature #3349
Conversation
@viclafargue Thanks for working on this. I will make sure to review once it's marked ready for review. Also, let me know if you'd like some early feedback. |
Thanks @hcho3! I think that it should be ready for a first review/discussion. For now, the broadcast feature is implemented for training, but only partially implemented for inference ( Additionally, is it preferable to perform the reduction in a distributed or local fashion? Some of the Dask features I am using only work with host arrays. Maybe a better solution would be to assume there's enough device space on the client and to perform the reduction there with GPU acceleration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me.
Have you tested in a MNMG environment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The train code looks really nice and clean. Per discussion, I think it's worth moving the prediction code over to do the predict_proba+average+vote for classification to ensure no change in prediction. I'd really hope we can get it all onto gpu too with dask.array/cupy arrays... if mean isn't working on gpu for the reduction you need, maybe worth conversing with the dask team.
def _func_predict(model, input_data, **kwargs): | ||
X = concatenate(input_data) | ||
with cuml.using_output_type("numpy"): | ||
prediction = model.predict(X, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per discussion, for classification, this should probably be predict_proba
and then we take the per-class average (maybe make a 3d array of nrowsnclassesnworkers) and average over the 3rd dim. We'd also want to weight by the ntrees per worker, since they could differ a bit (e.g. training 101 trees on 10 workers would lead to an extra tree on 1). Same approach (weighted average reduction) should work for regression, just working like nclasses=1.
Thanks, only MG for now. I need to do this. |
@viclafargue Looks good to me. |
Thank you for the review! I'll test the PR on MNMG settings. Unfortunately, it's a little bit hard to setup one at the moment because of technical problems. I'll keep you updated. |
Codecov Report
@@ Coverage Diff @@
## branch-0.19 #3349 +/- ##
===============================================
+ Coverage 79.21% 80.89% +1.67%
===============================================
Files 226 227 +1
Lines 17900 17796 -104
===============================================
+ Hits 14180 14396 +216
+ Misses 3720 3400 -320
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good... I just had a few suggestions then will approve quick. The weighting scheme via reduction looks good, but it may be nonobvious to code-readers, so it'd be good to beef up the comments there a bit. Small test suggestions too. Otherwise great!
ff5e505
to
4a0685d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, Victor!
@gpucibot merge |
Answers #3343 and #3342
This will provide the following new features to MNMG RF: