From 8b05a6a8ac133f558804a35af1b7e04f65e75e3d Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Tue, 14 Apr 2026 19:39:06 -0500 Subject: [PATCH 01/11] Integrate AITER fused RoPE kernels with fallback to TE native Add optional AITER RoPE dispatch path in FusedRoPEFunc for improved performance on ROCm/AMD GPUs. When aiter is installed and the input meets the supported subset (sbhd format, non-interleaved, no context parallelism, no packed sequences, no start_positions), the forward and backward passes dispatch to aiter.ops.rope.rope_fwd / rope_bwd. Fallback to the existing tex.fused_rope_forward / tex.fused_rope_backward is automatic for all other configurations and when AITER is not available. A new env var NVTE_USE_AITER_ROPE (default "1") allows explicit opt-out. The AITER import is gated behind IS_HIP_EXTENSION to avoid unnecessary import attempts on CUDA systems. Add unit tests for AITER-vs-TE numerical parity, guard logic coverage, env var disable behavior, and fallback on unsupported configurations. Tested in MLPerf GPT-OSS-20B MoE pretraining on MI355X (8xGPU). Signed-off-by: Su Ann Chong Made-with: Cursor Signed-off-by: Su Ann Chong Made-with: Cursor Signed-off-by: Su Ann Chong Made-with: Cursor --- tests/pytorch/test_fused_rope.py | 100 +++++++++++++++++ transformer_engine/pytorch/attention/rope.py | 112 +++++++++++++++---- 2 files changed, 193 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 50624df9e..7239144f0 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -3,12 +3,15 @@ # See LICENSE for license information. from typing import Callable, Tuple, Union, List import math +import unittest.mock as mock import torch import pytest from transformer_engine.pytorch.attention.rope import ( + FusedRoPEFunc, RotaryPositionEmbedding, apply_rotary_pos_emb, apply_fused_qkv_rotary_pos_emb, + _HAVE_AITER_ROPE, ) @@ -495,3 +498,100 @@ def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_wi atol=1e-8, rtol=1e-8, ) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("seq_length", [2048, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [64, 128, 256]) +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +def test_aiter_rope_matches_te_fused( + dtype: torch.dtype, + seq_length: int, + hidden_size: int, + rotary_percent: float, + loss_func: Callable, +) -> None: + """ + When AITER dispatch is active (sbhd, non-interleaved, cp_size=1, no cu_seqlens, + no start_positions), verify output and gradients match the TE fused kernel. + """ + if not _HAVE_AITER_ROPE: + pytest.skip("AITER RoPE not available") + + device = torch.device("cuda:0") + batch_size, head_num = 2, 64 + tensor_format = "sbhd" + interleaved = False + cp_size = 1 + cp_rank = 0 + + t = torch.rand( + (seq_length, batch_size, head_num, hidden_size), + dtype=dtype, + device=device, + ) + t.requires_grad = True + + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) + emb = rotary_pos_emb(seq_length) + + output_aiter = apply_rotary_pos_emb( + t, + emb, + tensor_format=tensor_format, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss_aiter = loss_func(output_aiter) + loss_aiter.backward() + grad_aiter = t.grad.detach().clone() + t.grad = None + + with mock.patch.object(FusedRoPEFunc, "_can_use_aiter", return_value=False): + output_te = apply_rotary_pos_emb( + t, + emb, + tensor_format=tensor_format, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss_te = loss_func(output_te) + loss_te.backward() + grad_te = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(output_aiter, output_te) + torch.testing.assert_close(grad_aiter, grad_te) + + +@pytest.mark.parametrize( + "tensor_format,interleaved,cu_seqlens,cp_size,start_positions,expected", + [ + ("sbhd", False, None, 1, None, True), + ("bshd", False, None, 1, None, False), + ("sbhd", True, None, 1, None, False), + ("sbhd", False, torch.tensor([0, 10]), 1, None, False), + ("sbhd", False, None, 2, None, False), + ("sbhd", False, None, 1, torch.tensor([0]), False), + ], +) +def test_aiter_rope_can_use_guard( + tensor_format: str, + interleaved: bool, + cu_seqlens, + cp_size: int, + start_positions, + expected: bool, +) -> None: + """Unit test the _can_use_aiter guard logic exhaustively.""" + if not _HAVE_AITER_ROPE and expected: + pytest.skip("AITER not available — guard always returns False for True cases") + result = FusedRoPEFunc._can_use_aiter( + tensor_format, interleaved, cu_seqlens, cp_size, start_positions + ) + assert result == expected diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 77ad57ed8..68697a016 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -5,12 +5,50 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ +import logging +import os from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat +logger = logging.getLogger(__name__) + +_aiter_rope_fwd = None +_aiter_rope_bwd = None +_HAVE_AITER_ROPE = False +_USE_AITER_ROPE = os.environ.get("NVTE_USE_AITER_ROPE", "1") == "1" + +if _USE_AITER_ROPE: + try: + from torch.utils.cpp_extension import IS_HIP_EXTENSION + except ImportError: + IS_HIP_EXTENSION = False + + if IS_HIP_EXTENSION: + try: + from aiter.ops.rope import ( # pylint: disable=import-error + rope_fwd as _aiter_rope_fwd, + rope_bwd as _aiter_rope_bwd, + ) + _HAVE_AITER_ROPE = True + except Exception as _aiter_import_err: # pylint: disable=broad-except + _HAVE_AITER_ROPE = False + logger.info( + "AITER fused RoPE import failed (%s: %s). " + "Falling back to TE native kernels. " + "Set NVTE_USE_AITER_ROPE=0 to silence this message.", + type(_aiter_import_err).__name__, + _aiter_import_err, + ) + +if _HAVE_AITER_ROPE: + logger.info("Using AITER fused RoPE kernels (aiter.ops.rope)") +else: + _reason = "disabled via NVTE_USE_AITER_ROPE=0" if not _USE_AITER_ROPE else "not available" + logger.debug("AITER RoPE not active (%s), using TE native kernels", _reason) + __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] @@ -118,6 +156,18 @@ class FusedRoPEFunc(torch.autograd.Function): the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. """ + @staticmethod + def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): + """Check if we can dispatch to AITER's faster rope kernel.""" + return ( + _HAVE_AITER_ROPE + and tensor_format == "sbhd" + and not interleaved + and cu_seqlens is None + and cp_size == 1 + and start_positions is None + ) + @staticmethod def forward( ctx, @@ -139,21 +189,36 @@ def forward( "bshd", "thd", ), f"Unsupported tensor_format: {tensor_format}." - output = tex.fused_rope_forward( - t, - freqs, - start_positions, - QKVFormat[tensor_format], - interleaved, - cu_seqlens, - cp_size, - cp_rank, + + use_aiter = FusedRoPEFunc._can_use_aiter( + tensor_format, interleaved, cu_seqlens, cp_size, start_positions ) + + if use_aiter: + rotate_style = 1 if interleaved else 0 + output = _aiter_rope_fwd( + t, freqs, rotate_style, + False, # reuse_freqs_front_part + False, # nope_first + ) + else: + output = tex.fused_rope_forward( + t, + freqs, + start_positions, + QKVFormat[tensor_format], + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank ctx.interleaved = interleaved + ctx.use_aiter = use_aiter return output @@ -161,16 +226,25 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """Fused RoPE backward.""" freqs, cu_seqlens, start_positions = ctx.saved_tensors - grad_input = tex.fused_rope_backward( - grad_output, - freqs, - start_positions, - QKVFormat[ctx.tensor_format], - ctx.interleaved, - cu_seqlens, - ctx.cp_size, - ctx.cp_rank, - ) + + if ctx.use_aiter: + rotate_style = 1 if ctx.interleaved else 0 + grad_input = _aiter_rope_bwd( + grad_output, freqs, rotate_style, + False, # reuse_freqs_front_part + False, # nope_first + ) + else: + grad_input = tex.fused_rope_backward( + grad_output, + freqs, + start_positions, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + cu_seqlens, + ctx.cp_size, + ctx.cp_rank, + ) return grad_input, None, None, None, None, None, None, None, None From 277d40c51e423fdc9d642131e8c818809ddb807e Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 20 Apr 2026 14:43:45 -0500 Subject: [PATCH 02/11] Address PR #541 review feedback from ipanfilo - Add AMD copyright header to rope.py - Check IS_HIP_EXTENSION first, guard all AITER code behind it - Use logger.warning for AITER import failures instead of logger.info - Log AITER version (via aiter._version) on successful import - Default NVTE_USE_AITER_ROPE to "0" (opt-in) since CI cannot test it - Expose _HAVE_AITER_ROPE via FusedRoPEFunc.has_aiter_rope() method - Use @pytest.mark.skipif decorator instead of inline pytest.skip() Signed-off-by: Su Ann Chong Made-with: Cursor --- tests/pytorch/test_fused_rope.py | 8 +-- transformer_engine/pytorch/attention/rope.py | 63 ++++++++++++-------- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 7239144f0..3d057f9eb 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -11,7 +11,6 @@ RotaryPositionEmbedding, apply_rotary_pos_emb, apply_fused_qkv_rotary_pos_emb, - _HAVE_AITER_ROPE, ) @@ -500,6 +499,9 @@ def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_wi ) +@pytest.mark.skipif( + not FusedRoPEFunc.has_aiter_rope(), reason="AITER RoPE not available" +) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("seq_length", [2048, 4096, 8192]) @pytest.mark.parametrize("hidden_size", [64, 128, 256]) @@ -516,8 +518,6 @@ def test_aiter_rope_matches_te_fused( When AITER dispatch is active (sbhd, non-interleaved, cp_size=1, no cu_seqlens, no start_positions), verify output and gradients match the TE fused kernel. """ - if not _HAVE_AITER_ROPE: - pytest.skip("AITER RoPE not available") device = torch.device("cuda:0") batch_size, head_num = 2, 64 @@ -589,7 +589,7 @@ def test_aiter_rope_can_use_guard( expected: bool, ) -> None: """Unit test the _can_use_aiter guard logic exhaustively.""" - if not _HAVE_AITER_ROPE and expected: + if not FusedRoPEFunc.has_aiter_rope() and expected: pytest.skip("AITER not available — guard always returns False for True cases") result = FusedRoPEFunc._can_use_aiter( tensor_format, interleaved, cu_seqlens, cp_size, start_positions diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 68697a016..bbd6bcba8 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -1,4 +1,5 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. @@ -15,38 +16,47 @@ logger = logging.getLogger(__name__) +try: + from torch.utils.cpp_extension import IS_HIP_EXTENSION +except ImportError: + IS_HIP_EXTENSION = False + _aiter_rope_fwd = None _aiter_rope_bwd = None _HAVE_AITER_ROPE = False -_USE_AITER_ROPE = os.environ.get("NVTE_USE_AITER_ROPE", "1") == "1" +_USE_AITER_ROPE = os.environ.get("NVTE_USE_AITER_ROPE", "0") == "1" -if _USE_AITER_ROPE: +if IS_HIP_EXTENSION and _USE_AITER_ROPE: try: - from torch.utils.cpp_extension import IS_HIP_EXTENSION - except ImportError: - IS_HIP_EXTENSION = False - - if IS_HIP_EXTENSION: - try: - from aiter.ops.rope import ( # pylint: disable=import-error - rope_fwd as _aiter_rope_fwd, - rope_bwd as _aiter_rope_bwd, - ) - _HAVE_AITER_ROPE = True - except Exception as _aiter_import_err: # pylint: disable=broad-except - _HAVE_AITER_ROPE = False - logger.info( - "AITER fused RoPE import failed (%s: %s). " - "Falling back to TE native kernels. " - "Set NVTE_USE_AITER_ROPE=0 to silence this message.", - type(_aiter_import_err).__name__, - _aiter_import_err, - ) + from aiter.ops.rope import ( # pylint: disable=import-error + rope_fwd as _aiter_rope_fwd, + rope_bwd as _aiter_rope_bwd, + ) + _HAVE_AITER_ROPE = True + except Exception as _aiter_import_err: # pylint: disable=broad-except + _HAVE_AITER_ROPE = False + logger.warning( + "AITER fused RoPE import failed (%s: %s). " + "Falling back to TE native kernels. " + "Set NVTE_USE_AITER_ROPE=0 to silence this message.", + type(_aiter_import_err).__name__, + _aiter_import_err, + ) if _HAVE_AITER_ROPE: - logger.info("Using AITER fused RoPE kernels (aiter.ops.rope)") + _aiter_version = "unknown" + try: + from aiter._version import version as _aiter_version # pylint: disable=import-error + except Exception: # pylint: disable=broad-except + pass + logger.info("Using AITER fused RoPE kernels (aiter.ops.rope, version=%s)", _aiter_version) else: - _reason = "disabled via NVTE_USE_AITER_ROPE=0" if not _USE_AITER_ROPE else "not available" + if not IS_HIP_EXTENSION: + _reason = "not on ROCm" + elif not _USE_AITER_ROPE: + _reason = "disabled via NVTE_USE_AITER_ROPE=0" + else: + _reason = "not available" logger.debug("AITER RoPE not active (%s), using TE native kernels", _reason) @@ -156,6 +166,11 @@ class FusedRoPEFunc(torch.autograd.Function): the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. """ + @staticmethod + def has_aiter_rope(): + """Return whether AITER RoPE kernels are available.""" + return _HAVE_AITER_ROPE + @staticmethod def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): """Check if we can dispatch to AITER's faster rope kernel.""" From 08ea73e6d9fbdf3152e405eadd047bdb2d6ad163 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 20 Apr 2026 14:46:21 -0500 Subject: [PATCH 03/11] remove nvidia header --- transformer_engine/pytorch/attention/rope.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index bbd6bcba8..03edaa4ff 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -1,4 +1,3 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. From 6bf76341189a8abf89da3325c9fbdb36f6e93211 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 20 Apr 2026 14:54:00 -0500 Subject: [PATCH 04/11] Add local testing instructions for AITER RoPE tests Signed-off-by: Su Ann Chong Made-with: Cursor --- tests/pytorch/test_fused_rope.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 3d057f9eb..44c5b9857 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -499,6 +499,14 @@ def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_wi ) +# AITER RoPE tests require: +# 1. ROCm environment with `aiter` installed (pip install amd-aiter) +# 2. NVTE_USE_AITER_ROPE=1 environment variable set before running +# Example: +# NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_matches_te_fused -v +# NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_can_use_guard -v + + @pytest.mark.skipif( not FusedRoPEFunc.has_aiter_rope(), reason="AITER RoPE not available" ) From 6eb19fe2257632c233ad5fc419839063883f4e04 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Mon, 20 Apr 2026 17:06:41 -0500 Subject: [PATCH 05/11] Add Dockerfile and README for local AITER RoPE testing Provides a containerized way to test the AITER fused RoPE integration on ROCm systems, since CI cannot test this feature. Signed-off-by: Su Ann Chong Made-with: Cursor --- tests/pytorch/aiter_rope_test/Dockerfile | 46 ++++++++++++++++++++ tests/pytorch/aiter_rope_test/README.md | 55 ++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 tests/pytorch/aiter_rope_test/Dockerfile create mode 100644 tests/pytorch/aiter_rope_test/README.md diff --git a/tests/pytorch/aiter_rope_test/Dockerfile b/tests/pytorch/aiter_rope_test/Dockerfile new file mode 100644 index 000000000..583915b50 --- /dev/null +++ b/tests/pytorch/aiter_rope_test/Dockerfile @@ -0,0 +1,46 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +# +# Dockerfile for locally testing AITER fused RoPE integration. +# +# Build: +# docker build -t te-aiter-rope-test -f tests/pytorch/aiter_rope_test/Dockerfile . +# +# Run (requires ROCm GPU): +# docker run --rm --device /dev/kfd --device /dev/dri --group-add video \ +# te-aiter-rope-test + +ARG BASE_DOCKER=rocm/pytorch:latest +FROM $BASE_DOCKER + +ARG TE_BRANCH=feat/aiter-fused-rope +ARG GPU_ARCHS="gfx942;gfx950" + +ENV NVTE_ROCM_ARCH=${GPU_ARCHS} + +WORKDIR /workspace + +RUN pip install setuptools wheel pybind11 ninja pandas pytest psutil + +# Install AITER from PyPI +RUN pip install amd-aiter + +# Clone and install TransformerEngine from the PR branch +RUN git clone --branch ${TE_BRANCH} --recursive \ + https://github.com/ROCm/TransformerEngine.git /workspace/TransformerEngine + +WORKDIR /workspace/TransformerEngine +RUN pip install -e . + +# Install the PyTorch extension +WORKDIR /workspace/TransformerEngine/transformer_engine/pytorch +RUN python setup.py develop + +WORKDIR /workspace/TransformerEngine + +ENV NVTE_USE_AITER_ROPE=1 + +CMD ["pytest", "tests/pytorch/test_fused_rope.py::test_aiter_rope_matches_te_fused", \ + "tests/pytorch/test_fused_rope.py::test_aiter_rope_can_use_guard", \ + "-v", "--tb=short"] diff --git a/tests/pytorch/aiter_rope_test/README.md b/tests/pytorch/aiter_rope_test/README.md new file mode 100644 index 000000000..212f5db60 --- /dev/null +++ b/tests/pytorch/aiter_rope_test/README.md @@ -0,0 +1,55 @@ +# AITER Fused RoPE Local Testing + +This directory provides a Dockerfile for testing the AITER fused RoPE integration +on a ROCm system. CI cannot test this feature since it depends on the `aiter` package +and ROCm hardware. + +## Prerequisites + +- Docker with ROCm support +- AMD GPU with ROCm drivers installed (e.g. MI250X, MI300X, MI355X) + +## Quick Start + +From the **repository root**: + +```bash +# Build the test image +docker build -t te-aiter-rope-test \ + -f tests/pytorch/aiter_rope_test/Dockerfile . + +# Run the AITER RoPE tests +docker run --rm --device /dev/kfd --device /dev/dri --group-add video \ + te-aiter-rope-test +``` + +## Run All RoPE Tests (regression check) + +```bash +docker run --rm --device /dev/kfd --device /dev/dri --group-add video \ + te-aiter-rope-test \ + pytest tests/pytorch/test_fused_rope.py -v --tb=short +``` + +## Interactive Debugging + +```bash +docker run --rm -it --device /dev/kfd --device /dev/dri --group-add video \ + te-aiter-rope-test /bin/bash + +# Inside the container: +NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_matches_te_fused -v +NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_can_use_guard -v +``` + +## Customization + +Override build args as needed: + +```bash +docker build -t te-aiter-rope-test \ + --build-arg BASE_DOCKER=rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.5.1 \ + --build-arg TE_BRANCH=feat/aiter-fused-rope \ + --build-arg GPU_ARCHS="gfx942" \ + -f tests/pytorch/aiter_rope_test/Dockerfile . +``` From c46d8063b247bb23e3c3bd5f0ec2e34b9e06c726 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Tue, 21 Apr 2026 11:21:20 -0500 Subject: [PATCH 06/11] Remove aiter_rope_test directory from branch Local testing infrastructure, not intended for the repository. Signed-off-by: Su Ann Chong Made-with: Cursor --- tests/pytorch/aiter_rope_test/Dockerfile | 46 -------------------- tests/pytorch/aiter_rope_test/README.md | 55 ------------------------ 2 files changed, 101 deletions(-) delete mode 100644 tests/pytorch/aiter_rope_test/Dockerfile delete mode 100644 tests/pytorch/aiter_rope_test/README.md diff --git a/tests/pytorch/aiter_rope_test/Dockerfile b/tests/pytorch/aiter_rope_test/Dockerfile deleted file mode 100644 index 583915b50..000000000 --- a/tests/pytorch/aiter_rope_test/Dockerfile +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -# -# Dockerfile for locally testing AITER fused RoPE integration. -# -# Build: -# docker build -t te-aiter-rope-test -f tests/pytorch/aiter_rope_test/Dockerfile . -# -# Run (requires ROCm GPU): -# docker run --rm --device /dev/kfd --device /dev/dri --group-add video \ -# te-aiter-rope-test - -ARG BASE_DOCKER=rocm/pytorch:latest -FROM $BASE_DOCKER - -ARG TE_BRANCH=feat/aiter-fused-rope -ARG GPU_ARCHS="gfx942;gfx950" - -ENV NVTE_ROCM_ARCH=${GPU_ARCHS} - -WORKDIR /workspace - -RUN pip install setuptools wheel pybind11 ninja pandas pytest psutil - -# Install AITER from PyPI -RUN pip install amd-aiter - -# Clone and install TransformerEngine from the PR branch -RUN git clone --branch ${TE_BRANCH} --recursive \ - https://github.com/ROCm/TransformerEngine.git /workspace/TransformerEngine - -WORKDIR /workspace/TransformerEngine -RUN pip install -e . - -# Install the PyTorch extension -WORKDIR /workspace/TransformerEngine/transformer_engine/pytorch -RUN python setup.py develop - -WORKDIR /workspace/TransformerEngine - -ENV NVTE_USE_AITER_ROPE=1 - -CMD ["pytest", "tests/pytorch/test_fused_rope.py::test_aiter_rope_matches_te_fused", \ - "tests/pytorch/test_fused_rope.py::test_aiter_rope_can_use_guard", \ - "-v", "--tb=short"] diff --git a/tests/pytorch/aiter_rope_test/README.md b/tests/pytorch/aiter_rope_test/README.md deleted file mode 100644 index 212f5db60..000000000 --- a/tests/pytorch/aiter_rope_test/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# AITER Fused RoPE Local Testing - -This directory provides a Dockerfile for testing the AITER fused RoPE integration -on a ROCm system. CI cannot test this feature since it depends on the `aiter` package -and ROCm hardware. - -## Prerequisites - -- Docker with ROCm support -- AMD GPU with ROCm drivers installed (e.g. MI250X, MI300X, MI355X) - -## Quick Start - -From the **repository root**: - -```bash -# Build the test image -docker build -t te-aiter-rope-test \ - -f tests/pytorch/aiter_rope_test/Dockerfile . - -# Run the AITER RoPE tests -docker run --rm --device /dev/kfd --device /dev/dri --group-add video \ - te-aiter-rope-test -``` - -## Run All RoPE Tests (regression check) - -```bash -docker run --rm --device /dev/kfd --device /dev/dri --group-add video \ - te-aiter-rope-test \ - pytest tests/pytorch/test_fused_rope.py -v --tb=short -``` - -## Interactive Debugging - -```bash -docker run --rm -it --device /dev/kfd --device /dev/dri --group-add video \ - te-aiter-rope-test /bin/bash - -# Inside the container: -NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_matches_te_fused -v -NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_can_use_guard -v -``` - -## Customization - -Override build args as needed: - -```bash -docker build -t te-aiter-rope-test \ - --build-arg BASE_DOCKER=rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.5.1 \ - --build-arg TE_BRANCH=feat/aiter-fused-rope \ - --build-arg GPU_ARCHS="gfx942" \ - -f tests/pytorch/aiter_rope_test/Dockerfile . -``` From 51cd242e86055f812cb7db43a07b0d2742f501a6 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Tue, 21 Apr 2026 11:25:15 -0500 Subject: [PATCH 07/11] Preserve upstream NVIDIA copyright header in rope.py Follow existing convention: add AMD copyright above the NVIDIA header with "modified for portability to AMDGPU" note, rather than replacing it. Signed-off-by: Su Ann Chong Made-with: Cursor --- transformer_engine/pytorch/attention/rope.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 03edaa4ff..dbfeb4c38 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -1,4 +1,6 @@ +# This file was modified for portability to AMDGPU # Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From 7d4cb24d2763f6a1e1a284a7250cccb017e238ef Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Thu, 23 Apr 2026 12:49:03 -0500 Subject: [PATCH 08/11] Address PR #541 review: raise RuntimeError instead of silent fallback - Replace logger.warning with RuntimeError when NVTE_USE_AITER_ROPE=1 but AITER import fails, making the failure explicit instead of silently falling back to TE native kernels - Remove all diagnostic logging (version info, reason tracking) to reduce maintenance burden and stay synchronized with upstream Signed-off-by: Su Ann Chong Made-with: Cursor --- transformer_engine/pytorch/attention/rope.py | 31 +++----------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index dbfeb4c38..747902e14 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -7,7 +7,6 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ -import logging import os from typing import Optional, Tuple, Union, List import torch @@ -15,8 +14,6 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -logger = logging.getLogger(__name__) - try: from torch.utils.cpp_extension import IS_HIP_EXTENSION except ImportError: @@ -35,30 +32,10 @@ ) _HAVE_AITER_ROPE = True except Exception as _aiter_import_err: # pylint: disable=broad-except - _HAVE_AITER_ROPE = False - logger.warning( - "AITER fused RoPE import failed (%s: %s). " - "Falling back to TE native kernels. " - "Set NVTE_USE_AITER_ROPE=0 to silence this message.", - type(_aiter_import_err).__name__, - _aiter_import_err, - ) - -if _HAVE_AITER_ROPE: - _aiter_version = "unknown" - try: - from aiter._version import version as _aiter_version # pylint: disable=import-error - except Exception: # pylint: disable=broad-except - pass - logger.info("Using AITER fused RoPE kernels (aiter.ops.rope, version=%s)", _aiter_version) -else: - if not IS_HIP_EXTENSION: - _reason = "not on ROCm" - elif not _USE_AITER_ROPE: - _reason = "disabled via NVTE_USE_AITER_ROPE=0" - else: - _reason = "not available" - logger.debug("AITER RoPE not active (%s), using TE native kernels", _reason) + raise RuntimeError( + f"NVTE_USE_AITER_ROPE=1 but AITER fused RoPE import failed: " + f"{type(_aiter_import_err).__name__}: {_aiter_import_err}" + ) from _aiter_import_err __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] From 3f06411fe1f84b69b6296ce18d8f770e99b27a8d Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Thu, 23 Apr 2026 11:16:40 -0700 Subject: [PATCH 09/11] Update transformer_engine/pytorch/attention/rope.py Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com> --- transformer_engine/pytorch/attention/rope.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 747902e14..ee946feb6 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -33,8 +33,7 @@ _HAVE_AITER_ROPE = True except Exception as _aiter_import_err: # pylint: disable=broad-except raise RuntimeError( - f"NVTE_USE_AITER_ROPE=1 but AITER fused RoPE import failed: " - f"{type(_aiter_import_err).__name__}: {_aiter_import_err}" + f"NVTE_USE_AITER_ROPE=1 but AITER fused RoPE import failed." ) from _aiter_import_err From 69bceaafe2a4e0144491a1ee24ff4b9ea7c74893 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Thu, 23 Apr 2026 14:53:46 -0500 Subject: [PATCH 10/11] Address PR #541 review feedback from Micky774 and wangye805 - Guard `import os` and env var check under IS_HIP_EXTENSION in rope.py to minimize upstream diff - Add AMD copyright header to test_fused_rope.py - Guard `unittest.mock` and `FusedRoPEFunc` imports behind IS_HIP_EXTENSION - Add IS_HIP_EXTENSION skipif guard to all AITER test functions - Use torch.device("cuda") instead of hardcoding cuda:0 in AITER test Signed-off-by: Su Ann Chong Made-with: Cursor --- tests/pytorch/test_fused_rope.py | 20 ++++++++++++--- transformer_engine/pytorch/attention/rope.py | 26 ++++++++++---------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 44c5b9857..9fbaed8b4 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,18 +1,27 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. from typing import Callable, Tuple, Union, List import math -import unittest.mock as mock import torch import pytest from transformer_engine.pytorch.attention.rope import ( - FusedRoPEFunc, RotaryPositionEmbedding, apply_rotary_pos_emb, apply_fused_qkv_rotary_pos_emb, ) +try: + from torch.utils.cpp_extension import IS_HIP_EXTENSION +except ImportError: + IS_HIP_EXTENSION = False + +if IS_HIP_EXTENSION: + import unittest.mock as mock + from transformer_engine.pytorch.attention.rope import FusedRoPEFunc + # Gradient is a broadcasted scalar def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: @@ -507,8 +516,10 @@ def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_wi # NVTE_USE_AITER_ROPE=1 pytest tests/pytorch/test_fused_rope.py::test_aiter_rope_can_use_guard -v +@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER RoPE requires ROCm") @pytest.mark.skipif( - not FusedRoPEFunc.has_aiter_rope(), reason="AITER RoPE not available" + IS_HIP_EXTENSION and not FusedRoPEFunc.has_aiter_rope(), + reason="AITER RoPE not available", ) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("seq_length", [2048, 4096, 8192]) @@ -527,7 +538,7 @@ def test_aiter_rope_matches_te_fused( no start_positions), verify output and gradients match the TE fused kernel. """ - device = torch.device("cuda:0") + device = torch.device("cuda") batch_size, head_num = 2, 64 tensor_format = "sbhd" interleaved = False @@ -577,6 +588,7 @@ def test_aiter_rope_matches_te_fused( torch.testing.assert_close(grad_aiter, grad_te) +@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER RoPE requires ROCm") @pytest.mark.parametrize( "tensor_format,interleaved,cu_seqlens,cp_size,start_positions,expected", [ diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index ee946feb6..7a83dab0a 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -7,7 +7,6 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ -import os from typing import Optional, Tuple, Union, List import torch @@ -22,19 +21,20 @@ _aiter_rope_fwd = None _aiter_rope_bwd = None _HAVE_AITER_ROPE = False -_USE_AITER_ROPE = os.environ.get("NVTE_USE_AITER_ROPE", "0") == "1" -if IS_HIP_EXTENSION and _USE_AITER_ROPE: - try: - from aiter.ops.rope import ( # pylint: disable=import-error - rope_fwd as _aiter_rope_fwd, - rope_bwd as _aiter_rope_bwd, - ) - _HAVE_AITER_ROPE = True - except Exception as _aiter_import_err: # pylint: disable=broad-except - raise RuntimeError( - f"NVTE_USE_AITER_ROPE=1 but AITER fused RoPE import failed." - ) from _aiter_import_err +if IS_HIP_EXTENSION: + import os # pylint: disable=wrong-import-order,wrong-import-position + if os.environ.get("NVTE_USE_AITER_ROPE", "0") == "1": + try: + from aiter.ops.rope import ( # pylint: disable=import-error + rope_fwd as _aiter_rope_fwd, + rope_bwd as _aiter_rope_bwd, + ) + _HAVE_AITER_ROPE = True + except Exception as _aiter_import_err: # pylint: disable=broad-except + raise RuntimeError( + "NVTE_USE_AITER_ROPE=1 but AITER fused RoPE import failed." + ) from _aiter_import_err __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] From ec7fc13ad62466a7acb1d343b7e94782b1d7b709 Mon Sep 17 00:00:00 2001 From: Su Ann Chong Date: Thu, 23 Apr 2026 16:20:10 -0500 Subject: [PATCH 11/11] Use bare import for IS_HIP_EXTENSION in test file Follow repo convention: import IS_HIP_EXTENSION directly from torch.utils.cpp_extension without try/except guard, consistent with all other test modules. Signed-off-by: Su Ann Chong Made-with: Cursor --- tests/pytorch/test_fused_rope.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 9fbaed8b4..86e2afeaa 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -13,10 +13,7 @@ apply_fused_qkv_rotary_pos_emb, ) -try: - from torch.utils.cpp_extension import IS_HIP_EXTENSION -except ImportError: - IS_HIP_EXTENSION = False +from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: import unittest.mock as mock