-
Notifications
You must be signed in to change notification settings - Fork 538
Evaluate the estimators in parallel. #484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces parallel evaluation of estimators to improve performance. This is achieved by refactoring the inference engines to use a new parallel_evaluate utility, which can leverage multiple devices using multithreading. The core evaluation logic for a single estimator has been extracted into a helper function _evaluate_estimator.
The changes look good overall and are a nice improvement. I've found a few issues:
- A high-severity bug in
parallel_evaluate.pythat would cause a crash on non-CUDA devices. - A high-severity bug in
inference.pywhere a model type cast is not assigned back to the model. - Several medium-severity issues related to code cleanup, leftover TODOs, and API consistency.
Please see the detailed comments for suggestions.
|
Hey Brendan, I'd love to get your overall thoughts on this draft before I finish it up. |
|
@oscarkey, would love to. I can give this a proper pass tomorrow if that works! |
brendan-priorlabs
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @oscarkey, thanks for putting this together! Very clean and readable. I left a few minor comments, but this looks solid overall. The only that might be major is the one on Manager.Queue. Keep me posted!
9eb0099 to
eb1d626
Compare
0a04c54 to
187551c
Compare
|
This is now ready for a full review! |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces parallel evaluation of estimators in the tabpfn library, focusing on the "low_memory" and "cache_preprocessing" fit modes. It leverages multithreading to improve performance, particularly for longer datasets where the flash attention kernel releases the GIL. The changes include refactoring input preparation logic and adding a new parallel_execute module to manage parallel execution across multiple PyTorch devices. Tests have been added to ensure consistency between serial and parallel execution.
brendan-priorlabs
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks, Oscar!
LeoGrin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me!
90e8ca2 to
74c2ae8
Compare
* Record copied public PR 484 * Evaluate the estimators in parallel. (#484) Only support parallel evaluation for the "low_memory" and "cache_preprocessing" fit modes for now. Use multithreading to evaluate the model in parallel for each estimator. I selected multithreading over multiprocessing because our benchmarking shows that for longer datasets we spend almost all our time in the flash attention kernel, during which time the GIL is released. This allows multithreading to work efficiently, and it is less complex and avoids starting additional processes (which can take a substantial fraction of the inference time). Ideally `inference.py` would be refactored, but I only did minimal refactoring in this PR. In addition to the tests added in this PR, these changes are also convered by the new consistency tests for each inference mode: #498 . Unfortunately, GitHub does not support multi-gpu testing on the CI yet, so the testing of the parallelisation is a bit limited. I simplified the logic for converting the inputs to tensors and setting the dtype in `_prepare_model_inputs()`: - Always call `torch.as_tensor()`: this only copies if necessary, no need for `if` - Don't surpress exceptions raised by `X_full.float()`: the comment says this is to avoid overflow errors, but I can't find any case when `.float()` would through an exception, except if `X_full` is complex --------- Co-authored-by: mirror-bot <mirror-bot@users.noreply.github.com> Co-authored-by: Oscar Key <oscar@priorlabs.ai>
Only support parallel evaluation for the "low_memory" and "cache_preprocessing" fit modes for now.
Use multithreading to evaluate the model in parallel for each estimator. I selected multithreading over multiprocessing because our benchmarking shows that for longer datasets we spend almost all our time in the flash attention kernel, during which time the GIL is released. This allows multithreading to work efficiently, and it is less complex and avoids starting additional processes (which can take a substantial fraction of the inference time).
Ideally
inference.pywould be refactored, but I only did minimal refactoring in this PR.In addition to the tests added in this PR, these changes are also convered by the new consistency tests for each inference mode: #498 . Unfortunately, GitHub does not support multi-gpu testing on the CI yet, so the testing of the parallelisation is a bit limited.
I simplified the logic for converting the inputs to tensors and setting the dtype in
_prepare_model_inputs():torch.as_tensor(): this only copies if necessary, no need forifX_full.float(): the comment says this is to avoid overflow errors, but I can't find any case when.float()would through an exception, except ifX_fullis complex