Skip to content

Conversation

@xrennvidia
Copy link
Collaborator

@xrennvidia xrennvidia commented Mar 5, 2025

Description

Loss curve of CP>1 and CP=1 does not match with cuDNN fused attention while batch size > 1, this is due to the non-contiguous dout. This PR fixed it.

cuDNN Fused Attn and Tri Dao' Flash Attn has different requirements on tensor memory format. Tri Dao's Attn only requires the last dim to be contiguous. TE integration of cuDNN Attn specifies qkv_format (bshd, sbhd, thd), tensors need to be contiguous in the specified format, this is a stronger requirement than Tri Dao's Attn. This is why it only shows issue with cuDNN Attn.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Xiaowei Ren <xren@nvidia.com>
@xrennvidia xrennvidia requested a review from cyanguwa March 5, 2025 23:55
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes from a bug where CP=2 fused attn diverges from flash attn. A2A converges because there is a .contigous() for dout in that path.

@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch L1

@cyanguwa cyanguwa added the 2.2.0 label Mar 6, 2025
@xrennvidia
Copy link
Collaborator Author

All CP tests passed in the CI pipeline (both Hopper and Blackwell), so merge it.

@xrennvidia xrennvidia merged commit e1c4f51 into NVIDIA:main Mar 6, 2025
11 checks passed
@xrennvidia xrennvidia deleted the xren/cp_fix_dout branch March 6, 2025 19:41
negvet pushed a commit to negvet/TransformerEngine that referenced this pull request Mar 18, 2025
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants