-
Notifications
You must be signed in to change notification settings - Fork 26
Integrate AITER fused RoPE kernels with fallback to TE native #541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
suachong
wants to merge
11
commits into
dev
Choose a base branch
from
feat/aiter-fused-rope
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
8b05a6a
Integrate AITER fused RoPE kernels with fallback to TE native
suachong 277d40c
Address PR #541 review feedback from ipanfilo
suachong 08ea73e
remove nvidia header
suachong 6bf7634
Add local testing instructions for AITER RoPE tests
suachong 6eb19fe
Add Dockerfile and README for local AITER RoPE testing
suachong c46d806
Remove aiter_rope_test directory from branch
suachong 51cd242
Preserve upstream NVIDIA copyright header in rope.py
suachong 7d4cb24
Address PR #541 review: raise RuntimeError instead of silent fallback
suachong 3f06411
Update transformer_engine/pytorch/attention/rope.py
suachong 69bceaa
Address PR #541 review feedback from Micky774 and wangye805
suachong ec7fc13
Use bare import for IS_HIP_EXTENSION in test file
suachong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
ipanfilo marked this conversation as resolved.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| # This file was modified for portability to AMDGPU | ||
| # Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. | ||
|
suachong marked this conversation as resolved.
|
||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
@@ -11,6 +13,29 @@ | |
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat | ||
|
|
||
| try: | ||
| from torch.utils.cpp_extension import IS_HIP_EXTENSION | ||
| except ImportError: | ||
| IS_HIP_EXTENSION = False | ||
|
Comment on lines
+16
to
+19
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to guard |
||
|
|
||
| _aiter_rope_fwd = None | ||
| _aiter_rope_bwd = None | ||
| _HAVE_AITER_ROPE = False | ||
|
|
||
| if IS_HIP_EXTENSION: | ||
| import os # pylint: disable=wrong-import-order,wrong-import-position | ||
| if os.environ.get("NVTE_USE_AITER_ROPE", "0") == "1": | ||
| try: | ||
| from aiter.ops.rope import ( # pylint: disable=import-error | ||
| rope_fwd as _aiter_rope_fwd, | ||
| rope_bwd as _aiter_rope_bwd, | ||
| ) | ||
| _HAVE_AITER_ROPE = True | ||
| except Exception as _aiter_import_err: # pylint: disable=broad-except | ||
| raise RuntimeError( | ||
| "NVTE_USE_AITER_ROPE=1 but AITER fused RoPE import failed." | ||
| ) from _aiter_import_err | ||
|
|
||
|
|
||
| __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] | ||
|
|
||
|
|
@@ -118,6 +143,23 @@ class FusedRoPEFunc(torch.autograd.Function): | |
| the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def has_aiter_rope(): | ||
| """Return whether AITER RoPE kernels are available.""" | ||
| return _HAVE_AITER_ROPE | ||
|
|
||
| @staticmethod | ||
| def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): | ||
| """Check if we can dispatch to AITER's faster rope kernel.""" | ||
| return ( | ||
| _HAVE_AITER_ROPE | ||
| and tensor_format == "sbhd" | ||
| and not interleaved | ||
| and cu_seqlens is None | ||
| and cp_size == 1 | ||
| and start_positions is None | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def forward( | ||
| ctx, | ||
|
|
@@ -139,38 +181,62 @@ def forward( | |
| "bshd", | ||
| "thd", | ||
| ), f"Unsupported tensor_format: {tensor_format}." | ||
| output = tex.fused_rope_forward( | ||
| t, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[tensor_format], | ||
| interleaved, | ||
| cu_seqlens, | ||
| cp_size, | ||
| cp_rank, | ||
|
|
||
| use_aiter = FusedRoPEFunc._can_use_aiter( | ||
| tensor_format, interleaved, cu_seqlens, cp_size, start_positions | ||
| ) | ||
|
|
||
| if use_aiter: | ||
| rotate_style = 1 if interleaved else 0 | ||
| output = _aiter_rope_fwd( | ||
| t, freqs, rotate_style, | ||
| False, # reuse_freqs_front_part | ||
| False, # nope_first | ||
| ) | ||
| else: | ||
| output = tex.fused_rope_forward( | ||
| t, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[tensor_format], | ||
| interleaved, | ||
| cu_seqlens, | ||
| cp_size, | ||
| cp_rank, | ||
| ) | ||
|
|
||
| ctx.save_for_backward(freqs, cu_seqlens, start_positions) | ||
| ctx.tensor_format = tensor_format | ||
| ctx.cp_size = cp_size | ||
| ctx.cp_rank = cp_rank | ||
| ctx.interleaved = interleaved | ||
| ctx.use_aiter = use_aiter | ||
|
|
||
| return output | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: | ||
| """Fused RoPE backward.""" | ||
| freqs, cu_seqlens, start_positions = ctx.saved_tensors | ||
| grad_input = tex.fused_rope_backward( | ||
| grad_output, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[ctx.tensor_format], | ||
| ctx.interleaved, | ||
| cu_seqlens, | ||
| ctx.cp_size, | ||
| ctx.cp_rank, | ||
| ) | ||
|
|
||
| if ctx.use_aiter: | ||
| rotate_style = 1 if ctx.interleaved else 0 | ||
| grad_input = _aiter_rope_bwd( | ||
| grad_output, freqs, rotate_style, | ||
| False, # reuse_freqs_front_part | ||
| False, # nope_first | ||
| ) | ||
| else: | ||
| grad_input = tex.fused_rope_backward( | ||
| grad_output, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[ctx.tensor_format], | ||
| ctx.interleaved, | ||
| cu_seqlens, | ||
| ctx.cp_size, | ||
| ctx.cp_rank, | ||
| ) | ||
|
|
||
| return grad_input, None, None, None, None, None, None, None, None | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.