Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmarks/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ function(add_te_benchmark TARGET_NAME SOURCE_FILE)
)
endfunction()

add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp)
add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp)
add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp)
add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp)
add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp)
add_te_benchmark(bench_group_quantize_mxfp8 cast/bench_group_quantize_mxfp8.cpp)
add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp)
390 changes: 390 additions & 0 deletions benchmarks/cpp/cast/bench_group_quantize_mxfp8.cpp

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions benchmarks/cpp/run_benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ main() {
echo -e "\n[2/3] Running benchmarks..."

BENCHMARKS=(
"bench_quantize_mxfp8_fused"
"bench_casttranspose"
"bench_dequantize_mxfp8"
"bench_gated_mxfp8"
"bench_casttranspose"
"bench_group_quantize_mxfp8"
"bench_quantize_mxfp8_fused"
)

FAILED_BENCHMARKS=()
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ if(USE_ROCM)
list(APPEND test_cuda_sources
test_dequantize_nvfp4.cu
test_cublaslt_gemm.cu
test_cast_mxfp4_transpose.cu)
test_cast_mxfp4_transpose.cu
test_multi_quantize_mxfp8.cu)
TE_GetHipifiedSources("${test_cuda_sources}" ${CMAKE_CURRENT_SOURCE_DIR} test_hip_sources)
message("${message_line}")
message(STATUS "test_operator hipified sources: ${test_hip_sources}")
Expand Down
59 changes: 39 additions & 20 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,17 @@ void performTest(const ProcessingMethod processing_method,
const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32);
const size_t unpadded_colwise_blocks_X = K;

#ifdef __HIP_PLATFORM_AMD__
rowwise_scales_first_dim[t] = unpadded_rowwise_blocks_Y;
rowwise_scales_last_dim[t] = unpadded_rowwise_blocks_X;
colwise_scales_first_dim[t] = unpadded_colwise_blocks_Y;
colwise_scales_last_dim[t] = unpadded_colwise_blocks_X;
#else
rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128);
rowwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4);
colwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_Y, 4);
colwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128);
#endif

const size_t rowwise_sfs = rowwise_scales_first_dim[t] * rowwise_scales_last_dim[t];
const size_t colwise_sfs = colwise_scales_first_dim[t] * colwise_scales_last_dim[t];
Expand Down Expand Up @@ -566,22 +573,27 @@ void performTest(const ProcessingMethod processing_method,
size_t mismatches_scales = 0;
#ifdef USE_ROCM
std::vector<size_t> mismatches_scales_indices;
#endif
compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(),
1, rowwise_sfs_num, rowwise_sfs_num,
#ifdef USE_ROCM
mismatches_scales_indices,
#endif
mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);

#ifdef USE_ROCM
for (size_t t = 0; t < num_tensors; t++) {
compare_scaling_factors("rowwise_scales",
out_scales_rowwise_h.data() + rowwise_scales_offset[t],
out_scales_rowwise_ref.data() + rowwise_scales_offset[t],
rowwise_scales_first_dim[t], rowwise_scales_last_dim[t],
rowwise_scales_last_dim[t],
mismatches_scales_indices,
mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
}
if (::testing::Test::HasFatalFailure()) return;
adjust_ref_for_e8m0_scale_error("rowwise_scales", mismatches_scales_indices,
out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(),
rowwise_sfs_num, rows, cols, true,
out_data_rowwise_ref.data(), otype);
mismatches_scales = 0;
#else
compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(),
1, rowwise_sfs_num, rowwise_sfs_num,
mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
#endif
const size_t mismatches_elts = 32 * mismatches_scales;

Expand All @@ -596,22 +608,27 @@ void performTest(const ProcessingMethod processing_method,
size_t mismatches_scales = 0;
#ifdef USE_ROCM
std::vector<size_t> mismatches_scales_indices;
#endif
compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(),
1, colwise_sfs_num, colwise_sfs_num,
#ifdef USE_ROCM
mismatches_scales_indices,
#endif
mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);

#ifdef USE_ROCM
for (size_t t = 0; t < num_tensors; t++) {
compare_scaling_factors("colwise_scales",
out_scales_colwise_h.data() + colwise_scales_offset[t],
out_scales_colwise_ref.data() + colwise_scales_offset[t],
colwise_scales_first_dim[t], colwise_scales_last_dim[t],
colwise_scales_last_dim[t],
mismatches_scales_indices,
mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
}
if (::testing::Test::HasFatalFailure()) return;
adjust_ref_for_e8m0_scale_error("colwise_scales", mismatches_scales_indices,
out_scales_colwise_h.data(), out_scales_colwise_ref.data(),
colwise_sfs_num, rows, cols, false,
out_data_colwise_ref.data(), otype);
mismatches_scales = 0;
#else
compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(),
1, colwise_sfs_num, colwise_sfs_num,
mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
#endif
const size_t mismatches_elts = 32 * mismatches_scales;

Expand Down Expand Up @@ -703,10 +720,12 @@ class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam
>> {};

TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
#ifndef __HIP_PLATFORM_AMD__
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
#endif

using namespace transformer_engine;
using namespace test;
Expand Down
158 changes: 158 additions & 0 deletions tests/cpp/operator/test_multi_quantize_mxfp8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*************************************************************************
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/

#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/cast.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

using namespace transformer_engine;
using namespace test;

namespace {

template <typename IType, typename OType>
void performTest(const std::vector<std::pair<size_t, size_t>> &tensor_dims,
bool rowwise, bool colwise) {
const DType itype = TypeInfo<IType>::dtype;
const DType otype = TypeInfo<OType>::dtype;
const size_t num_tensors = tensor_dims.size();

std::vector<Tensor> inputs, outputs_multi, outputs_ref;

for (size_t i = 0; i < num_tensors; i++) {
auto [rows, cols] = tensor_dims[i];
inputs.emplace_back("input_" + std::to_string(i),
std::vector<size_t>{rows, cols}, itype);
outputs_multi.emplace_back("output_multi_" + std::to_string(i),
std::vector<size_t>{rows, cols}, otype,
rowwise, colwise, NVTE_MXFP8_1D_SCALING);
outputs_ref.emplace_back("output_ref_" + std::to_string(i),
std::vector<size_t>{rows, cols}, otype,
rowwise, colwise, NVTE_MXFP8_1D_SCALING);
fillUniform(&inputs.back());
}

std::vector<NVTETensor> nvte_inputs, nvte_outputs_multi;
for (auto &t : inputs) {
nvte_inputs.push_back(t.data());
}
for (auto &t : outputs_multi) {
nvte_outputs_multi.push_back(t.data());
}

nvte_multi_quantize_mxfp8(num_tensors, nvte_inputs.data(),
nvte_outputs_multi.data(), 0);

for (size_t i = 0; i < num_tensors; i++) {
if (tensor_dims[i].first > 0 && tensor_dims[i].second > 0)
nvte_quantize(inputs[i].data(), outputs_ref[i].data(), 0);
}
cudaDeviceSynchronize();

for (size_t i = 0; i < num_tensors; i++) {
auto [rows, cols] = tensor_dims[i];
if (rows == 0 || cols == 0) continue;
if (rowwise) {
auto *multi_data = outputs_multi[i].rowwise_cpu_dptr<OType>();
auto *ref_data = outputs_ref[i].rowwise_cpu_dptr<OType>();
for (size_t j = 0; j < rows * cols; j++) {
ASSERT_EQ(static_cast<uint8_t>(multi_data[j]),
static_cast<uint8_t>(ref_data[j]))
<< "Mismatch at tensor " << i << " element " << j;
}
auto *multi_scales = outputs_multi[i].rowwise_cpu_scale_inv_ptr<uint8_t>();
auto *ref_scales = outputs_ref[i].rowwise_cpu_scale_inv_ptr<uint8_t>();
size_t num_scales = rows * ((cols + 31) / 32);
for (size_t j = 0; j < num_scales; j++) {
ASSERT_EQ(multi_scales[j], ref_scales[j])
<< "Scale mismatch at tensor " << i << " scale " << j;
}
}
if (colwise) {
auto *multi_data = outputs_multi[i].columnwise_cpu_dptr<OType>();
auto *ref_data = outputs_ref[i].columnwise_cpu_dptr<OType>();
for (size_t j = 0; j < rows * cols; j++) {
ASSERT_EQ(static_cast<uint8_t>(multi_data[j]),
static_cast<uint8_t>(ref_data[j]))
<< "Colwise mismatch at tensor " << i << " element " << j;
}
}
}
}

std::vector<std::pair<size_t, size_t>> getTestDims(int config) {
switch (config) {
case 0:
return {{128, 128},
{256, 256},
{128, 512},
{512, 128}};
case 1:
return {{128, 4096},
{128, 4096},
{384, 4096},
{256, 4096},
{256, 4096}};
case 2:
return {{0, 128},
{128, 256},
{256, 128}};
default:
return {};
}
}

enum ScalingMode {
Rowwise = 0,
Colwise = 1,
Both = 2
};

class MultiQuantizeMXFP8TestSuite
: public ::testing::TestWithParam<
std::tuple<transformer_engine::DType, transformer_engine::DType, int, ScalingMode>> {};

TEST_P(MultiQuantizeMXFP8TestSuite, Test) {
auto [itype, otype, config, mode] = GetParam();
auto dims = getTestDims(config);
bool rowwise = (mode == Rowwise || mode == Both);
bool colwise = (mode == Colwise || mode == Both);

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(itype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(otype, OType,
performTest<IType, OType>(dims, rowwise, colwise);
)
)
}

static const char *scalingModeName(ScalingMode mode) {
switch (mode) {
case Rowwise: return "rowwise";
case Colwise: return "colwise";
case Both: return "both";
default: return "unknown";
}
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest, MultiQuantizeMXFP8TestSuite,
::testing::Combine(
::testing::Values(DType::kBFloat16, DType::kFloat16, DType::kFloat32),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::Values(0, 1, 2),
::testing::Values(Rowwise, Colwise, Both)),
[](const testing::TestParamInfo<MultiQuantizeMXFP8TestSuite::ParamType> &info) {
return test::typeName(std::get<0>(info.param)) + "_" +
test::typeName(std::get<1>(info.param)) + "_config" +
std::to_string(std::get<2>(info.param)) + "_" +
scalingModeName(std::get<3>(info.param));
});

} // namespace
14 changes: 14 additions & 0 deletions transformer_engine/common/cast/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}

#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_quantize_mxfp8(size_t num_tensors, const NVTETensor *input_list,
NVTETensor *output_list, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_quantize_mxfp8);
using namespace transformer_engine;
std::vector<Tensor *> input_list_, output_list_;
for (size_t i = 0; i < num_tensors; i++) {
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
}
dispatch::multi_quantize_mxfp8(input_list_, output_list_, stream);
}
#endif

void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize);
Expand Down
Loading
Loading