-
Notifications
You must be signed in to change notification settings - Fork 538
Fix GQA api detection on PyTorch<2.5. #438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Revert the detection of "use_gqa" to before #330. In this PR the detection was bypassed to enable support for multi-gpu training, but this breaks TabPFN if Pytorch is older than 2.5. I copied the previous implementation from https://github.com/PriorLabs/TabPFN/blob/17961ba58d6d1c9ec86d0e692909ba306f76d935/src/tabpfn/model/multi_head_attention.py. In a follow-up PR I will fix multi-gpu training again. Testing: This isn't picked up by the CI because we don't yet have GPU tests enabled. Before fix (torch 2.4.1): $ python examples/tabpfn_for_multiclass_classification.py TypeError: scaled_dot_product_attention() got an unexpected keyword argument 'enable_gqa' After fix (torch 2.4.1 and torch 2.7.1): $ python examples/tabpfn_for_multiclass_classification.py ROC AUC: 1.0 Accuracy 0.98 Inference speed on 2.7.1 is unaffected by the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @oscarkey, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've addressed a compatibility issue with PyTorch versions older than 2.5 by reverting a previous change that forced the use of the enable_gqa flag in scaled_dot_product_attention. This PR restores the dynamic detection of GQA API support, ensuring that the application functions correctly across a wider range of PyTorch versions without encountering TypeError exceptions. The core change involves re-enabling the original logic that checks for enable_gqa support and GPU compute capability.
Highlights
- Revert GQA detection logic: I have reverted the logic that detects whether PyTorch supports the
enable_gqaargument inscaled_dot_product_attention. This change restores the correct detection mechanism, ensuring compatibility with PyTorch versions older than 2.5. - Restore
try-exceptforenable_gqa: I have re-enabled thetry-exceptblock that attempts to callscaled_dot_product_attentionwithenable_gqa=True. This allows the system to dynamically determine if the PyTorch version in use supports this feature, preventingTypeErrorexceptions on older versions. - Re-enable NVIDIA compute capability check: I have reinstated the check for NVIDIA GPU compute capability. This check, combined with the
enable_gqasupport detection, accurately determines ifUSE_TORCH_2_GQAshould be enabled, ensuring optimal performance on supported hardware while maintaining compatibility. - Remove hardcoded
USE_TORCH_2_GQAoverride: I have removed the temporary hardcodedUSE_TORCH_2_GQA = Truethat was introduced as a workaround. This hardcoding caused issues with PyTorch versions that do not support theenable_gqaargument, leading to runtime errors.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes an issue where Grouped-Query Attention (GQA) was unconditionally enabled, causing errors on PyTorch versions older than 2.5. The change introduces proper feature detection for GQA support. My review includes two suggestions to improve the implementation: one to make the feature detection more robust by using CUDA tensors for the check, and another to remove a redundant conditional check to improve code clarity and remove dead code.
LeoGrin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM when ruff passes
* Record copied public PR 438 * Fix GQA api detection on PyTorch<2.5. (#438) Revert the detection of use_gqa to before #330. This PR bypassed the detection to enable support for multi-gpu training, and instead always enabled the use_gqa flag. However, this flag is not available if Pytorch is older than 2.5. I reverted the change to the detection by copying the previous implementation from https://github.com/PriorLabs/TabPFN/blob/17961ba58d6d1c9ec86d0e692909ba306f76d935/src/tabpfn/model/multi_head_attention.py. The bug was reported in #418. In a follow-up PR I will fix multi-gpu training again. --------- Co-authored-by: mirror-bot <mirror-bot@users.noreply.github.com> Co-authored-by: Oscar Key <oscar.t.key@gmail.com>
Revert the detection of
use_gqato before #330.This PR bypassed the detection to enable support for multi-gpu training, and instead always enabled the
use_gqaflag. However, this flag is not available if Pytorch is older than 2.5. I reverted the change to the detection by copying the previous implementation fromhttps://github.com/PriorLabs/TabPFN/blob/17961ba58d6d1c9ec86d0e692909ba306f76d935/src/tabpfn/model/multi_head_attention.py.
The bug was reported in #418
In a follow-up PR I will fix multi-gpu training again.
Public API changes
No api changes.
Testing
This bug picked up by the CI because we don't yet have GPU tests enabled. Hence I performed manual testing.
Before this PR (torch 2.4.1):
After this PR (torch 2.4.1 and torch 2.7.1):
I benchmarked inference speed before and after this PR on torch 2.7.1, and it was unchanged.