Skip to content

[PyTorch] Disable Flash Attention backend in Userbuffers tests#2399

Merged
ksivaman merged 2 commits into
NVIDIA:mainfrom
timmoon10:tmoon/debug-ub-on-ampere
Nov 19, 2025
Merged

[PyTorch] Disable Flash Attention backend in Userbuffers tests#2399
ksivaman merged 2 commits into
NVIDIA:mainfrom
timmoon10:tmoon/debug-ub-on-ampere

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

Description

We have experienced some Userbuffers test failures on A100s, apparently because the Flash Attention backward pass introduces numerical errors. This test is primarily intended to test the linear layers and not attention, so as a quick fix I've just disabled the Flash Attention backend.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Disable Flash Attention backend in Userbuffers tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added bug Something isn't working 2.10.0 labels Nov 18, 2025
Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as a fix for green CI, @cyanguwa we should document + try and root cause this

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Nov 18, 2025

Greptile Summary

  • Disables Flash Attention backend in Userbuffers layer tests by setting NVTE_FLASH_ATTN=0 to prevent numerical errors on A100s
  • The fix is appropriately scoped to _run_layer_with_overlap function only, affecting tests for Linear, LayerNormLinear, LayerNormMLP, MultiheadAttention, and TransformerLayer

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The change is minimal and surgical - it only adds an environment variable to disable Flash Attention in specific tests that were experiencing numerical errors. The fix properly sets and unsets the environment variable, maintaining clean test isolation. The scope is appropriately limited to layer tests (not low-level GEMM tests), which aligns with the PR description stating this is a quick fix for Userbuffers tests that are primarily intended to test linear layers, not attention.
  • No files require special attention

Important Files Changed

Filename Overview
tests/pytorch/distributed/test_comm_gemm_overlap.py Added NVTE_FLASH_ATTN=0 environment variable to disable Flash Attention backend in layer overlap tests to avoid numerical errors

Sequence Diagram

sequenceDiagram
    participant Test as test_layers_with_overlap_*
    participant Helper as _run_layer_with_overlap
    participant Env as Environment
    participant Subprocess as run_layer_with_overlap.py
    
    Test->>Helper: Call with layer parameters
    Helper->>Env: Set "NVTE_FLASH_ATTN=0"
    Helper->>Env: Set "PYTORCH_JIT=0"
    Helper->>Env: Set "NVTE_TORCH_COMPILE=0"
    Helper->>Env: Set "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0"
    Helper->>Subprocess: Run test with modified environment
    Subprocess-->>Helper: Return test result
    Helper->>Env: Unset "NVTE_FLASH_ATTN"
    Helper->>Env: Unset "PYTORCH_JIT"
    Helper->>Env: Unset "NVTE_TORCH_COMPILE"
    Helper->>Env: Unset "NVTE_ALLOW_NONDETERMINISTIC_ALGO"
    Helper-->>Test: Return success/failure
Loading

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Agree with @ksivaman on documenting it. On the JAX side, we've mostly been sprinkling TODO for such things - I wonder if @cyanguwa would prefer something like that or a different approach ?

@ksivaman ksivaman merged commit e6da012 into NVIDIA:main Nov 19, 2025
24 of 31 checks passed
@timmoon10 timmoon10 deleted the tmoon/debug-ub-on-ampere branch November 20, 2025 05:01
KshitijLakhani pushed a commit that referenced this pull request Nov 20, 2025
Disable Flash attention in Userbuffers tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.10.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants