Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions tests/pytorch/test_fused_rope.py
Comment thread
suachong marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -11,6 +13,12 @@
apply_fused_qkv_rotary_pos_emb,
)

from torch.utils.cpp_extension import IS_HIP_EXTENSION

if IS_HIP_EXTENSION:
import unittest.mock as mock
from transformer_engine.pytorch.attention.rope import FusedRoPEFunc


# Gradient is a broadcasted scalar
def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -495,3 +503,112 @@ def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_wi
atol=1e-8,
rtol=1e-8,
)


Comment thread
suachong marked this conversation as resolved.
# AITER RoPE tests require:
# 1. ROCm environment with `aiter` installed (pip install amd-aiter)
# 2. NVTE_USE_AITER_ROPE=1 environment variable set before running
# Example:
# NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_matches_te_fused -v
# NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_can_use_guard -v


@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER RoPE requires ROCm")
@pytest.mark.skipif(
IS_HIP_EXTENSION and not FusedRoPEFunc.has_aiter_rope(),
reason="AITER RoPE not available",
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
def test_aiter_rope_matches_te_fused(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
loss_func: Callable,
) -> None:
"""
When AITER dispatch is active (sbhd, non-interleaved, cp_size=1, no cu_seqlens,
no start_positions), verify output and gradients match the TE fused kernel.
"""

device = torch.device("cuda")
batch_size, head_num = 2, 64
tensor_format = "sbhd"
interleaved = False
cp_size = 1
cp_rank = 0

t = torch.rand(
(seq_length, batch_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
t.requires_grad = True

rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(seq_length)

output_aiter = apply_rotary_pos_emb(
t,
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_aiter = loss_func(output_aiter)
loss_aiter.backward()
grad_aiter = t.grad.detach().clone()
t.grad = None

with mock.patch.object(FusedRoPEFunc, "_can_use_aiter", return_value=False):
output_te = apply_rotary_pos_emb(
t,
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_te = loss_func(output_te)
loss_te.backward()
grad_te = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_aiter, output_te)
torch.testing.assert_close(grad_aiter, grad_te)


@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER RoPE requires ROCm")
@pytest.mark.parametrize(
"tensor_format,interleaved,cu_seqlens,cp_size,start_positions,expected",
[
("sbhd", False, None, 1, None, True),
("bshd", False, None, 1, None, False),
("sbhd", True, None, 1, None, False),
("sbhd", False, torch.tensor([0, 10]), 1, None, False),
("sbhd", False, None, 2, None, False),
("sbhd", False, None, 1, torch.tensor([0]), False),
],
)
def test_aiter_rope_can_use_guard(
tensor_format: str,
interleaved: bool,
cu_seqlens,
cp_size: int,
start_positions,
expected: bool,
) -> None:
"""Unit test the _can_use_aiter guard logic exhaustively."""
if not FusedRoPEFunc.has_aiter_rope() and expected:
pytest.skip("AITER not available — guard always returns False for True cases")
result = FusedRoPEFunc._can_use_aiter(
tensor_format, interleaved, cu_seqlens, cp_size, start_positions
)
assert result == expected
104 changes: 85 additions & 19 deletions transformer_engine/pytorch/attention/rope.py
Comment thread
ipanfilo marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
Comment thread
suachong marked this conversation as resolved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -11,6 +13,29 @@
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat

try:
from torch.utils.cpp_extension import IS_HIP_EXTENSION
except ImportError:
IS_HIP_EXTENSION = False
Comment on lines +16 to +19
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to guard


_aiter_rope_fwd = None
_aiter_rope_bwd = None
_HAVE_AITER_ROPE = False

if IS_HIP_EXTENSION:
import os # pylint: disable=wrong-import-order,wrong-import-position
if os.environ.get("NVTE_USE_AITER_ROPE", "0") == "1":
try:
from aiter.ops.rope import ( # pylint: disable=import-error
rope_fwd as _aiter_rope_fwd,
rope_bwd as _aiter_rope_bwd,
)
_HAVE_AITER_ROPE = True
except Exception as _aiter_import_err: # pylint: disable=broad-except
raise RuntimeError(
"NVTE_USE_AITER_ROPE=1 but AITER fused RoPE import failed."
) from _aiter_import_err


__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]

Expand Down Expand Up @@ -118,6 +143,23 @@ class FusedRoPEFunc(torch.autograd.Function):
the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""

@staticmethod
def has_aiter_rope():
"""Return whether AITER RoPE kernels are available."""
return _HAVE_AITER_ROPE

@staticmethod
def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions):
"""Check if we can dispatch to AITER's faster rope kernel."""
return (
_HAVE_AITER_ROPE
and tensor_format == "sbhd"
and not interleaved
and cu_seqlens is None
and cp_size == 1
and start_positions is None
)

@staticmethod
def forward(
ctx,
Expand All @@ -139,38 +181,62 @@ def forward(
"bshd",
"thd",
), f"Unsupported tensor_format: {tensor_format}."
output = tex.fused_rope_forward(
t,
freqs,
start_positions,
QKVFormat[tensor_format],
interleaved,
cu_seqlens,
cp_size,
cp_rank,

use_aiter = FusedRoPEFunc._can_use_aiter(
tensor_format, interleaved, cu_seqlens, cp_size, start_positions
)

if use_aiter:
rotate_style = 1 if interleaved else 0
output = _aiter_rope_fwd(
t, freqs, rotate_style,
False, # reuse_freqs_front_part
False, # nope_first
)
else:
output = tex.fused_rope_forward(
t,
freqs,
start_positions,
QKVFormat[tensor_format],
interleaved,
cu_seqlens,
cp_size,
cp_rank,
)

ctx.save_for_backward(freqs, cu_seqlens, start_positions)
ctx.tensor_format = tensor_format
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
ctx.use_aiter = use_aiter

return output

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused RoPE backward."""
freqs, cu_seqlens, start_positions = ctx.saved_tensors
grad_input = tex.fused_rope_backward(
grad_output,
freqs,
start_positions,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
cu_seqlens,
ctx.cp_size,
ctx.cp_rank,
)

if ctx.use_aiter:
rotate_style = 1 if ctx.interleaved else 0
grad_input = _aiter_rope_bwd(
grad_output, freqs, rotate_style,
False, # reuse_freqs_front_part
False, # nope_first
)
else:
grad_input = tex.fused_rope_backward(
grad_output,
freqs,
start_positions,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
cu_seqlens,
ctx.cp_size,
ctx.cp_rank,
)

return grad_input, None, None, None, None, None, None, None, None

Expand Down