From e20c7ea1edb9ee5e54bf1f29a332920aadc99eec Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Tue, 5 May 2026 01:29:50 -0400 Subject: [PATCH] Fix build on Pytorch 2.11 (#16505) --- build_tools/hipify/custom_map.json | 1 - .../pytorch/csrc/extensions/attention.cpp | 11 ++++++++++- transformer_engine/pytorch/csrc/extensions/gemm.cpp | 13 ++++++++++--- .../pytorch/csrc/extensions/normalization.cpp | 11 +++++++++-- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index fe3820f85..f9c452120 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -11,7 +11,6 @@ "__nv_fp8_e5m2" : "te_hip_fp8_e5m2", "__nv_fp8_e4m3" : "te_hip_fp8_e4m3", "cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA", - "at::cuda::CUDAGuard" : "at::hip::HIPGuardMasqueradingAsCUDA", "__nv_fp4_e2m1" : "__hip_fp4_e2m1", "__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1", "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index b455e0375..b80c191bc 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -8,6 +10,13 @@ #include "common.h" #include "pybind.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace { constexpr int block_size = 512; @@ -111,7 +120,7 @@ std::vector fused_attn_fwd( // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(cu_seqlens_q.device()); + TECUDAGuard device_guard(cu_seqlens_q.device()); auto none = py::none(); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 941b88e36..6898ce387 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -20,6 +20,13 @@ #include "transformer_engine/transformer_engine.h" #include "util.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace { void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) { @@ -100,7 +107,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); + TECUDAGuard device_guard(workspace.device()); // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); @@ -388,7 +395,7 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); + TECUDAGuard device_guard(workspace.device()); // TODO: Handle scaling modes NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; @@ -442,7 +449,7 @@ std::optional> te_general_grouped_gemm( // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace[0].device()); + TECUDAGuard device_guard(workspace[0].device()); void* output_data_ptr = nullptr; if (single_output) { diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index b78982d4d..8f8eed2c3 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -10,6 +10,13 @@ #include "common/util/system.h" #include "pybind.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace transformer_engine::pytorch { std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, @@ -69,7 +76,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(input.cast().device()); + TECUDAGuard device_guard(input.cast().device()); // Input and param tensors auto none = py::none(); @@ -319,7 +326,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(input.cast().device()); + TECUDAGuard device_guard(input.cast().device()); // Input and param tensors auto none = py::none();