Skip to content

Conversation

@oscarkey
Copy link
Contributor

@oscarkey oscarkey commented Aug 5, 2025

Revert the detection of use_gqa to before #330.
This PR bypassed the detection to enable support for multi-gpu training, and instead always enabled the use_gqa flag. However, this flag is not available if Pytorch is older than 2.5. I reverted the change to the detection by copying the previous implementation from
https://github.com/PriorLabs/TabPFN/blob/17961ba58d6d1c9ec86d0e692909ba306f76d935/src/tabpfn/model/multi_head_attention.py.

The bug was reported in #418

In a follow-up PR I will fix multi-gpu training again.

Public API changes

No api changes.

Testing

This bug picked up by the CI because we don't yet have GPU tests enabled. Hence I performed manual testing.

Before this PR (torch 2.4.1):

$ python examples/tabpfn_for_multiclass_classification.py
TypeError: scaled_dot_product_attention() got an unexpected keyword argument 'enable_gqa'

After this PR (torch 2.4.1 and torch 2.7.1):

$ python examples/tabpfn_for_multiclass_classification.py
ROC AUC: 1.0
Accuracy 0.98

I benchmarked inference speed before and after this PR on torch 2.7.1, and it was unchanged.

Revert the detection of "use_gqa" to before
#330.
In this PR the detection was bypassed to enable support for multi-gpu
training, but this breaks TabPFN if Pytorch is older than 2.5.
I copied the previous implementation from
https://github.com/PriorLabs/TabPFN/blob/17961ba58d6d1c9ec86d0e692909ba306f76d935/src/tabpfn/model/multi_head_attention.py.

In a follow-up PR I will fix multi-gpu training again.

Testing: This isn't picked up by the CI because we don't yet have GPU
tests enabled.

Before fix  (torch 2.4.1):
$ python examples/tabpfn_for_multiclass_classification.py
TypeError: scaled_dot_product_attention() got an unexpected keyword
argument 'enable_gqa'

After fix (torch 2.4.1 and torch 2.7.1):
$ python examples/tabpfn_for_multiclass_classification.py
ROC AUC: 1.0
Accuracy 0.98

Inference speed on 2.7.1 is unaffected by the change.
@oscarkey oscarkey requested review from LeoGrin and noahho August 5, 2025 16:42
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @oscarkey, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've addressed a compatibility issue with PyTorch versions older than 2.5 by reverting a previous change that forced the use of the enable_gqa flag in scaled_dot_product_attention. This PR restores the dynamic detection of GQA API support, ensuring that the application functions correctly across a wider range of PyTorch versions without encountering TypeError exceptions. The core change involves re-enabling the original logic that checks for enable_gqa support and GPU compute capability.

Highlights

  • Revert GQA detection logic: I have reverted the logic that detects whether PyTorch supports the enable_gqa argument in scaled_dot_product_attention. This change restores the correct detection mechanism, ensuring compatibility with PyTorch versions older than 2.5.
  • Restore try-except for enable_gqa: I have re-enabled the try-except block that attempts to call scaled_dot_product_attention with enable_gqa=True. This allows the system to dynamically determine if the PyTorch version in use supports this feature, preventing TypeError exceptions on older versions.
  • Re-enable NVIDIA compute capability check: I have reinstated the check for NVIDIA GPU compute capability. This check, combined with the enable_gqa support detection, accurately determines if USE_TORCH_2_GQA should be enabled, ensuring optimal performance on supported hardware while maintaining compatibility.
  • Remove hardcoded USE_TORCH_2_GQA override: I have removed the temporary hardcoded USE_TORCH_2_GQA = True that was introduced as a workaround. This hardcoding caused issues with PyTorch versions that do not support the enable_gqa argument, leading to runtime errors.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes an issue where Grouped-Query Attention (GQA) was unconditionally enabled, causing errors on PyTorch versions older than 2.5. The change introduces proper feature detection for GQA support. My review includes two suggestions to improve the implementation: one to make the feature detection more robust by using CUDA tensors for the check, and another to remove a redundant conditional check to improve code clarity and remove dead code.

Copy link
Collaborator

@LeoGrin LeoGrin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM when ruff passes

@oscarkey oscarkey enabled auto-merge (squash) August 6, 2025 08:01
@oscarkey oscarkey disabled auto-merge August 6, 2025 08:01
@oscarkey oscarkey enabled auto-merge (squash) August 6, 2025 08:12
@oscarkey oscarkey merged commit 90a6652 into main Aug 6, 2025
8 checks passed
@oscarkey oscarkey deleted the ok-fix-gqa-check branch August 6, 2025 08:29
@oscarkey oscarkey linked an issue Aug 6, 2025 that may be closed by this pull request
oscarkey added a commit that referenced this pull request Nov 12, 2025
* Record copied public PR 438

* Fix GQA api detection on PyTorch<2.5. (#438)

Revert the detection of use_gqa to before #330.
This PR bypassed the detection to enable support for multi-gpu training, and instead always enabled the use_gqa flag. However, this flag is not available if Pytorch is older than 2.5. I reverted the change to the detection by copying the previous implementation from
https://github.com/PriorLabs/TabPFN/blob/17961ba58d6d1c9ec86d0e692909ba306f76d935/src/tabpfn/model/multi_head_attention.py.

The bug was reported in #418.

In a follow-up PR I will fix multi-gpu training again.

---------

Co-authored-by: mirror-bot <mirror-bot@users.noreply.github.com>
Co-authored-by: Oscar Key <oscar.t.key@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

PyTorch compatibility issue: enable_gqa detection fails

4 participants