diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 50624df9e..86e2afeaa 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,6 +13,12 @@ apply_fused_qkv_rotary_pos_emb, ) +from torch.utils.cpp_extension import IS_HIP_EXTENSION + +if IS_HIP_EXTENSION: + import unittest.mock as mock + from transformer_engine.pytorch.attention.rope import FusedRoPEFunc + # Gradient is a broadcasted scalar def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: @@ -495,3 +503,112 @@ def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_wi atol=1e-8, rtol=1e-8, ) + + +# AITER RoPE tests require: +# 1. ROCm environment with `aiter` installed (pip install amd-aiter) +# 2. NVTE_USE_AITER_ROPE=1 environment variable set before running +# Example: +# NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_matches_te_fused -v +# NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_can_use_guard -v + + +@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER RoPE requires ROCm") +@pytest.mark.skipif( + IS_HIP_EXTENSION and not FusedRoPEFunc.has_aiter_rope(), + reason="AITER RoPE not available", +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("seq_length", [2048, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [64, 128, 256]) +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +def test_aiter_rope_matches_te_fused( + dtype: torch.dtype, + seq_length: int, + hidden_size: int, + rotary_percent: float, + loss_func: Callable, +) -> None: + """ + When AITER dispatch is active (sbhd, non-interleaved, cp_size=1, no cu_seqlens, + no start_positions), verify output and gradients match the TE fused kernel. + """ + + device = torch.device("cuda") + batch_size, head_num = 2, 64 + tensor_format = "sbhd" + interleaved = False + cp_size = 1 + cp_rank = 0 + + t = torch.rand( + (seq_length, batch_size, head_num, hidden_size), + dtype=dtype, + device=device, + ) + t.requires_grad = True + + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) + emb = rotary_pos_emb(seq_length) + + output_aiter = apply_rotary_pos_emb( + t, + emb, + tensor_format=tensor_format, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss_aiter = loss_func(output_aiter) + loss_aiter.backward() + grad_aiter = t.grad.detach().clone() + t.grad = None + + with mock.patch.object(FusedRoPEFunc, "_can_use_aiter", return_value=False): + output_te = apply_rotary_pos_emb( + t, + emb, + tensor_format=tensor_format, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss_te = loss_func(output_te) + loss_te.backward() + grad_te = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(output_aiter, output_te) + torch.testing.assert_close(grad_aiter, grad_te) + + +@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER RoPE requires ROCm") +@pytest.mark.parametrize( + "tensor_format,interleaved,cu_seqlens,cp_size,start_positions,expected", + [ + ("sbhd", False, None, 1, None, True), + ("bshd", False, None, 1, None, False), + ("sbhd", True, None, 1, None, False), + ("sbhd", False, torch.tensor([0, 10]), 1, None, False), + ("sbhd", False, None, 2, None, False), + ("sbhd", False, None, 1, torch.tensor([0]), False), + ], +) +def test_aiter_rope_can_use_guard( + tensor_format: str, + interleaved: bool, + cu_seqlens, + cp_size: int, + start_positions, + expected: bool, +) -> None: + """Unit test the _can_use_aiter guard logic exhaustively.""" + if not FusedRoPEFunc.has_aiter_rope() and expected: + pytest.skip("AITER not available — guard always returns False for True cases") + result = FusedRoPEFunc._can_use_aiter( + tensor_format, interleaved, cu_seqlens, cp_size, start_positions + ) + assert result == expected diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 77ad57ed8..7a83dab0a 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,6 +13,29 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat +try: + from torch.utils.cpp_extension import IS_HIP_EXTENSION +except ImportError: + IS_HIP_EXTENSION = False + +_aiter_rope_fwd = None +_aiter_rope_bwd = None +_HAVE_AITER_ROPE = False + +if IS_HIP_EXTENSION: + import os # pylint: disable=wrong-import-order,wrong-import-position + if os.environ.get("NVTE_USE_AITER_ROPE", "0") == "1": + try: + from aiter.ops.rope import ( # pylint: disable=import-error + rope_fwd as _aiter_rope_fwd, + rope_bwd as _aiter_rope_bwd, + ) + _HAVE_AITER_ROPE = True + except Exception as _aiter_import_err: # pylint: disable=broad-except + raise RuntimeError( + "NVTE_USE_AITER_ROPE=1 but AITER fused RoPE import failed." + ) from _aiter_import_err + __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] @@ -118,6 +143,23 @@ class FusedRoPEFunc(torch.autograd.Function): the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. """ + @staticmethod + def has_aiter_rope(): + """Return whether AITER RoPE kernels are available.""" + return _HAVE_AITER_ROPE + + @staticmethod + def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): + """Check if we can dispatch to AITER's faster rope kernel.""" + return ( + _HAVE_AITER_ROPE + and tensor_format == "sbhd" + and not interleaved + and cu_seqlens is None + and cp_size == 1 + and start_positions is None + ) + @staticmethod def forward( ctx, @@ -139,21 +181,36 @@ def forward( "bshd", "thd", ), f"Unsupported tensor_format: {tensor_format}." - output = tex.fused_rope_forward( - t, - freqs, - start_positions, - QKVFormat[tensor_format], - interleaved, - cu_seqlens, - cp_size, - cp_rank, + + use_aiter = FusedRoPEFunc._can_use_aiter( + tensor_format, interleaved, cu_seqlens, cp_size, start_positions ) + + if use_aiter: + rotate_style = 1 if interleaved else 0 + output = _aiter_rope_fwd( + t, freqs, rotate_style, + False, # reuse_freqs_front_part + False, # nope_first + ) + else: + output = tex.fused_rope_forward( + t, + freqs, + start_positions, + QKVFormat[tensor_format], + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank ctx.interleaved = interleaved + ctx.use_aiter = use_aiter return output @@ -161,16 +218,25 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """Fused RoPE backward.""" freqs, cu_seqlens, start_positions = ctx.saved_tensors - grad_input = tex.fused_rope_backward( - grad_output, - freqs, - start_positions, - QKVFormat[ctx.tensor_format], - ctx.interleaved, - cu_seqlens, - ctx.cp_size, - ctx.cp_rank, - ) + + if ctx.use_aiter: + rotate_style = 1 if ctx.interleaved else 0 + grad_input = _aiter_rope_bwd( + grad_output, freqs, rotate_style, + False, # reuse_freqs_front_part + False, # nope_first + ) + else: + grad_input = tex.fused_rope_backward( + grad_output, + freqs, + start_positions, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + cu_seqlens, + ctx.cp_size, + ctx.cp_rank, + ) return grad_input, None, None, None, None, None, None, None, None