Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c15d93b
Current scaling: two-stage amax kernel
matthiasdiener Nov 12, 2025
51fab36
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 13, 2025
ae35e4c
bugfix graph capture
matthiasdiener Nov 13, 2025
77a68a7
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 17, 2025
c0d8e73
outline workspace allocation
matthiasdiener Nov 17, 2025
6c3507d
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 18, 2025
3c9de07
Proper allocation of workspace
matthiasdiener Nov 18, 2025
91249cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
be0e0c8
add a test to compare the accuracy of both amax implementations
matthiasdiener Nov 19, 2025
bce34da
add possibility to force using previous (atomic) kernel
matthiasdiener Nov 19, 2025
8c388cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
6388604
add copyrights
matthiasdiener Nov 20, 2025
9e6586f
don't add extra template to kernel
matthiasdiener Nov 20, 2025
18292bf
make amax_kernel_threads usable in pytorch
matthiasdiener Nov 21, 2025
a389455
update remaining calls to nvte_compute_amax
matthiasdiener Nov 21, 2025
d87ab8a
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 24, 2025
fd5dead
additional copyrights
matthiasdiener Nov 24, 2025
16d3bf9
avoid workspace allocations if NVTE_USE_ATOMIC_AMAX is set
matthiasdiener Nov 24, 2025
50b34aa
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 25, 2025
ef532b1
remove use_block_amax parameter, more cleanups
matthiasdiener Nov 25, 2025
f933ef3
Factor workspace allocation into function
matthiasdiener Nov 25, 2025
7d4054e
expand test slightly
matthiasdiener Nov 25, 2025
63cff98
Revert "expand test slightly"
Nov 25, 2025
c7d44a7
guard by HIP macro, address review comments
matthiasdiener Nov 26, 2025
f92b926
bugfix workspace.data.dptr
matthiasdiener Nov 26, 2025
eba552e
various cleanups
matthiasdiener Nov 26, 2025
0d6a177
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 26, 2025
8eda427
simplify types in allocate_amax_workspace
matthiasdiener Nov 26, 2025
6990928
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Dec 1, 2025
9ee618f
fix indentation
matthiasdiener Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/cpp/operator/test_cast_current_scaling.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -195,6 +197,43 @@ TEST_P(CastCSTestSuite, TestCastCS) {
);
}

#ifdef __HIP_PLATFORM_AMD__

TEST(AmaxConsistencyTest, AtomicVsWorkspace) {
using namespace transformer_engine;
using namespace test;

std::vector<size_t> shape{256, 1024};
const size_t N = product(shape);

// Input: FP32, Output: FP8 (E4M3) with per-tensor scaling
Tensor input("input", shape, DType::kFloat32);
Tensor out_atomic("out_atomic", shape, DType::kFloat8E4M3, true, false);
Tensor out_ws("out_ws", shape, DType::kFloat8E4M3, true, false);

fillUniform(&input);

// Path 1: atomic-based amax (no workspace)
nvte_compute_amax(input.data(), out_atomic.data(), 0);

// Path 2: two-stage amax using workspace
std::vector<size_t> ws_shape{N};
Tensor workspace("workspace", ws_shape, DType::kFloat32);
nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

// Compare the resulting amax values
float amax_atomic = out_atomic.amax();
float amax_ws = out_ws.amax();

compareResults("amax_consistency", amax_atomic, amax_ws, /*atol=*/0.0f, /*rtol=*/0.0f);
}

#endif



INSTANTIATE_TEST_SUITE_P(
Expand Down
24 changes: 24 additions & 0 deletions transformer_engine/common/include/transformer_engine/recipe.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -73,6 +75,12 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
float margin, cudaStream_t stream);

#ifdef __HIP_PLATFORM_AMD__

constexpr int amax_kernel_threads = 512;

#endif

/*! \brief Compute an FP8 tensor's amax.
*
* The amax (maximum absolute value) of the input tensor is computed
Expand All @@ -84,6 +92,22 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
*/
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a brief doc just like the nvte_compute_amax above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in c7d44a7

#ifdef __HIP_PLATFORM_AMD__

/*! \brief Compute an FP8 tensor's amax.
*
* The amax (maximum absolute value) of the input tensor is computed
* and written to the amax buffer of the output tensor.
*
* \param[in] input Input tensor. Must be unquantized.
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
* \param[out] workspace Output tensor. Must be FP32.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, NVTETensor workspace, cudaStream_t stream);

#endif

/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
Expand Down
98 changes: 96 additions & 2 deletions transformer_engine/common/recipe/current_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,39 @@ using bf16__ = __nv_bfloat16;
using bf16__ = __hip_bfloat16;
#endif //__HIP_PLATFORM_AMD__

constexpr int amax_kernel_threads = 512;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's guard our rocm specific code changes by macro HIP_PLATFORM_AMD

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in c7d44a7

#ifdef __HIP_PLATFORM_AMD__

template <int BLOCK_THREADS>
__global__ void amax_final_reduce(const float* __restrict__ block_amax,
float* __restrict__ global_amax,
int num_blocks) {
float val = 0.f;

for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) {
val = fmaxf(val, block_amax[i]);
}

const int warp_id = threadIdx.x / THREADS_PER_WARP;
const float block_max =
reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);

if (threadIdx.x == 0) {
*global_amax = block_max;
}
}

#endif

template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
#ifdef __HIP_PLATFORM_AMD__
void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard the api change so NV upstream can remain their flow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in c7d44a7

const size_t num_aligned_elements) {
#else
void amax_kernel(const InputType *input, float *amax, const size_t N,
const size_t num_aligned_elements) {
#endif
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max{0.f};
const int warp_id = threadIdx.x / THREADS_PER_WARP;
Expand Down Expand Up @@ -65,12 +92,23 @@ __launch_bounds__(amax_kernel_threads) __global__
// Reduce amax over block
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
#ifdef __HIP_PLATFORM_AMD__
if (block_amax != nullptr) {
// 2-stage: write per-block result
block_amax[blockIdx.x] = max;
} else {
// Atomic path: directly update global amax
atomicMaxFloat(amax, max);
}
#else
atomicMaxFloat(amax, max);
#endif
}
}

template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax,
size_t block_capacity, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
(void)cudaMemsetAsync(amax, 0, sizeof(float), stream);

Expand All @@ -89,24 +127,54 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);

#ifdef __HIP_PLATFORM_AMD__
if (block_capacity < num_blocks)
block_amax = nullptr;
#endif

// Launch kernel
switch (align) {
case Alignment::SAME_ALIGNED:
#ifdef __HIP_PLATFORM_AMD__
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
#else
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
#endif
break;
case Alignment::SAME_UNALIGNED:
#ifdef __HIP_PLATFORM_AMD__
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
#else
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
#endif
break;
case Alignment::DIFFERENT: {
// This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization.
#ifdef __HIP_PLATFORM_AMD__
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N);
#else
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
#endif
break;
}
}

#ifdef __HIP_PLATFORM_AMD__
if (block_amax != nullptr) {
constexpr int FINAL_REDUCE_THREADS = 256;
dim3 fr_block(FINAL_REDUCE_THREADS);
dim3 fr_grid(1);

amax_final_reduce<FINAL_REDUCE_THREADS>
<<<fr_grid, fr_block, 0, stream>>>(block_amax, amax, static_cast<int>(num_blocks));
}
#endif

// Check results
NVTE_CHECK_CUDA(cudaGetLastError());
}
Expand All @@ -115,6 +183,12 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
} // namespace transformer_engine

void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream);
}

void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) {
#endif
NVTE_API_CALL(nvte_compute_amax);
using namespace transformer_engine;

Expand Down Expand Up @@ -150,11 +224,31 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax", true);

#ifdef __HIP_PLATFORM_AMD__
// Optional workspace
float* block_amax = nullptr;
size_t block_capacity = 0;

if (workspace_ != nullptr) {
auto &workspace = *reinterpret_cast<Tensor *>(workspace_);
if (workspace.data.dptr != nullptr) {
NVTE_CHECK(workspace.data.dtype == DType::kFloat32,
"Workspace tensor for amax computation must be FP32, got dtype=",
to_string(workspace.data.dtype));
block_amax = reinterpret_cast<float*>(workspace.data.dptr);
block_capacity = workspace.data.numel();
}
}
#endif

// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
#ifdef __HIP_PLATFORM_AMD__
block_amax, block_capacity,
#endif
stream);); // NOLINT(*)
}

Expand Down
34 changes: 34 additions & 0 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -10,6 +12,10 @@
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"

#ifdef __HIP_PLATFORM_AMD__
#include "common/common.h"
#endif

namespace transformer_engine::pytorch {

std::vector<size_t> getTensorShape(at::Tensor t) {
Expand Down Expand Up @@ -277,4 +283,32 @@ int roundup(const int value, const int multiple) {
return ((value + multiple - 1) / multiple) * multiple;
}

#ifdef __HIP_PLATFORM_AMD__

inline bool nvte_use_atomic_amax() {
const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX");
if (env_p && std::string(env_p) == "1")
return true;
return false;
}

TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor) {
if (nvte_use_atomic_amax() || input_tensor.numel() == 0) {
// User chose atomic path, or empty tensor -> no need for workspace
return TensorWrapper{};
}

const auto N = input_tensor.numel();
constexpr size_t max_blocks_hw = 65535;

size_t max_blocks = DIVUP(N, static_cast<size_t>(amax_kernel_threads));
size_t workspace_blocks = std::min(max_blocks, max_blocks_hw);

at::Tensor ws = at::empty(workspace_blocks, at::CUDA(at::kFloat));

return makeTransformerEngineTensor(ws);
}

#endif

} // namespace transformer_engine::pytorch
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);

#ifdef __HIP_PLATFORM_AMD__
TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor);
#endif
} // namespace transformer_engine::pytorch

namespace std {
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -36,10 +38,18 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));

#ifdef __HIP_PLATFORM_AMD__
auto workspace = allocate_amax_workspace(te_input);
#endif
NVTE_SCOPED_GIL_RELEASE({
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
#ifdef __HIP_PLATFORM_AMD__
nvte_compute_amax_with_workspace(te_output_act.data(), te_output.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
#else
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
#endif
});

// my_quantizer here has to be a Float8CurrentScalingQuantizer
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/bias.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -49,7 +51,13 @@ std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_qu
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
#ifdef __HIP_PLATFORM_AMD__
nvte_compute_amax_with_workspace(input_tensor.data(), out_tensor.data(),
allocate_amax_workspace(input_tensor).data(),
at::cuda::getCurrentCUDAStream());
#else
nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
#endif
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -53,7 +55,13 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
#ifdef __HIP_PLATFORM_AMD__
nvte_compute_amax_with_workspace(te_input.data(), te_output.data(),
allocate_amax_workspace(te_input).data(),
at::cuda::getCurrentCUDAStream());
#else
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
#endif
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
Expand Down
Loading