Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion build_tools/hipify/custom_map.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"__nv_fp8_e5m2" : "te_hip_fp8_e5m2",
"__nv_fp8_e4m3" : "te_hip_fp8_e4m3",
"cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA",
"at::cuda::CUDAGuard" : "at::hip::HIPGuardMasqueradingAsCUDA",
"__nv_fp4_e2m1" : "__hip_fp4_e2m1",
"__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1",
"__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1",
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/pytorch/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -8,6 +10,13 @@
#include "common.h"
#include "pybind.h"

#include <torch/version.h>
#if USE_ROCM && TORCH_VERSION_MINOR < 11
using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA;
#else
using TECUDAGuard = at::cuda::CUDAGuard;
#endif

namespace {

constexpr int block_size = 512;
Expand Down Expand Up @@ -111,7 +120,7 @@ std::vector<py::object> fused_attn_fwd(
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(cu_seqlens_q.device());
TECUDAGuard device_guard(cu_seqlens_q.device());

auto none = py::none();

Expand Down
13 changes: 10 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
#include "transformer_engine/transformer_engine.h"
#include "util.h"

#include <torch/version.h>
#if USE_ROCM && TORCH_VERSION_MINOR < 11
using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA;
#else
using TECUDAGuard = at::cuda::CUDAGuard;
#endif

namespace {

void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) {
Expand Down Expand Up @@ -100,7 +107,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace.device());
TECUDAGuard device_guard(workspace.device());

// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
Expand Down Expand Up @@ -388,7 +395,7 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace.device());
TECUDAGuard device_guard(workspace.device());

// TODO: Handle scaling modes
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
Expand Down Expand Up @@ -442,7 +449,7 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace[0].device());
TECUDAGuard device_guard(workspace[0].device());

void* output_data_ptr = nullptr;
if (single_output) {
Expand Down
11 changes: 9 additions & 2 deletions transformer_engine/pytorch/csrc/extensions/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include "common/util/system.h"
#include "pybind.h"

#include <torch/version.h>
#if USE_ROCM && TORCH_VERSION_MINOR < 11
using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA;
#else
using TECUDAGuard = at::cuda::CUDAGuard;
#endif

namespace transformer_engine::pytorch {

std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
Expand Down Expand Up @@ -69,7 +76,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
TECUDAGuard device_guard(input.cast<at::Tensor>().device());

// Input and param tensors
auto none = py::none();
Expand Down Expand Up @@ -319,7 +326,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
TECUDAGuard device_guard(input.cast<at::Tensor>().device());

// Input and param tensors
auto none = py::none();
Expand Down
Loading