Skip to content

Conversation

@oscarkey
Copy link
Contributor

@oscarkey oscarkey commented Sep 4, 2025

Only support parallel evaluation for the "low_memory" and "cache_preprocessing" fit modes for now.

Use multithreading to evaluate the model in parallel for each estimator. I selected multithreading over multiprocessing because our benchmarking shows that for longer datasets we spend almost all our time in the flash attention kernel, during which time the GIL is released. This allows multithreading to work efficiently, and it is less complex and avoids starting additional processes (which can take a substantial fraction of the inference time).

Ideally inference.py would be refactored, but I only did minimal refactoring in this PR.

In addition to the tests added in this PR, these changes are also convered by the new consistency tests for each inference mode: #498 . Unfortunately, GitHub does not support multi-gpu testing on the CI yet, so the testing of the parallelisation is a bit limited.

I simplified the logic for converting the inputs to tensors and setting the dtype in _prepare_model_inputs():

  • Always call torch.as_tensor(): this only copies if necessary, no need for if
  • Don't surpress exceptions raised by X_full.float(): the comment says this is to avoid overflow errors, but I can't find any case when .float() would through an exception, except if X_full is complex

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces parallel evaluation of estimators to improve performance. This is achieved by refactoring the inference engines to use a new parallel_evaluate utility, which can leverage multiple devices using multithreading. The core evaluation logic for a single estimator has been extracted into a helper function _evaluate_estimator.

The changes look good overall and are a nice improvement. I've found a few issues:

  • A high-severity bug in parallel_evaluate.py that would cause a crash on non-CUDA devices.
  • A high-severity bug in inference.py where a model type cast is not assigned back to the model.
  • Several medium-severity issues related to code cleanup, leftover TODOs, and API consistency.

Please see the detailed comments for suggestions.

@oscarkey
Copy link
Contributor Author

oscarkey commented Sep 4, 2025

Hey Brendan, I'd love to get your overall thoughts on this draft before I finish it up.

@brendan-priorlabs
Copy link
Contributor

@oscarkey, would love to. I can give this a proper pass tomorrow if that works!

Copy link
Contributor

@brendan-priorlabs brendan-priorlabs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @oscarkey, thanks for putting this together! Very clean and readable. I left a few minor comments, but this looks solid overall. The only that might be major is the one on Manager.Queue. Keep me posted!

@oscarkey oscarkey force-pushed the ok-multiple-devices branch from 9eb0099 to eb1d626 Compare September 9, 2025 08:51
@oscarkey oscarkey force-pushed the ok-multiple-devices-2 branch from 0a04c54 to 187551c Compare September 10, 2025 07:16
@oscarkey oscarkey changed the title [WIP] Evaluate the estimators in parallel. Evaluate the estimators in parallel. Sep 10, 2025
@oscarkey oscarkey requested a review from LeoGrin September 10, 2025 09:51
@oscarkey
Copy link
Contributor Author

This is now ready for a full review!

@oscarkey
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces parallel evaluation of estimators in the tabpfn library, focusing on the "low_memory" and "cache_preprocessing" fit modes. It leverages multithreading to improve performance, particularly for longer datasets where the flash attention kernel releases the GIL. The changes include refactoring input preparation logic and adding a new parallel_execute module to manage parallel execution across multiple PyTorch devices. Tests have been added to ensure consistency between serial and parallel execution.

Copy link
Contributor

@brendan-priorlabs brendan-priorlabs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, Oscar!

Copy link
Collaborator

@LeoGrin LeoGrin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me!

Base automatically changed from ok-multiple-devices to main September 11, 2025 11:40
@oscarkey oscarkey force-pushed the ok-multiple-devices-2 branch from 90e8ca2 to 74c2ae8 Compare September 11, 2025 11:43
@oscarkey oscarkey enabled auto-merge (squash) September 11, 2025 12:16
@oscarkey oscarkey merged commit 2745764 into main Sep 11, 2025
10 checks passed
@oscarkey oscarkey deleted the ok-multiple-devices-2 branch September 11, 2025 12:47
oscarkey added a commit that referenced this pull request Nov 12, 2025
* Record copied public PR 484

* Evaluate the estimators in parallel. (#484)

Only support parallel evaluation for the "low_memory" and "cache_preprocessing" fit modes for now.

Use multithreading to evaluate the model in parallel for each estimator. I selected multithreading over multiprocessing because our benchmarking shows that for longer datasets we spend almost all our time in the flash attention kernel, during which time the GIL is released. This allows multithreading to work efficiently, and it is less complex and avoids starting additional processes (which can take a substantial fraction of the inference time).

Ideally `inference.py` would be refactored, but I only did minimal refactoring in this PR.

In addition to the tests added in this PR, these changes are also convered by the new consistency tests for each inference mode: #498 . Unfortunately, GitHub does not support multi-gpu testing on the CI yet, so the testing of the parallelisation is a bit limited.

I simplified the logic for converting the inputs to tensors and setting the dtype in `_prepare_model_inputs()`:
- Always call `torch.as_tensor()`: this only copies if necessary, no need for `if`
- Don't surpress exceptions raised by `X_full.float()`: the comment says this is to avoid overflow errors, but I can't find any case when `.float()` would through an exception, except if `X_full` is complex

---------

Co-authored-by: mirror-bot <mirror-bot@users.noreply.github.com>
Co-authored-by: Oscar Key <oscar@priorlabs.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants