diff --git a/tests/cpp/operator/test_cast_current_scaling.cu b/tests/cpp/operator/test_cast_current_scaling.cu index f7425f0f3..856c24cfc 100644 --- a/tests/cpp/operator/test_cast_current_scaling.cu +++ b/tests/cpp/operator/test_cast_current_scaling.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -195,6 +197,43 @@ TEST_P(CastCSTestSuite, TestCastCS) { ); } +#ifdef __HIP_PLATFORM_AMD__ + +TEST(AmaxConsistencyTest, AtomicVsWorkspace) { + using namespace transformer_engine; + using namespace test; + + std::vector shape{256, 1024}; + const size_t N = product(shape); + + // Input: FP32, Output: FP8 (E4M3) with per-tensor scaling + Tensor input("input", shape, DType::kFloat32); + Tensor out_atomic("out_atomic", shape, DType::kFloat8E4M3, true, false); + Tensor out_ws("out_ws", shape, DType::kFloat8E4M3, true, false); + + fillUniform(&input); + + // Path 1: atomic-based amax (no workspace) + nvte_compute_amax(input.data(), out_atomic.data(), 0); + + // Path 2: two-stage amax using workspace + std::vector ws_shape{N}; + Tensor workspace("workspace", ws_shape, DType::kFloat32); + nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Compare the resulting amax values + float amax_atomic = out_atomic.amax(); + float amax_ws = out_ws.amax(); + + compareResults("amax_consistency", amax_atomic, amax_ws, /*atol=*/0.0f, /*rtol=*/0.0f); +} + +#endif + INSTANTIATE_TEST_SUITE_P( diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 50fb696ea..5955835bb 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -73,6 +75,12 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); +#ifdef __HIP_PLATFORM_AMD__ + +constexpr int amax_kernel_threads = 512; + +#endif + /*! \brief Compute an FP8 tensor's amax. * * The amax (maximum absolute value) of the input tensor is computed @@ -84,6 +92,22 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( */ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); +#ifdef __HIP_PLATFORM_AMD__ + +/*! \brief Compute an FP8 tensor's amax. + * + * The amax (maximum absolute value) of the input tensor is computed + * and written to the amax buffer of the output tensor. + * + * \param[in] input Input tensor. Must be unquantized. + * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[out] workspace Output tensor. Must be FP32. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, NVTETensor workspace, cudaStream_t stream); + +#endif + /*! \brief Update an FP8 tensor's scale based on its amax. * * This is only supported for FP8 tensors with per-tensor scaling. diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index 709ab200f..86592e1d7 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -26,12 +26,39 @@ using bf16__ = __nv_bfloat16; using bf16__ = __hip_bfloat16; #endif //__HIP_PLATFORM_AMD__ -constexpr int amax_kernel_threads = 512; + +#ifdef __HIP_PLATFORM_AMD__ + +template +__global__ void amax_final_reduce(const float* __restrict__ block_amax, + float* __restrict__ global_amax, + int num_blocks) { + float val = 0.f; + + for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) { + val = fmaxf(val, block_amax[i]); + } + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const float block_max = + reduce_max(val, warp_id); + + if (threadIdx.x == 0) { + *global_amax = block_max; + } +} + +#endif template __launch_bounds__(amax_kernel_threads) __global__ +#ifdef __HIP_PLATFORM_AMD__ + void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N, + const size_t num_aligned_elements) { +#else void amax_kernel(const InputType *input, float *amax, const size_t N, const size_t num_aligned_elements) { +#endif VectorizedLoader loader(input, N); InputType max{0.f}; const int warp_id = threadIdx.x / THREADS_PER_WARP; @@ -65,12 +92,23 @@ __launch_bounds__(amax_kernel_threads) __global__ // Reduce amax over block max = reduce_max(max, warp_id); if (threadIdx.x == 0) { +#ifdef __HIP_PLATFORM_AMD__ + if (block_amax != nullptr) { + // 2-stage: write per-block result + block_amax[blockIdx.x] = max; + } else { + // Atomic path: directly update global amax + atomicMaxFloat(amax, max); + } +#else atomicMaxFloat(amax, max); +#endif } } template -void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax, + size_t block_capacity, cudaStream_t stream) { // Zero out amax so we can update with atomic max (void)cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -89,24 +127,54 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud constexpr size_t max_blocks = 65535; num_blocks = std::min(num_blocks, max_blocks); +#ifdef __HIP_PLATFORM_AMD__ + if (block_capacity < num_blocks) + block_amax = nullptr; +#endif + // Launch kernel switch (align) { case Alignment::SAME_ALIGNED: +#ifdef __HIP_PLATFORM_AMD__ + amax_kernel + <<>>(input, amax, block_amax, N, num_aligned_elements); +#else amax_kernel <<>>(input, amax, N, num_aligned_elements); +#endif break; case Alignment::SAME_UNALIGNED: +#ifdef __HIP_PLATFORM_AMD__ + amax_kernel + <<>>(input, amax, block_amax, N, num_aligned_elements); +#else amax_kernel <<>>(input, amax, N, num_aligned_elements); +#endif break; case Alignment::DIFFERENT: { // This case is a logic error, since there is only one pointer (input) // in the alignment check. Still safe to process without vectorization. +#ifdef __HIP_PLATFORM_AMD__ + amax_kernel<1, true, InputType><<>>(input, amax, block_amax, N, N); +#else amax_kernel<1, true, InputType><<>>(input, amax, N, N); +#endif break; } } +#ifdef __HIP_PLATFORM_AMD__ + if (block_amax != nullptr) { + constexpr int FINAL_REDUCE_THREADS = 256; + dim3 fr_block(FINAL_REDUCE_THREADS); + dim3 fr_grid(1); + + amax_final_reduce + <<>>(block_amax, amax, static_cast(num_blocks)); + } +#endif + // Check results NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -115,6 +183,12 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud } // namespace transformer_engine void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream); +} + +void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) { +#endif NVTE_API_CALL(nvte_compute_amax); using namespace transformer_engine; @@ -150,11 +224,31 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt to_string(output.amax.dtype), ")"); CheckOutputTensor(output, "output_compute_amax", true); +#ifdef __HIP_PLATFORM_AMD__ + // Optional workspace + float* block_amax = nullptr; + size_t block_capacity = 0; + + if (workspace_ != nullptr) { + auto &workspace = *reinterpret_cast(workspace_); + if (workspace.data.dptr != nullptr) { + NVTE_CHECK(workspace.data.dtype == DType::kFloat32, + "Workspace tensor for amax computation must be FP32, got dtype=", + to_string(workspace.data.dtype)); + block_amax = reinterpret_cast(workspace.data.dptr); + block_capacity = workspace.data.numel(); + } + } +#endif + // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel(reinterpret_cast(input.data.dptr), reinterpret_cast(output.amax.dptr), input.data.numel(), +#ifdef __HIP_PLATFORM_AMD__ + block_amax, block_capacity, +#endif stream);); // NOLINT(*) } diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 75e8c14fc..e8083b812 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -10,6 +12,10 @@ #include "pybind.h" #include "transformer_engine/transformer_engine.h" +#ifdef __HIP_PLATFORM_AMD__ +#include "common/common.h" +#endif + namespace transformer_engine::pytorch { std::vector getTensorShape(at::Tensor t) { @@ -277,4 +283,32 @@ int roundup(const int value, const int multiple) { return ((value + multiple - 1) / multiple) * multiple; } +#ifdef __HIP_PLATFORM_AMD__ + +inline bool nvte_use_atomic_amax() { + const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX"); + if (env_p && std::string(env_p) == "1") + return true; + return false; +} + +TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor) { + if (nvte_use_atomic_amax() || input_tensor.numel() == 0) { + // User chose atomic path, or empty tensor -> no need for workspace + return TensorWrapper{}; + } + + const auto N = input_tensor.numel(); + constexpr size_t max_blocks_hw = 65535; + + size_t max_blocks = DIVUP(N, static_cast(amax_kernel_threads)); + size_t workspace_blocks = std::min(max_blocks, max_blocks_hw); + + at::Tensor ws = at::empty(workspace_blocks, at::CUDA(at::kFloat)); + + return makeTransformerEngineTensor(ws); +} + +#endif + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c15a1ae3c..13d08f141 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -374,6 +374,9 @@ std::vector convertShape(const NVTEShape& shape); int roundup(const int value, const int multiple); +#ifdef __HIP_PLATFORM_AMD__ +TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor); +#endif } // namespace transformer_engine::pytorch namespace std { diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 189190f68..93a3ef27c 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -36,10 +38,18 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto [te_output_act, out_act] = my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); +#ifdef __HIP_PLATFORM_AMD__ + auto workspace = allocate_amax_workspace(te_input); +#endif NVTE_SCOPED_GIL_RELEASE({ act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream()); // use te_output_act as input to the compute amax and find the amax of activated tensor +#ifdef __HIP_PLATFORM_AMD__ + nvte_compute_amax_with_workspace(te_output_act.data(), te_output.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); +#else nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); +#endif }); // my_quantizer here has to be a Float8CurrentScalingQuantizer diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 63455e3c0..c4d91bf77 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -49,7 +51,13 @@ std::vector bgrad_quantize(const at::Tensor& input, py::handle py_qu // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(quantizer.get()); NVTE_SCOPED_GIL_RELEASE({ +#ifdef __HIP_PLATFORM_AMD__ + nvte_compute_amax_with_workspace(input_tensor.data(), out_tensor.data(), + allocate_amax_workspace(input_tensor).data(), + at::cuda::getCurrentCUDAStream()); +#else nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); +#endif }); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 1edbef8cd..a9f799278 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -53,7 +55,13 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); NVTE_SCOPED_GIL_RELEASE({ +#ifdef __HIP_PLATFORM_AMD__ + nvte_compute_amax_with_workspace(te_input.data(), te_output.data(), + allocate_amax_workspace(te_input).data(), + at::cuda::getCurrentCUDAStream()); +#else nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); +#endif }); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 23e415c40..39960fdc3 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -144,8 +144,14 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); NVTE_SCOPED_GIL_RELEASE({ +#ifdef __HIP_PLATFORM_AMD__ + nvte_compute_amax_with_workspace(unquantized_out_cu.data(), out_cu.data(), + allocate_amax_workspace(unquantized_out_cu).data(), + at::cuda::getCurrentCUDAStream()); +#else nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); +#endif }); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { @@ -302,8 +308,14 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); NVTE_SCOPED_GIL_RELEASE({ +#ifdef __HIP_PLATFORM_AMD__ + nvte_compute_amax_with_workspace(unquantized_out_cu.data(), out_cu.data(), + allocate_amax_workspace(unquantized_out_cu).data(), + at::cuda::getCurrentCUDAStream()); +#else nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); +#endif }); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index eb4d60bd0..705b57d58 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -25,7 +27,13 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { DType::kFloat8E4M3, // It doesn't matter because we only compute amax. amax.data_ptr()); +#ifdef __HIP_PLATFORM_AMD__ + nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), + allocate_amax_workspace(te_input).data(), + at::cuda::getCurrentCUDAStream()); +#else nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); +#endif } void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer,