diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index 6071f9083..8abccddf2 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -82,7 +82,8 @@ function(add_te_benchmark TARGET_NAME SOURCE_FILE) ) endfunction() -add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) +add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp) add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) -add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp) +add_te_benchmark(bench_group_quantize_mxfp8 cast/bench_group_quantize_mxfp8.cpp) +add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) diff --git a/benchmarks/cpp/cast/bench_group_quantize_mxfp8.cpp b/benchmarks/cpp/cast/bench_group_quantize_mxfp8.cpp new file mode 100644 index 000000000..f293bf1c2 --- /dev/null +++ b/benchmarks/cpp/cast/bench_group_quantize_mxfp8.cpp @@ -0,0 +1,390 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include +#include +#include +#include + +#include "benchmark_utils.h" + +#include "transformer_engine/cast_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +constexpr int kPadMultiple = 128; + +static std::vector generate_routed_tokens(int total_tokens, int num_experts, + std::mt19937 &rng, bool skewed, + double zipf_s = 0.7) { + std::vector weights(num_experts); + if (skewed) { + for (int i = 0; i < num_experts; i++) { + weights[i] = 1.0 / std::pow(i + 1, zipf_s); + } + } else { + std::fill(weights.begin(), weights.end(), 1.0); + } + double sum = std::accumulate(weights.begin(), weights.end(), 0.0); + for (auto &w : weights) { + w /= sum; + } + std::shuffle(weights.begin(), weights.end(), rng); + + std::discrete_distribution dist(weights.begin(), weights.end()); + + std::vector tokens(num_experts, 0); + for (int i = 0; i < total_tokens; i++) { + tokens[dist(rng)]++; + } + + for (auto &t : tokens) { + t = std::max(kPadMultiple, ((t + kPadMultiple - 1) / kPadMultiple) * kPadMultiple); + } + return tokens; +} + +template +static void BM_GroupQuantizeMXFP8(benchmark::State &state) { + const int num_experts = state.range(0); + const int cols = state.range(1); + const int total_tokens = state.range(2); + const int skewed = state.range(3); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + unsigned int seed = std::random_device{}(); + uint64_t config_hash = seed ^ (uint64_t(num_experts) * 2654435761ULL) + ^ (uint64_t(cols) * 40503ULL) + ^ (uint64_t(total_tokens) * 12345ULL); + std::mt19937 rng(config_hash); + + std::vector token_counts = generate_routed_tokens(total_tokens, num_experts, rng, skewed); + + int min_tok = *std::min_element(token_counts.begin(), token_counts.end()); + int max_tok = *std::max_element(token_counts.begin(), token_counts.end()); + int sum_tok = std::accumulate(token_counts.begin(), token_counts.end(), 0); + int avg_tok = sum_tok / num_experts; + + size_t total_elements = 0; + std::vector first_dims_h(num_experts); + std::vector offsets_h(num_experts + 1); + offsets_h[0] = 0; + for (int i = 0; i < num_experts; i++) { + first_dims_h[i] = token_counts[i]; + total_elements += static_cast(token_counts[i]) * cols; + offsets_h[i + 1] = static_cast(total_elements); + } + + size_t total_rowwise_scales = 0, total_colwise_scales = 0; + for (int i = 0; i < num_experts; i++) { + if (USE_ROWWISE) total_rowwise_scales += token_counts[i] * ((cols + 31) / 32); + if (USE_COLWISE) total_colwise_scales += ((token_counts[i] + 31) / 32) * cols; + } + + void *in_data_d = nullptr, *out_data_rw_d = nullptr, *out_data_cw_d = nullptr; + void *scales_rw_d = nullptr, *scales_cw_d = nullptr; + int64_t *first_dims_d = nullptr, *offsets_d = nullptr; + float *amax_d = nullptr; + + HIP_CHECK(hipMalloc(&in_data_d, total_elements * sizeof(IType))); + if (USE_ROWWISE) { + HIP_CHECK(hipMalloc(&out_data_rw_d, total_elements * sizeof(OType))); + HIP_CHECK(hipMalloc(&scales_rw_d, total_rowwise_scales)); + } + if (USE_COLWISE) { + HIP_CHECK(hipMalloc(&out_data_cw_d, total_elements * sizeof(OType))); + HIP_CHECK(hipMalloc(&scales_cw_d, total_colwise_scales)); + } + HIP_CHECK(hipMalloc(&amax_d, sizeof(float))); + HIP_CHECK(hipMalloc(&first_dims_d, num_experts * sizeof(int64_t))); + HIP_CHECK(hipMalloc(&offsets_d, (num_experts + 1) * sizeof(int64_t))); + HIP_CHECK(hipMemcpy(first_dims_d, first_dims_h.data(), num_experts * sizeof(int64_t), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(offsets_d, offsets_h.data(), (num_experts + 1) * sizeof(int64_t), hipMemcpyHostToDevice)); + + std::vector logical_shape_vec = {static_cast(sum_tok), static_cast(cols)}; + NVTEShape logical_shape = nvte_make_shape(logical_shape_vec.data(), 2); + NVTEShape first_dims_shape; + first_dims_shape.ndim = 1; + first_dims_shape.data[0] = num_experts; + NVTEShape offsets_shape; + offsets_shape.ndim = 1; + offsets_shape.data[0] = num_experts + 1; + + NVTEGroupedTensor in_gt = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_experts, logical_shape); + NVTEGroupedTensor out_gt = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_experts, logical_shape); + + NVTEBasicTensor in_bt = {in_data_d, std::is_same_v ? kNVTEFloat32 : + (std::is_same_v ? kNVTEBFloat16 : kNVTEFloat16), logical_shape}; + nvte_set_grouped_tensor_param(in_gt, kNVTEGroupedRowwiseData, &in_bt, sizeof(in_bt)); + + NVTEBasicTensor fd_bt = {first_dims_d, kNVTEInt64, first_dims_shape}; + NVTEBasicTensor off_bt = {offsets_d, kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(in_gt, kNVTEGroupedFirstDims, &fd_bt, sizeof(fd_bt)); + nvte_set_grouped_tensor_param(in_gt, kNVTEGroupedTensorOffsets, &off_bt, sizeof(off_bt)); + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedFirstDims, &fd_bt, sizeof(fd_bt)); + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedTensorOffsets, &off_bt, sizeof(off_bt)); + + if (USE_ROWWISE) { + NVTEBasicTensor rw_data_bt = {out_data_rw_d, std::is_same_v ? kNVTEFloat8E4M3 : kNVTEFloat8E5M2, logical_shape}; + std::vector scales_rw_shape = {total_rowwise_scales}; + NVTEShape scales_rw_nvshape = nvte_make_shape(scales_rw_shape.data(), 1); + NVTEBasicTensor rw_scales_bt = {scales_rw_d, kNVTEFloat8E8M0, scales_rw_nvshape}; + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedRowwiseData, &rw_data_bt, sizeof(rw_data_bt)); + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedRowwiseScaleInv, &rw_scales_bt, sizeof(rw_scales_bt)); + } + if (USE_COLWISE) { + NVTEBasicTensor cw_data_bt = {out_data_cw_d, std::is_same_v ? kNVTEFloat8E4M3 : kNVTEFloat8E5M2, logical_shape}; + std::vector scales_cw_shape = {total_colwise_scales}; + NVTEShape scales_cw_nvshape = nvte_make_shape(scales_cw_shape.data(), 1); + NVTEBasicTensor cw_scales_bt = {scales_cw_d, kNVTEFloat8E8M0, scales_cw_nvshape}; + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedColumnwiseData, &cw_data_bt, sizeof(cw_data_bt)); + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedColumnwiseScaleInv, &cw_scales_bt, sizeof(cw_scales_bt)); + } + + NVTEBasicTensor amax_bt = {amax_d, kNVTEFloat32, nvte_make_shape(std::vector{1}.data(), 1)}; + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedAmax, &amax_bt, sizeof(amax_bt)); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_group_quantize(in_gt, out_gt, stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + size_t bytes_read = total_elements * sizeof(IType); + size_t bytes_write = total_elements * sizeof(OType) * ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + set_bytes_processed(state, bytes_read + bytes_write + total_rowwise_scales + total_colwise_scales); + + state.counters["experts"] = num_experts; + state.counters["cols"] = cols; + state.counters["avg_tok"] = avg_tok; + state.counters["min_tok"] = min_tok; + state.counters["max_tok"] = max_tok; + + nvte_destroy_grouped_tensor(in_gt); + nvte_destroy_grouped_tensor(out_gt); + hipFree(in_data_d); + if (out_data_rw_d) hipFree(out_data_rw_d); + if (out_data_cw_d) hipFree(out_data_cw_d); + if (scales_rw_d) hipFree(scales_rw_d); + if (scales_cw_d) hipFree(scales_cw_d); + hipFree(amax_d); + hipFree(first_dims_d); + hipFree(offsets_d); + HIP_CHECK(hipStreamDestroy(stream)); +} + +template +static void BM_MultiQuantizeMXFP8(benchmark::State &state) { + const int num_experts = state.range(0); + const int cols = state.range(1); + const int total_tokens = state.range(2); + const int skewed = state.range(3); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + unsigned int seed = std::random_device{}(); + uint64_t config_hash = seed ^ (uint64_t(num_experts) * 2654435761ULL) + ^ (uint64_t(cols) * 40503ULL) + ^ (uint64_t(total_tokens) * 12345ULL); + std::mt19937 rng(config_hash); + + std::vector token_counts = generate_routed_tokens(total_tokens, num_experts, rng, skewed); + + int min_tok = *std::min_element(token_counts.begin(), token_counts.end()); + int max_tok = *std::max_element(token_counts.begin(), token_counts.end()); + int sum_tok = std::accumulate(token_counts.begin(), token_counts.end(), 0); + int avg_tok = sum_tok / num_experts; + + size_t total_elements = 0; + for (int i = 0; i < num_experts; i++) + total_elements += static_cast(token_counts[i]) * cols; + + std::vector in_ptrs(num_experts), out_rw_ptrs(num_experts), out_cw_ptrs(num_experts); + std::vector scales_rw_ptrs(num_experts), scales_cw_ptrs(num_experts); + std::vector amax_ptrs(num_experts); + std::vector nvte_inputs(num_experts), nvte_outputs(num_experts); + + NVTEDType nvte_itype = std::is_same_v ? kNVTEFloat32 : + (std::is_same_v ? kNVTEBFloat16 : kNVTEFloat16); + NVTEDType nvte_otype = std::is_same_v ? kNVTEFloat8E4M3 : kNVTEFloat8E5M2; + + for (int i = 0; i < num_experts; i++) { + size_t rows = token_counts[i]; + size_t elts = rows * cols; + size_t rw_scales = rows * ((cols + 31) / 32); + size_t cw_scales = ((rows + 31) / 32) * cols; + + HIP_CHECK(hipMalloc(&in_ptrs[i], elts * sizeof(IType))); + if (USE_ROWWISE) { + HIP_CHECK(hipMalloc(&out_rw_ptrs[i], elts * sizeof(OType))); + HIP_CHECK(hipMalloc(&scales_rw_ptrs[i], rw_scales)); + } + if (USE_COLWISE) { + HIP_CHECK(hipMalloc(&out_cw_ptrs[i], elts * sizeof(OType))); + HIP_CHECK(hipMalloc(&scales_cw_ptrs[i], cw_scales)); + } + HIP_CHECK(hipMalloc(&amax_ptrs[i], sizeof(float))); + + std::vector shape_vec = {rows, static_cast(cols)}; + NVTEShape shape = nvte_make_shape(shape_vec.data(), 2); + + nvte_inputs[i] = nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING); + NVTEBasicTensor in_bt = {in_ptrs[i], nvte_itype, shape}; + nvte_set_tensor_param(&nvte_inputs[i], kNVTERowwiseData, &in_bt); + + nvte_outputs[i] = nvte_create_tensor(NVTE_MXFP8_1D_SCALING); + if (USE_ROWWISE) { + NVTEBasicTensor rw_bt = {out_rw_ptrs[i], nvte_otype, shape}; + std::vector srw = {rw_scales}; + NVTEBasicTensor srw_bt = {scales_rw_ptrs[i], kNVTEFloat8E8M0, nvte_make_shape(srw.data(), 1)}; + nvte_set_tensor_param(&nvte_outputs[i], kNVTERowwiseData, &rw_bt); + nvte_set_tensor_param(&nvte_outputs[i], kNVTERowwiseScaleInv, &srw_bt); + } + if (USE_COLWISE) { + NVTEBasicTensor cw_bt = {out_cw_ptrs[i], nvte_otype, shape}; + std::vector scw = {cw_scales}; + NVTEBasicTensor scw_bt = {scales_cw_ptrs[i], kNVTEFloat8E8M0, nvte_make_shape(scw.data(), 1)}; + nvte_set_tensor_param(&nvte_outputs[i], kNVTEColumnwiseData, &cw_bt); + nvte_set_tensor_param(&nvte_outputs[i], kNVTEColumnwiseScaleInv, &scw_bt); + } + NVTEBasicTensor amax_bt = {amax_ptrs[i], kNVTEFloat32, nvte_make_shape(std::vector{1}.data(), 1)}; + nvte_set_tensor_param(&nvte_outputs[i], kNVTEAmax, &amax_bt); + } + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + nvte_multi_quantize_mxfp8(num_experts, nvte_inputs.data(), nvte_outputs.data(), stream); + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + size_t bytes_read = total_elements * sizeof(IType); + size_t bytes_write = total_elements * sizeof(OType) * ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + size_t total_rw_scales = 0, total_cw_scales = 0; + for (int i = 0; i < num_experts; i++) { + if (USE_ROWWISE) total_rw_scales += token_counts[i] * ((cols + 31) / 32); + if (USE_COLWISE) total_cw_scales += ((token_counts[i] + 31) / 32) * cols; + } + set_bytes_processed(state, bytes_read + bytes_write + total_rw_scales + total_cw_scales); + + state.counters["experts"] = num_experts; + state.counters["cols"] = cols; + state.counters["avg_tok"] = avg_tok; + state.counters["min_tok"] = min_tok; + state.counters["max_tok"] = max_tok; + + for (int i = 0; i < num_experts; i++) { + nvte_destroy_tensor(nvte_inputs[i]); + nvte_destroy_tensor(nvte_outputs[i]); + HIP_CHECK(hipFree(in_ptrs[i])); + if (out_rw_ptrs[i]) HIP_CHECK(hipFree(out_rw_ptrs[i])); + if (out_cw_ptrs[i]) HIP_CHECK(hipFree(out_cw_ptrs[i])); + if (scales_rw_ptrs[i]) HIP_CHECK(hipFree(scales_rw_ptrs[i])); + if (scales_cw_ptrs[i]) HIP_CHECK(hipFree(scales_cw_ptrs[i])); + HIP_CHECK(hipFree(amax_ptrs[i])); + } + HIP_CHECK(hipStreamDestroy(stream)); +} + +// experts, cols, total_tokens, skewed +#define MOE_BALANCED \ + ->Args({128, 4096, 65536, 0}) /* Qwen3 H=4096 */ \ + ->Args({128, 1536, 65536, 0}) /* Qwen3 I=1536 */ \ + ->Args({256, 7168, 131072, 0}) /* DeepSeek H=7168 */ \ + ->Args({256, 2048, 131072, 0}) /* DeepSeek I=2048 */ + +#define MOE_SKEWED \ + ->Args({128, 4096, 65536, 1}) /* Qwen3 H=4096 */ \ + ->Args({128, 1536, 65536, 1}) /* Qwen3 I=1536 */ \ + ->Args({256, 7168, 131072, 1}) /* DeepSeek H=7168 */ \ + ->Args({256, 2048, 131072, 1}) /* DeepSeek I=2048 */ + +#define REGISTER_GROUP_QUANTIZE(ITYPE, OTYPE, INAME, ONAME) \ + BENCHMARK_TEMPLATE(BM_GroupQuantizeMXFP8, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_GroupQuantizeMXFP8/rowwise/" INAME "_" ONAME) \ + MOE_BALANCED MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GroupQuantizeMXFP8, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_GroupQuantizeMXFP8/colwise/" INAME "_" ONAME) \ + MOE_BALANCED MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GroupQuantizeMXFP8, ITYPE, OTYPE, 32, 32) \ + ->Name("BM_GroupQuantizeMXFP8/both/" INAME "_" ONAME) \ + MOE_BALANCED MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) ->UseManualTime(); + +REGISTER_GROUP_QUANTIZE(hip_bfloat16, fp8_e4m3, "BF16", "E4M3") + +#define REGISTER_MULTI_QUANTIZE(ITYPE, OTYPE, INAME, ONAME) \ + BENCHMARK_TEMPLATE(BM_MultiQuantizeMXFP8, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_MultiQuantizeMXFP8/rowwise/" INAME "_" ONAME) \ + MOE_BALANCED MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_MultiQuantizeMXFP8, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_MultiQuantizeMXFP8/colwise/" INAME "_" ONAME) \ + MOE_BALANCED MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_MultiQuantizeMXFP8, ITYPE, OTYPE, 32, 32) \ + ->Name("BM_MultiQuantizeMXFP8/both/" INAME "_" ONAME) \ + MOE_BALANCED MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) ->UseManualTime(); + +REGISTER_MULTI_QUANTIZE(hip_bfloat16, fp8_e4m3, "BF16", "E4M3") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/run_benchmarks.sh b/benchmarks/cpp/run_benchmarks.sh index 05f7f853e..40775ebf1 100755 --- a/benchmarks/cpp/run_benchmarks.sh +++ b/benchmarks/cpp/run_benchmarks.sh @@ -23,10 +23,11 @@ main() { echo -e "\n[2/3] Running benchmarks..." BENCHMARKS=( - "bench_quantize_mxfp8_fused" + "bench_casttranspose" "bench_dequantize_mxfp8" "bench_gated_mxfp8" - "bench_casttranspose" + "bench_group_quantize_mxfp8" + "bench_quantize_mxfp8_fused" ) FAILED_BENCHMARKS=() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0f42b1ef3..75856baa0 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -47,7 +47,8 @@ if(USE_ROCM) list(APPEND test_cuda_sources test_dequantize_nvfp4.cu test_cublaslt_gemm.cu - test_cast_mxfp4_transpose.cu) + test_cast_mxfp4_transpose.cu + test_multi_quantize_mxfp8.cu) TE_GetHipifiedSources("${test_cuda_sources}" ${CMAKE_CURRENT_SOURCE_DIR} test_hip_sources) message("${message_line}") message(STATUS "test_operator hipified sources: ${test_hip_sources}") diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 9f2523cb6..2f0a356bb 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -282,10 +282,17 @@ void performTest(const ProcessingMethod processing_method, const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32); const size_t unpadded_colwise_blocks_X = K; +#ifdef __HIP_PLATFORM_AMD__ + rowwise_scales_first_dim[t] = unpadded_rowwise_blocks_Y; + rowwise_scales_last_dim[t] = unpadded_rowwise_blocks_X; + colwise_scales_first_dim[t] = unpadded_colwise_blocks_Y; + colwise_scales_last_dim[t] = unpadded_colwise_blocks_X; +#else rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128); rowwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4); colwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_Y, 4); colwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128); +#endif const size_t rowwise_sfs = rowwise_scales_first_dim[t] * rowwise_scales_last_dim[t]; const size_t colwise_sfs = colwise_scales_first_dim[t] * colwise_scales_last_dim[t]; @@ -566,22 +573,27 @@ void performTest(const ProcessingMethod processing_method, size_t mismatches_scales = 0; #ifdef USE_ROCM std::vector mismatches_scales_indices; -#endif - compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), - 1, rowwise_sfs_num, rowwise_sfs_num, -#ifdef USE_ROCM - mismatches_scales_indices, -#endif - mismatches_scales, scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); - -#ifdef USE_ROCM + for (size_t t = 0; t < num_tensors; t++) { + compare_scaling_factors("rowwise_scales", + out_scales_rowwise_h.data() + rowwise_scales_offset[t], + out_scales_rowwise_ref.data() + rowwise_scales_offset[t], + rowwise_scales_first_dim[t], rowwise_scales_last_dim[t], + rowwise_scales_last_dim[t], + mismatches_scales_indices, + mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + } if (::testing::Test::HasFatalFailure()) return; adjust_ref_for_e8m0_scale_error("rowwise_scales", mismatches_scales_indices, out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), rowwise_sfs_num, rows, cols, true, out_data_rowwise_ref.data(), otype); mismatches_scales = 0; +#else + compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), + 1, rowwise_sfs_num, rowwise_sfs_num, + mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); #endif const size_t mismatches_elts = 32 * mismatches_scales; @@ -596,22 +608,27 @@ void performTest(const ProcessingMethod processing_method, size_t mismatches_scales = 0; #ifdef USE_ROCM std::vector mismatches_scales_indices; -#endif - compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(), - 1, colwise_sfs_num, colwise_sfs_num, -#ifdef USE_ROCM - mismatches_scales_indices, -#endif - mismatches_scales, scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); - -#ifdef USE_ROCM + for (size_t t = 0; t < num_tensors; t++) { + compare_scaling_factors("colwise_scales", + out_scales_colwise_h.data() + colwise_scales_offset[t], + out_scales_colwise_ref.data() + colwise_scales_offset[t], + colwise_scales_first_dim[t], colwise_scales_last_dim[t], + colwise_scales_last_dim[t], + mismatches_scales_indices, + mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + } if (::testing::Test::HasFatalFailure()) return; adjust_ref_for_e8m0_scale_error("colwise_scales", mismatches_scales_indices, out_scales_colwise_h.data(), out_scales_colwise_ref.data(), colwise_sfs_num, rows, cols, false, out_data_colwise_ref.data(), otype); mismatches_scales = 0; +#else + compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(), + 1, colwise_sfs_num, colwise_sfs_num, + mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); #endif const size_t mismatches_elts = 32 * mismatches_scales; @@ -703,10 +720,12 @@ class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam >> {}; TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { +#ifndef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures if (getDeviceComputeCapability() < blackwellComputeCapability) { GTEST_SKIP(); } +#endif using namespace transformer_engine; using namespace test; diff --git a/tests/cpp/operator/test_multi_quantize_mxfp8.cu b/tests/cpp/operator/test_multi_quantize_mxfp8.cu new file mode 100644 index 000000000..67d3e55b3 --- /dev/null +++ b/tests/cpp/operator/test_multi_quantize_mxfp8.cu @@ -0,0 +1,158 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void performTest(const std::vector> &tensor_dims, + bool rowwise, bool colwise) { + const DType itype = TypeInfo::dtype; + const DType otype = TypeInfo::dtype; + const size_t num_tensors = tensor_dims.size(); + + std::vector inputs, outputs_multi, outputs_ref; + + for (size_t i = 0; i < num_tensors; i++) { + auto [rows, cols] = tensor_dims[i]; + inputs.emplace_back("input_" + std::to_string(i), + std::vector{rows, cols}, itype); + outputs_multi.emplace_back("output_multi_" + std::to_string(i), + std::vector{rows, cols}, otype, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); + outputs_ref.emplace_back("output_ref_" + std::to_string(i), + std::vector{rows, cols}, otype, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); + fillUniform(&inputs.back()); + } + + std::vector nvte_inputs, nvte_outputs_multi; + for (auto &t : inputs) { + nvte_inputs.push_back(t.data()); + } + for (auto &t : outputs_multi) { + nvte_outputs_multi.push_back(t.data()); + } + + nvte_multi_quantize_mxfp8(num_tensors, nvte_inputs.data(), + nvte_outputs_multi.data(), 0); + + for (size_t i = 0; i < num_tensors; i++) { + if (tensor_dims[i].first > 0 && tensor_dims[i].second > 0) + nvte_quantize(inputs[i].data(), outputs_ref[i].data(), 0); + } + cudaDeviceSynchronize(); + + for (size_t i = 0; i < num_tensors; i++) { + auto [rows, cols] = tensor_dims[i]; + if (rows == 0 || cols == 0) continue; + if (rowwise) { + auto *multi_data = outputs_multi[i].rowwise_cpu_dptr(); + auto *ref_data = outputs_ref[i].rowwise_cpu_dptr(); + for (size_t j = 0; j < rows * cols; j++) { + ASSERT_EQ(static_cast(multi_data[j]), + static_cast(ref_data[j])) + << "Mismatch at tensor " << i << " element " << j; + } + auto *multi_scales = outputs_multi[i].rowwise_cpu_scale_inv_ptr(); + auto *ref_scales = outputs_ref[i].rowwise_cpu_scale_inv_ptr(); + size_t num_scales = rows * ((cols + 31) / 32); + for (size_t j = 0; j < num_scales; j++) { + ASSERT_EQ(multi_scales[j], ref_scales[j]) + << "Scale mismatch at tensor " << i << " scale " << j; + } + } + if (colwise) { + auto *multi_data = outputs_multi[i].columnwise_cpu_dptr(); + auto *ref_data = outputs_ref[i].columnwise_cpu_dptr(); + for (size_t j = 0; j < rows * cols; j++) { + ASSERT_EQ(static_cast(multi_data[j]), + static_cast(ref_data[j])) + << "Colwise mismatch at tensor " << i << " element " << j; + } + } + } +} + +std::vector> getTestDims(int config) { + switch (config) { + case 0: + return {{128, 128}, + {256, 256}, + {128, 512}, + {512, 128}}; + case 1: + return {{128, 4096}, + {128, 4096}, + {384, 4096}, + {256, 4096}, + {256, 4096}}; + case 2: + return {{0, 128}, + {128, 256}, + {256, 128}}; + default: + return {}; + } +} + +enum ScalingMode { + Rowwise = 0, + Colwise = 1, + Both = 2 +}; + +class MultiQuantizeMXFP8TestSuite + : public ::testing::TestWithParam< + std::tuple> {}; + +TEST_P(MultiQuantizeMXFP8TestSuite, Test) { + auto [itype, otype, config, mode] = GetParam(); + auto dims = getTestDims(config); + bool rowwise = (mode == Rowwise || mode == Both); + bool colwise = (mode == Colwise || mode == Both); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(itype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(otype, OType, + performTest(dims, rowwise, colwise); + ) + ) +} + +static const char *scalingModeName(ScalingMode mode) { + switch (mode) { + case Rowwise: return "rowwise"; + case Colwise: return "colwise"; + case Both: return "both"; + default: return "unknown"; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, MultiQuantizeMXFP8TestSuite, + ::testing::Combine( + ::testing::Values(DType::kBFloat16, DType::kFloat16, DType::kFloat32), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(0, 1, 2), + ::testing::Values(Rowwise, Colwise, Both)), + [](const testing::TestParamInfo &info) { + return test::typeName(std::get<0>(info.param)) + "_" + + test::typeName(std::get<1>(info.param)) + "_config" + + std::to_string(std::get<2>(info.param)) + "_" + + scalingModeName(std::get<3>(info.param)); + }); + +} // namespace diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 4e7e3c4da..34037702c 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -30,6 +30,20 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea dispatch::quantize_fwd_helper(input, output, nullptr, stream); } +#ifdef __HIP_PLATFORM_AMD__ +void nvte_multi_quantize_mxfp8(size_t num_tensors, const NVTETensor *input_list, + NVTETensor *output_list, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_quantize_mxfp8); + using namespace transformer_engine; + std::vector input_list_, output_list_; + for (size_t i = 0; i < num_tensors; i++) { + input_list_.push_back(convertNVTETensorCheck(input_list[i])); + output_list_.push_back(convertNVTETensorCheck(output_list[i])); + } + dispatch::multi_quantize_mxfp8(input_list_, output_list_, stream); +} +#endif + void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize); diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 579caee06..3bb04529c 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -23,8 +23,8 @@ #ifdef __HIP_PLATFORM_AMD__ #include "../fp8/rocm_cast.cuh" #endif -#include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" +#include "../mxfp8/group_quantize_mxfp8.cuh" #ifdef __HIP_PLATFORM_AMD__ #include "../mxfp4/quantize_mxfp4.cuh" #endif //#ifdef __HIP_PLATFORM_AMD__ @@ -509,6 +509,76 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe } } +#ifdef __HIP_PLATFORM_AMD__ +inline void multi_quantize_mxfp8(const std::vector &input_list, + std::vector &output_list, cudaStream_t stream) { + const size_t num_tensors = input_list.size(); + if (num_tensors == 0) return; + + DType itype = input_list[0]->data.dtype; + DType otype = output_list[0]->dtype(); + const bool use_rowwise = output_list[0]->has_data(); + const bool use_colwise = output_list[0]->has_columnwise_data(); + + constexpr size_t CDY = 64; // tile height (rows) + constexpr size_t CDX = 64; // tile width (cols) + constexpr size_t TPC = 128; // threads per block + + mxfp8::quantize_kernel::MultiQuantizeMXFP8Args args; + args.num_tensors = 0; + args.block_range[0] = 0; + int tiles_x = 0; + + for (size_t i = 0; i < num_tensors; i++) { + const int rows = input_list[i]->data.shape[0]; + const int cols = input_list[i]->data.shape[1]; + const int row_tiles = DIVUP(static_cast(rows), CDY); + const int col_tiles = DIVUP(static_cast(cols), CDX); + if (col_tiles > tiles_x) { + tiles_x = col_tiles; + } + const int pos = args.num_tensors; + + args.input_list[pos] = input_list[i]->data.dptr; + args.output_rowwise_list[pos] = use_rowwise ? output_list[i]->data.dptr : nullptr; + args.output_colwise_list[pos] = use_colwise ? output_list[i]->columnwise_data.dptr : nullptr; + args.scales_rowwise_list[pos] = use_rowwise ? output_list[i]->scale_inv.dptr : nullptr; + args.scales_colwise_list[pos] = use_colwise ? output_list[i]->columnwise_scale_inv.dptr : nullptr; + args.amax_list[pos] = reinterpret_cast(output_list[i]->amax.dptr); + args.rows_list[pos] = rows; + args.cols_list[pos] = cols; + args.block_range[pos + 1] = args.block_range[pos] + row_tiles; + args.num_tensors++; + } + + if (args.num_tensors == 0) return; + + bool is_aligned = true; + for (int i = 0; i < args.num_tensors; i++) { + if (args.cols_list[i] % (32 * typeToSize(itype)) != 0) { + is_aligned = false; + break; + } + } + + const dim3 grid(tiles_x, args.block_range[args.num_tensors]); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(itype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(otype, OType, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH((use_colwise ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH((use_rowwise ? 32 : 1), SCALE_DIM_X, + if (is_aligned) { + mxfp8::quantize_kernel::multi_quantize_mxfp8_kernel + <<>>(args); + } else { + mxfp8::quantize_kernel::multi_quantize_mxfp8_kernel + <<>>(args); + } + )); + )); + NVTE_CHECK_CUDA(cudaGetLastError()); +} +#endif // __HIP_PLATFORM_AMD__ + } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index b56e28968..444b5cda3 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -787,7 +787,113 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const Tensor *noop, GroupedTensor *output, GroupedTensor *dbias, Tensor *workspace, cudaStream_t stream) { #ifdef __HIP_PLATFORM_AMD__ - NVTE_ERROR("group_quantize is not supported on ROCm yet."); + using namespace quantize_kernel; + CheckNoopTensor(*noop, "cast_noop"); + + const bool use_rowwise_scaling = output->has_data(); + const bool use_colwise_scaling = output->has_columnwise_data(); + NVTE_CHECK(use_rowwise_scaling || use_colwise_scaling, + "Either rowwise or columnwise output data need to be allocated."); + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); + NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + const bool is_single_tensor = output->all_same_shape() || output->all_same_last_dim(); + + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + const size_t num_tensors = input->num_tensors; + const size_t elts_total = first_logical_dim * last_logical_dim; + + const bool use_large_chunks = (elts_total > 32 * 1024 * 1024); + const size_t CHUNK_DIM_Y = use_large_chunks ? 128 : 64; + const size_t CHUNK_DIM_X = use_large_chunks ? 128 : 64; + const size_t THREADS_PER_CHUNK = use_large_chunks ? 256 : 128; + + size_t blocks_X, blocks_Y; + if (is_single_tensor) { + blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + } else { + blocks_Y = 1; + blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + } + const dim3 grid(blocks_X, blocks_Y); + + const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(output->last_dims.dptr); + + e8m0_t *const scales_rowwise_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + const size_t scale_stride_rowwise = use_rowwise_scaling + ? DIVUP(last_logical_dim, size_t{32}) : 1; + const size_t scale_stride_colwise = use_colwise_scaling ? last_logical_dim : 1; + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input->dtype(), "DBias must have the same type as input."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + const size_t dbias_workspace_rows = blocks_Y; + const size_t dbias_workspace_cols = last_logical_dim; + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_large_chunks, USE_LARGE_CHUNKS, + + constexpr size_t CDY = USE_LARGE_CHUNKS ? 128 : 64; + constexpr size_t CDX = USE_LARGE_CHUNKS ? 128 : 64; + constexpr size_t TPC = USE_LARGE_CHUNKS ? 256 : 128; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(last_logical_dim % (32 * sizeof(IType))), IS_ALIGNED, + grouped_quantize_mxfp8_kernel + <<>>( + reinterpret_cast(input->data.dptr), + (IS_DACT) ? reinterpret_cast(activations->data.dptr) : nullptr, + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr, + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) : nullptr, + scales_rowwise_ptr, scales_colwise_ptr, + noop_ptr, workspace_ptr, amax_ptr, + first_logical_dim, last_logical_dim, + scale_stride_rowwise, scale_stride_colwise, + num_tensors, offsets_ptr, first_dims_ptr, last_dims_ptr, + first_logical_dim); + NVTE_CHECK_CUDA(cudaGetLastError()); + ))))); + + if constexpr (IS_DBIAS) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->dtype(), IType, + common::ShapeRepresentation shape_rep = output->all_same_shape() + ? common::ShapeRepresentation::SAME_BOTH_DIMS : common::ShapeRepresentation::VARYING_FIRST_DIM; + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, + CDY, stream); + ); + } + ); // NOLINT(*) #else using namespace group_quantize_kernel; diff --git a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index fee3db3e0..145248e07 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -189,7 +189,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (row_valid && col_valid) { if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { - out_vec.store_to(&output_act_rowwise[row * output_cols + col_start]); + reinterpret_cast*>(&out_vec)->nt_store( + &output_act_rowwise[row * output_cols + col_start]); } else { #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { @@ -223,7 +224,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (row_valid && col_valid) { if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { - out_vec.store_to(&output_gate_rowwise[row * output_cols + col_start]); + reinterpret_cast*>(&out_vec)->nt_store( + &output_gate_rowwise[row * output_cols + col_start]); } else { #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { @@ -346,7 +348,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (row_valid && col_valid) { if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { - out_vec.store_to(&output_act_rowwise[row * output_cols + col_start]); + reinterpret_cast*>(&out_vec)->nt_store( + &output_act_rowwise[row * output_cols + col_start]); } else { #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { @@ -380,7 +383,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (row_valid && col_valid) { if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { - out_vec.store_to(&output_gate_rowwise[row * output_cols + col_start]); + reinterpret_cast*>(&out_vec)->nt_store( + &output_gate_rowwise[row * output_cols + col_start]); } else { #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { @@ -398,9 +402,29 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_rowwise[scale_idx] = biased_exp; } } + + { + Vec cached_act, cached_gate; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + cached_act.data.elt[j] = static_cast(computed_act[j]); + } + cached_act.store_to(&in_act_sh[shmem_base]); + if constexpr (IS_DGATED) { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + cached_gate.data.elt[j] = static_cast(computed_gate[j]); + } + cached_gate.store_to(&in_gate_sh[shmem_base]); + } + } } if constexpr (USE_COLWISE_SCALING) { + if constexpr (USE_ROWWISE_SCALING) { + __syncthreads(); + } + const bool col_out_of_bounds = (chunk_offset_X + tid_colwise_X >= cols); const size_t row_base = chunk_it_offset_y; const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; @@ -411,22 +435,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float thread_Y_mx_block_amax = 0.0f; float thread_Y_mx_block_amax_gate = 0.0f; - // Compute activation and accumulate column amax for (int stage = 0; stage < BUFFER_STAGES_NUM_COLWISE; ++stage) { const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_COLWISE; const int shmem_offset_y = tid_colwise_Y + stage_offset_Y; const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + tid_colwise_X; - float act_elt = static_cast(in_act_sh[shmem_idx]); - float gate_elt = static_cast(in_gate_sh[shmem_idx]); - float grad_elt = 0.0f; - if constexpr (IS_DGATED) { - grad_elt = static_cast(in_grad_sh[shmem_idx]); + if constexpr (USE_ROWWISE_SCALING) { + after_dact_reg[stage] = static_cast(in_act_sh[shmem_idx]); + if constexpr (IS_DGATED) { + after_dgate_reg[stage] = static_cast(in_gate_sh[shmem_idx]); + } + } else { + float act_elt = static_cast(in_act_sh[shmem_idx]); + float gate_elt = static_cast(in_gate_sh[shmem_idx]); + float grad_elt = 0.0f; + if constexpr (IS_DGATED) { + grad_elt = static_cast(in_grad_sh[shmem_idx]); + } + compute_gated_activation( + act_elt, gate_elt, grad_elt, p, after_dact_reg[stage], after_dgate_reg[stage]); } - compute_gated_activation( - act_elt, gate_elt, grad_elt, p, after_dact_reg[stage], after_dgate_reg[stage]); - __builtin_assume(thread_Y_mx_block_amax >= 0); thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); if constexpr (IS_DGATED) { diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index 7a9a0d696..6f2224d29 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -13,6 +13,21 @@ constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr int kMultiQuantizeMXFP8MaxTensors = 256; + +struct MultiQuantizeMXFP8Args { + const void *input_list[kMultiQuantizeMXFP8MaxTensors]; + void *output_rowwise_list[kMultiQuantizeMXFP8MaxTensors]; + void *output_colwise_list[kMultiQuantizeMXFP8MaxTensors]; + void *scales_rowwise_list[kMultiQuantizeMXFP8MaxTensors]; + void *scales_colwise_list[kMultiQuantizeMXFP8MaxTensors]; + float *amax_list[kMultiQuantizeMXFP8MaxTensors]; + int rows_list[kMultiQuantizeMXFP8MaxTensors]; + int cols_list[kMultiQuantizeMXFP8MaxTensors]; + int block_range[kMultiQuantizeMXFP8MaxTensors + 1]; + int num_tensors; +}; + constexpr size_t ELEMS_PER_THREAD = 16; constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported @@ -20,501 +35,164 @@ constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported typedef short mxfp8_v2i16_t __attribute__((ext_vector_type(2))); #endif // #ifdef HAS_CVT_4xFLOAT8 +// Single tensor quantize template + size_t CHUNK_DIM_Y = 64, size_t CHUNK_DIM_X = 64, size_t THREADS_PER_CHUNK = 64> __global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_mxfp8_kernel( - const IType *input_ptr, - const IType *act_input_ptr, - OType *output_rowwise, - OType *output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + const IType *input_ptr, const IType *act_input_ptr, + OType *output_rowwise, OType *output_colwise, + e8m0_t *scales_rowwise, e8m0_t *scales_colwise, + const float *noop, float *const dbias_workspace, float *amax_ptr, + size_t rows, size_t cols, size_t scale_stride_rowwise, size_t scale_stride_colwise) { if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if (noop != nullptr && noop[0] == 1.0f) return; } - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - - constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - - constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; - constexpr size_t SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; - constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; - - constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; - constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; - constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; - - constexpr size_t BUFF_STAGES_NUM = MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; - constexpr size_t ITERATIONS = CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; - - constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = CHUNK_DIM_Y; - constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; - constexpr size_t SCALES_COLWISE_PER_BLOCK_X = CHUNK_DIM_X; - - constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 - constexpr size_t VECTOR_WIDTH_IN = ROCM_VEC_BYTES / sizeof(IType); // BF16/FP16: 8, FP32: 4 - constexpr size_t VECTOR_WIDTH_OUT = ROCM_VEC_BYTES / sizeof(OType); // FP8: 16 - - const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * CHUNK_DIM_X; - const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; - const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; - const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; - const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; - - const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; - const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; - const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - - const int thread_offset_Y = tid_rowwise_Y; - const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - - const int dbias_rowwise_offset_Y = blockIdx.y + tid_rowwise_Y; - const int dbias_rowwise_block_offset_X = block_offset_X + thread_offset_X_rowwise; - const int dbias_colwise_offset_Y = blockIdx.y; - const int dbias_colwise_block_offset_X = block_offset_X + tid_colwise_X; - const int dbias_stride = cols; + const int block_id_Y = blockIdx.y; + const int block_id_X = blockIdx.x; + const int dbias_y_offset = blockIdx.y; +#include "rocm_quantize_mxfp8_body.inc" +} - Vec partial_dbias_rowwise; - float partial_dbias_colwise = 0; - if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - partial_dbias_rowwise.clear(); - } +// Grouped quantize (contiguous buffer + offsets) +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + grouped_quantize_mxfp8_kernel( + const IType *input_ptr, const IType *act_input_ptr, + OType *output_rowwise, OType *output_colwise, + e8m0_t *scales_rowwise, e8m0_t *scales_colwise, + const float *noop, float *const dbias_workspace, float *amax_ptr, + size_t rows, size_t cols, size_t scale_stride_rowwise, size_t scale_stride_colwise, + const size_t num_tensors, const int64_t *const offsets_ptr, + const int64_t *const first_dims_ptr, const int64_t *const last_dims_ptr, + const size_t first_logical_dim) { + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + if (noop != nullptr && noop[0] == 1.0f) return; } - - float block_amax = 0; - - constexpr size_t ROWS_PER_THREAD = CHUNK_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; - - if constexpr (USE_ROWWISE_SCALING && !USE_COLWISE_SCALING) { - const size_t col_start = block_offset_X + thread_offset_X_rowwise; - const bool col_valid = (col_start < cols); -#pragma unroll - for (size_t r = 0; r < ROWS_PER_THREAD; r++) { - const size_t row = block_offset_Y + tid_rowwise_Y + r * THREADS_PER_CHUNK_Y_ROWWISE; - const bool row_valid = (row < rows); - - Vec in; - Vec act_in; - - if (row_valid && col_valid) { - if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { - in.load_from(&input_ptr[row * cols + col_start]); - if constexpr (IS_DACT) { - act_in.load_from(&act_input_ptr[row * cols + col_start]); - } - } else { -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - in.data.elt[j] = (col_start + j < cols) ? input_ptr[row * cols + col_start + j] - : static_cast(0); - } - if constexpr (IS_DACT) { -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - act_in.data.elt[j] = (col_start + j < cols) ? act_input_ptr[row * cols + col_start + j] - : static_cast(0); - } - } - } - } - - float thread_amax = 0; - float in_compute[ELEMS_PER_THREAD]; - -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); - float elt = static_cast(in.data.elt[j]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[j]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - if (!out_of_bounds) { - partial_dbias_rowwise.data.elt[j] += elt; - } - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - in_compute[j] = elt; - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); - const e8m0_t biased_exponent = - ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); - - { - constexpr size_t SCALES_PER_GROUP = THREADS_PER_CHUNK_X_ROWWISE / THREADS_PER_SCALE_X_ROWWISE; - static_assert(SCALES_PER_GROUP < 4 || SCALES_PER_GROUP % 4 == 0, - "SCALES_PER_GROUP must be < 4 or a multiple of 4"); - uint32_t my_scale = static_cast(biased_exponent); - if constexpr (SCALES_PER_GROUP >= 4) { -#pragma unroll - for (int g = 0; g < SCALES_PER_GROUP / 4; g++) { - uint32_t s0 = __shfl_down(my_scale, (g*4+0) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); - uint32_t s1 = __shfl_down(my_scale, (g*4+1) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); - uint32_t s2 = __shfl_down(my_scale, (g*4+2) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); - uint32_t s3 = __shfl_down(my_scale, (g*4+3) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); - uint32_t packed = (s0 & 0xFF) | ((s1 & 0xFF) << 8) | ((s2 & 0xFF) << 16) | ((s3 & 0xFF) << 24); - if (tid_rowwise_X == 0 && row_valid && col_valid) { - const int scale_idx = row * scale_stride_rowwise + scales_rowwise_block_offset_X; - reinterpret_cast(&scales_rowwise[scale_idx])[g] = packed; - } - } - } else { - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { - const int scale_idx = - row * scale_stride_rowwise + - scales_rowwise_block_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; - scales_rowwise[scale_idx] = biased_exponent; - } - } - } - - Vec out_c; -#ifdef HAS_CVT_4xFLOAT8 - { - const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); - union { - uint32_t packed[ELEMS_PER_THREAD / 4]; - mxfp8_v2i16_t v2i16[ELEMS_PER_THREAD / 4]; - } cvt_out{}; -#pragma unroll - for (int p = 0; p < ELEMS_PER_THREAD / 4; p++) { - cvt_out.packed[p] = rocm_cvt_4xfloat8( - in_compute[p*4+0], in_compute[p*4+1], - in_compute[p*4+2], in_compute[p*4+3], cvt_scale); - } - memcpy(out_c.data.elt, cvt_out.packed, ELEMS_PER_THREAD * sizeof(OType)); - } -#else - { - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); - } - } -#endif // #ifdef HAS_CVT_4xFLOAT8 - - if (row_valid && col_valid) { - if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { - out_c.store_to(&output_rowwise[row * cols + col_start]); - } else { -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - if (col_start + j < cols) { - output_rowwise[row * cols + col_start + j] = out_c.data.elt[j]; - } - } - } + int block_id_Y = blockIdx.y; + int block_id_X = blockIdx.x; + size_t tensor_id; + size_t tensor_base_elts; + + if (last_dims_ptr == nullptr) { + const size_t global_row = static_cast(blockIdx.y) * CHUNK_DIM_Y; + size_t current_offset = global_row * cols; + if (offsets_ptr == nullptr) { + const size_t rows_per_tensor = first_logical_dim / num_tensors; + tensor_id = global_row / rows_per_tensor; + tensor_base_elts = tensor_id * rows_per_tensor * cols; + } else { + size_t lo = 1, hi = num_tensors; + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + if (static_cast(offsets_ptr[mid]) <= current_offset) lo = mid + 1; + else hi = mid; } + tensor_id = lo - 1; + tensor_base_elts = static_cast(offsets_ptr[tensor_id]); } - } - - if constexpr (USE_COLWISE_SCALING) { - alignas(128) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; - alignas(128) __shared__ IType act_in_sh[IS_DACT ? SHMEM_DIM_Y : 1][IS_DACT ? SHMEM_DIM_X : 1]; - alignas(128) __shared__ OType out_colwise_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; - - const size_t col = block_offset_X + tid_colwise_X; - const bool col_valid_colwise = (col < cols); - -#pragma unroll - for (int iter = 0; iter < ITERATIONS; iter++) { - const size_t row_base = block_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - - if constexpr (IS_DACT) { - copy_2d_to_shared( - &act_in_sh[0][0], act_input_ptr, - block_offset_X, row_base, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + rows = (first_dims_ptr != nullptr) + ? static_cast(first_dims_ptr[tensor_id]) : first_logical_dim / num_tensors; + const size_t tensor_base_rows = tensor_base_elts / cols; + block_id_Y = blockIdx.y - static_cast(tensor_base_rows / CHUNK_DIM_Y); + } else { + size_t block_tile = static_cast(blockIdx.x); + size_t tiles_before = 0; + tensor_base_elts = 0; + for (tensor_id = 0; tensor_id < num_tensors; tensor_id++) { + size_t t_rows = first_dims_ptr ? static_cast(first_dims_ptr[tensor_id]) : first_logical_dim; + size_t t_cols = static_cast(last_dims_ptr[tensor_id]); + size_t t_tiles = DIVUP(t_rows, CHUNK_DIM_Y) * DIVUP(t_cols, CHUNK_DIM_X); + if (block_tile < tiles_before + t_tiles) { + rows = t_rows; + cols = t_cols; + size_t local_tile = block_tile - tiles_before; + size_t tiles_x = DIVUP(t_cols, CHUNK_DIM_X); + block_id_Y = static_cast(local_tile / tiles_x); + block_id_X = static_cast(local_tile % tiles_x); + break; } - copy_2d_to_shared( - &in_sh[0][0], input_ptr, - block_offset_X, row_base, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - __syncthreads(); - - if constexpr (USE_ROWWISE_SCALING) { - const size_t col_start = block_offset_X + thread_offset_X_rowwise; - const bool col_valid = (col_start < cols); - -#pragma unroll - for (int stage = 0; stage < BUFF_STAGES_NUM; stage++) { - const int shmem_y = thread_offset_Y + stage * THREADS_PER_CHUNK_Y_ROWWISE; - const size_t row = row_base + shmem_y; - const bool row_valid = (row < rows); - - Vec in; - Vec act_in; - in.load_from(&in_sh[shmem_y][thread_offset_X_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_y][thread_offset_X_rowwise]); - } - - float thread_amax = 0; - float in_compute[ELEMS_PER_THREAD]; - -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); - float elt = static_cast(in.data.elt[j]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[j]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - if (!out_of_bounds) { - partial_dbias_rowwise.data.elt[j] += elt; - } - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - in_compute[j] = elt; - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); - const e8m0_t biased_exponent = - ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); - - { - constexpr size_t SCALES_PER_GROUP = THREADS_PER_CHUNK_X_ROWWISE / THREADS_PER_SCALE_X_ROWWISE; - static_assert(SCALES_PER_GROUP < 4 || SCALES_PER_GROUP % 4 == 0, - "SCALES_PER_GROUP must be < 4 or a multiple of 4"); - uint32_t my_scale = static_cast(biased_exponent); - if constexpr (SCALES_PER_GROUP >= 4) { -#pragma unroll - for (int g = 0; g < SCALES_PER_GROUP / 4; g++) { - uint32_t s0 = __shfl_down(my_scale, (g*4+0) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); - uint32_t s1 = __shfl_down(my_scale, (g*4+1) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); - uint32_t s2 = __shfl_down(my_scale, (g*4+2) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); - uint32_t s3 = __shfl_down(my_scale, (g*4+3) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); - uint32_t packed = (s0 & 0xFF) | ((s1 & 0xFF) << 8) | ((s2 & 0xFF) << 16) | ((s3 & 0xFF) << 24); - if (tid_rowwise_X == 0 && row_valid && col_valid) { - const int scale_idx = row * scale_stride_rowwise + scales_rowwise_block_offset_X; - reinterpret_cast(&scales_rowwise[scale_idx])[g] = packed; - } - } - } else { - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { - const int scale_idx = row * scale_stride_rowwise + - scales_rowwise_block_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; - scales_rowwise[scale_idx] = biased_exponent; - } - } - } - - Vec out_c; -#ifdef HAS_CVT_4xFLOAT8 - { - const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); - union { - uint32_t packed[ELEMS_PER_THREAD / 4]; - mxfp8_v2i16_t v2i16[ELEMS_PER_THREAD / 4]; - } cvt_out{}; -#pragma unroll - for (int p = 0; p < ELEMS_PER_THREAD / 4; p++) { - cvt_out.packed[p] = rocm_cvt_4xfloat8( - in_compute[p*4+0], in_compute[p*4+1], - in_compute[p*4+2], in_compute[p*4+3], cvt_scale); - } - memcpy(out_c.data.elt, cvt_out.packed, ELEMS_PER_THREAD * sizeof(OType)); - } -#else - { - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); - } - } -#endif // #ifdef HAS_CVT_4xFLOAT8 - - if (row_valid && col_valid) { - if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { - out_c.store_to(&output_rowwise[row * cols + col_start]); - } else { -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - if (col_start + j < cols) { - output_rowwise[row * cols + col_start + j] = out_c.data.elt[j]; - } - } - } - } - } - } - - if (threadIdx.x < CHUNK_DIM_X) { - float in_compute[SCALE_DIM_Y]; - float amax = 0; - -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; i++) { - const size_t row = row_base + i; - const bool out_of_bounds = (!col_valid_colwise || row >= rows); - - float elt = static_cast(in_sh[i][tid_colwise_X]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[i][tid_colwise_X]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if (!out_of_bounds) { - partial_dbias_colwise += elt; - } - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - in_compute[i] = elt; - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(amax >= 0); - block_amax = fmaxf(block_amax, amax); - - const e8m0_t biased_exponent = ptx::float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); - - if (col_valid_colwise && row_base < rows) { - const int scale_idx = - (scales_colwise_block_offset_Y + iter) * scale_stride_colwise + col; - scales_colwise[scale_idx] = biased_exponent; - } - -#ifdef HAS_CVT_4xFLOAT8 - { - const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; i += 2) { - uint32_t packed = rocm_cvt_4xfloat8( - in_compute[i], in_compute[i+1], 0.0f, 0.0f, cvt_scale); - OType val0, val1; - memcpy(&val0, &packed, sizeof(OType)); - memcpy(&val1, reinterpret_cast(&packed) + 1, sizeof(OType)); - out_colwise_sh[i][tid_colwise_X] = val0; - if (i + 1 < SCALE_DIM_Y) { - out_colwise_sh[i+1][tid_colwise_X] = val1; - } - } - } -#else - { - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; i++) { - out_colwise_sh[i][tid_colwise_X] = - static_cast(in_compute[i] * block_scale_inverse); - } - } -#endif // #ifdef HAS_CVT_4xFLOAT8 - } - - __syncthreads(); - - bulk_tensor_2d_shared_to_global( - &out_colwise_sh[0][0], output_colwise, - block_offset_X, row_base, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - - __syncthreads(); + tiles_before += t_tiles; + tensor_base_elts += t_rows * t_cols; } + if (tensor_id >= num_tensors) return; } - if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; - constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; - __shared__ float shmem_partial_dbias_rowwise[Y][X][ELEMS_PER_THREAD]; - - if (tid_rowwise_Y > 0) { - partial_dbias_rowwise.store_to( - &shmem_partial_dbias_rowwise[tid_rowwise_Y - 1][tid_rowwise_X]); - } - __syncthreads(); - - if (tid_rowwise_Y == 0) { - Vec other_row_dbias; - const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_block_offset_X; - const int left_bound = dbias_rowwise_block_offset_X; - const int right_bound = dbias_rowwise_block_offset_X + ELEMS_PER_THREAD - 1; - -#pragma unroll - for (int i = 0; i < Y; i++) { - other_row_dbias.load_from(&shmem_partial_dbias_rowwise[i][tid_rowwise_X]); -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - partial_dbias_rowwise.data.elt[j] += other_row_dbias.data.elt[j]; - } - } - - if (right_bound < cols) { - partial_dbias_rowwise.store_to(&dbias_workspace[dbias_offset]); - } else if (left_bound < cols && right_bound >= cols) { - const int in_bound_elts_count = cols - left_bound; - partial_dbias_rowwise.store_to_elts(&dbias_workspace[dbias_offset], 0, - in_bound_elts_count); - } - } - } else { - if (threadIdx.x < CHUNK_DIM_X) { - const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_block_offset_X; - const bool col_out_of_bounds = (dbias_colwise_block_offset_X >= cols); - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias_colwise; - } - } + if (static_cast(block_id_Y) * CHUNK_DIM_Y >= rows) return; + if (static_cast(block_id_X) * CHUNK_DIM_X >= cols) return; + + input_ptr += tensor_base_elts; + if (act_input_ptr) act_input_ptr += tensor_base_elts; + if (output_rowwise) output_rowwise += tensor_base_elts; + if (output_colwise) output_colwise += tensor_base_elts; + + scale_stride_rowwise = DIVUP(cols, SCALE_DIM_X); + scale_stride_colwise = cols; + + if (last_dims_ptr == nullptr) { + const size_t tensor_base_rows = tensor_base_elts / cols; + if (scales_rowwise) scales_rowwise += tensor_base_rows * scale_stride_rowwise; + if (scales_colwise) scales_colwise += (tensor_base_rows / SCALE_DIM_Y) * scale_stride_colwise; + } else { + size_t rowwise_offset = 0, colwise_offset = 0; + for (size_t t = 0; t < tensor_id; t++) { + size_t t_rows = first_dims_ptr ? static_cast(first_dims_ptr[t]) : first_logical_dim; + size_t t_cols = static_cast(last_dims_ptr[t]); + rowwise_offset += t_rows * DIVUP(t_cols, SCALE_DIM_X); + colwise_offset += DIVUP(t_rows, SCALE_DIM_Y) * t_cols; } + if (scales_rowwise) scales_rowwise += rowwise_offset; + if (scales_colwise) scales_colwise += colwise_offset; } - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - block_amax = reduce_max(block_amax, warp_id); - } + const int dbias_y_offset = blockIdx.y; +#include "rocm_quantize_mxfp8_body.inc" +} - if (threadIdx.x == 0 && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); +// Multi-tensor quantize (per-tensor pointers) +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + multi_quantize_mxfp8_kernel(MultiQuantizeMXFP8Args args) { + const int row_tile = blockIdx.y; + int lo = 0, hi = args.num_tensors - 1; + while (lo < hi) { + int mid = lo + (hi - lo) / 2; + if (args.block_range[mid + 1] <= row_tile) lo = mid + 1; + else hi = mid; } + const int tensor_id = lo; + const size_t rows = args.rows_list[tensor_id]; + const size_t cols = args.cols_list[tensor_id]; + if (rows == 0) return; + + const int block_id_Y = row_tile - args.block_range[tensor_id]; + const int block_id_X = blockIdx.x; + if (static_cast(block_id_Y) * CHUNK_DIM_Y >= rows) return; + if (static_cast(block_id_X) * CHUNK_DIM_X >= cols) return; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const Empty &) = nullptr; + + const IType *input_ptr = reinterpret_cast(args.input_list[tensor_id]); + const IType *act_input_ptr = nullptr; + OType *output_rowwise = reinterpret_cast(args.output_rowwise_list[tensor_id]); + OType *output_colwise = reinterpret_cast(args.output_colwise_list[tensor_id]); + e8m0_t *scales_rowwise = reinterpret_cast(args.scales_rowwise_list[tensor_id]); + e8m0_t *scales_colwise = reinterpret_cast(args.scales_colwise_list[tensor_id]); + float *dbias_workspace = nullptr; + float *amax_ptr = args.amax_list[tensor_id]; + const size_t scale_stride_rowwise = DIVUP(cols, SCALE_DIM_X); + const size_t scale_stride_colwise = cols; + const int dbias_y_offset = block_id_Y; +#include "rocm_quantize_mxfp8_body.inc" } + diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8_body.inc b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8_body.inc new file mode 100644 index 000000000..404e1b6e0 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8_body.inc @@ -0,0 +1,502 @@ +/************************************************************************* + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + + constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; + constexpr size_t SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; + constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; + + constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; + constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; + constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; + + constexpr size_t BUFF_STAGES_NUM = MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; + constexpr size_t ITERATIONS = CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; + + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + constexpr size_t VECTOR_WIDTH_IN = ROCM_VEC_BYTES / sizeof(IType); // BF16/FP16: 8, FP32: 4 + constexpr size_t VECTOR_WIDTH_OUT = ROCM_VEC_BYTES / sizeof(OType); // FP8: 16 + + const int block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const int block_offset_X = block_id_X * CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const int scales_rowwise_block_offset_X = block_id_X * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = block_id_Y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = block_id_X * CHUNK_DIM_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + + const int dbias_rowwise_offset_Y = dbias_y_offset + tid_rowwise_Y; + const int dbias_rowwise_block_offset_X = block_offset_X + thread_offset_X_rowwise; + const int dbias_colwise_offset_Y = dbias_y_offset; + const int dbias_colwise_block_offset_X = block_offset_X + tid_colwise_X; + const int dbias_stride = cols; + + Vec partial_dbias_rowwise; + float partial_dbias_colwise = 0; + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + partial_dbias_rowwise.clear(); + } + } + + float block_amax = 0; + + constexpr size_t ROWS_PER_THREAD = CHUNK_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; + + if constexpr (USE_ROWWISE_SCALING && !USE_COLWISE_SCALING) { + const size_t col_start = block_offset_X + thread_offset_X_rowwise; + const bool col_valid = (col_start < cols); +#pragma unroll + for (size_t r = 0; r < ROWS_PER_THREAD; r++) { + const size_t row = block_offset_Y + tid_rowwise_Y + r * THREADS_PER_CHUNK_Y_ROWWISE; + const bool row_valid = (row < rows); + + Vec in; + Vec act_in; + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + in.load_from(&input_ptr[row * cols + col_start]); + if constexpr (IS_DACT) { + act_in.load_from(&act_input_ptr[row * cols + col_start]); + } + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + in.data.elt[j] = (col_start + j < cols) ? input_ptr[row * cols + col_start + j] + : static_cast(0); + } + if constexpr (IS_DACT) { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + act_in.data.elt[j] = (col_start + j < cols) ? act_input_ptr[row * cols + col_start + j] + : static_cast(0); + } + } + } + } + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise.data.elt[j] += elt; + } + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute[j] = elt; + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = + ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + + { + constexpr size_t SCALES_PER_GROUP = THREADS_PER_CHUNK_X_ROWWISE / THREADS_PER_SCALE_X_ROWWISE; + static_assert(SCALES_PER_GROUP < 4 || SCALES_PER_GROUP % 4 == 0, + "SCALES_PER_GROUP must be < 4 or a multiple of 4"); + uint32_t my_scale = static_cast(biased_exponent); + if constexpr (SCALES_PER_GROUP >= 4) { +#pragma unroll + for (int g = 0; g < SCALES_PER_GROUP / 4; g++) { + uint32_t s0 = __shfl_down(my_scale, (g*4+0) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s1 = __shfl_down(my_scale, (g*4+1) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s2 = __shfl_down(my_scale, (g*4+2) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s3 = __shfl_down(my_scale, (g*4+3) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t packed = (s0 & 0xFF) | ((s1 & 0xFF) << 8) | ((s2 & 0xFF) << 16) | ((s3 & 0xFF) << 24); + if (tid_rowwise_X == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + scales_rowwise_block_offset_X; + reinterpret_cast(&scales_rowwise[scale_idx])[g] = packed; + } + } + } else { + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = + row * scale_stride_rowwise + + scales_rowwise_block_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exponent; + } + } + } + + Vec out_c; +#ifdef HAS_CVT_4xFLOAT8 + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); + union { + uint32_t packed[ELEMS_PER_THREAD / 4]; + mxfp8_v2i16_t v2i16[ELEMS_PER_THREAD / 4]; + } cvt_out{}; +#pragma unroll + for (int p = 0; p < ELEMS_PER_THREAD / 4; p++) { + cvt_out.packed[p] = rocm_cvt_4xfloat8( + in_compute[p*4+0], in_compute[p*4+1], + in_compute[p*4+2], in_compute[p*4+3], cvt_scale); + } + memcpy(out_c.data.elt, cvt_out.packed, ELEMS_PER_THREAD * sizeof(OType)); + } +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + } +#endif // #ifdef HAS_CVT_4xFLOAT8 + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + reinterpret_cast*>(&out_c)->nt_store( + &output_rowwise[row * cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_rowwise[row * cols + col_start + j] = out_c.data.elt[j]; + } + } + } + } + } + } + + if constexpr (USE_COLWISE_SCALING) { + alignas(128) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(128) __shared__ IType act_in_sh[IS_DACT ? SHMEM_DIM_Y : 1][IS_DACT ? SHMEM_DIM_X : 1]; + alignas(128) __shared__ OType out_colwise_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + + const size_t col = block_offset_X + tid_colwise_X; + const bool col_valid_colwise = (col < cols); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; iter++) { + const size_t row_base = block_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + if constexpr (IS_DACT) { + copy_2d_to_shared( + &act_in_sh[0][0], act_input_ptr, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + } + copy_2d_to_shared( + &in_sh[0][0], input_ptr, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + __syncthreads(); + + if constexpr (USE_ROWWISE_SCALING) { + const size_t col_start = block_offset_X + thread_offset_X_rowwise; + const bool col_valid = (col_start < cols); + +#pragma unroll + for (int stage = 0; stage < BUFF_STAGES_NUM; stage++) { + const int shmem_y = thread_offset_Y + stage * THREADS_PER_CHUNK_Y_ROWWISE; + const size_t row = row_base + shmem_y; + const bool row_valid = (row < rows); + + Vec in; + Vec act_in; + in.load_from(&in_sh[shmem_y][thread_offset_X_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_y][thread_offset_X_rowwise]); + } + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise.data.elt[j] += elt; + } + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute[j] = elt; + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = + ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + + { + constexpr size_t SCALES_PER_GROUP = THREADS_PER_CHUNK_X_ROWWISE / THREADS_PER_SCALE_X_ROWWISE; + static_assert(SCALES_PER_GROUP < 4 || SCALES_PER_GROUP % 4 == 0, + "SCALES_PER_GROUP must be < 4 or a multiple of 4"); + uint32_t my_scale = static_cast(biased_exponent); + if constexpr (SCALES_PER_GROUP >= 4) { +#pragma unroll + for (int g = 0; g < SCALES_PER_GROUP / 4; g++) { + uint32_t s0 = __shfl_down(my_scale, (g*4+0) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s1 = __shfl_down(my_scale, (g*4+1) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s2 = __shfl_down(my_scale, (g*4+2) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s3 = __shfl_down(my_scale, (g*4+3) * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t packed = (s0 & 0xFF) | ((s1 & 0xFF) << 8) | ((s2 & 0xFF) << 16) | ((s3 & 0xFF) << 24); + if (tid_rowwise_X == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + scales_rowwise_block_offset_X; + reinterpret_cast(&scales_rowwise[scale_idx])[g] = packed; + } + } + } else { + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_block_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exponent; + } + } + } + + Vec out_c; +#ifdef HAS_CVT_4xFLOAT8 + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); + union { + uint32_t packed[ELEMS_PER_THREAD / 4]; + mxfp8_v2i16_t v2i16[ELEMS_PER_THREAD / 4]; + } cvt_out{}; +#pragma unroll + for (int p = 0; p < ELEMS_PER_THREAD / 4; p++) { + cvt_out.packed[p] = rocm_cvt_4xfloat8( + in_compute[p*4+0], in_compute[p*4+1], + in_compute[p*4+2], in_compute[p*4+3], cvt_scale); + } + memcpy(out_c.data.elt, cvt_out.packed, ELEMS_PER_THREAD * sizeof(OType)); + } +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + } +#endif // #ifdef HAS_CVT_4xFLOAT8 + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + reinterpret_cast*>(&out_c)->nt_store( + &output_rowwise[row * cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_rowwise[row * cols + col_start + j] = out_c.data.elt[j]; + } + } + } + } + + if constexpr (IS_ACT || IS_DACT) { + Vec cached; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + cached.data.elt[j] = static_cast(in_compute[j]); + } + cached.store_to(&in_sh[shmem_y][thread_offset_X_rowwise]); + } + } + } + + if constexpr (IS_ACT || IS_DACT) { + __syncthreads(); + } + + if (threadIdx.x < CHUNK_DIM_X) { + float in_compute[SCALE_DIM_Y]; + float amax = 0; + +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; i++) { + const size_t row = row_base + i; + const bool out_of_bounds = (!col_valid_colwise || row >= rows); + + float elt = static_cast(in_sh[i][tid_colwise_X]); + constexpr bool ACT_CACHED = USE_ROWWISE_SCALING && (IS_ACT || IS_DACT); + if constexpr (!ACT_CACHED) { + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[i][tid_colwise_X]); + elt *= OP(act_in_elt, {}); + } + } + if constexpr (IS_DBIAS) { + if (!out_of_bounds) { + partial_dbias_colwise += elt; + } + } + if constexpr (!ACT_CACHED) { + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + } + in_compute[i] = elt; + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(amax >= 0); + block_amax = fmaxf(block_amax, amax); + + const e8m0_t biased_exponent = ptx::float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + + if (col_valid_colwise && row_base < rows) { + const int scale_idx = + (scales_colwise_block_offset_Y + iter) * scale_stride_colwise + col; + scales_colwise[scale_idx] = biased_exponent; + } + +#ifdef HAS_CVT_4xFLOAT8 + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; i += 2) { + uint32_t packed = rocm_cvt_4xfloat8( + in_compute[i], in_compute[i+1], 0.0f, 0.0f, cvt_scale); + OType val0, val1; + memcpy(&val0, &packed, sizeof(OType)); + memcpy(&val1, reinterpret_cast(&packed) + 1, sizeof(OType)); + out_colwise_sh[i][tid_colwise_X] = val0; + if (i + 1 < SCALE_DIM_Y) { + out_colwise_sh[i+1][tid_colwise_X] = val1; + } + } + } +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; i++) { + out_colwise_sh[i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } +#endif // #ifdef HAS_CVT_4xFLOAT8 + } + + __syncthreads(); + + bulk_tensor_2d_shared_to_global( + &out_colwise_sh[0][0], output_colwise, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + + __syncthreads(); + } + } + + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; + constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; + __shared__ float shmem_partial_dbias_rowwise[Y][X][ELEMS_PER_THREAD]; + + if (tid_rowwise_Y > 0) { + partial_dbias_rowwise.store_to( + &shmem_partial_dbias_rowwise[tid_rowwise_Y - 1][tid_rowwise_X]); + } + __syncthreads(); + + if (tid_rowwise_Y == 0) { + Vec other_row_dbias; + const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_block_offset_X; + const int left_bound = dbias_rowwise_block_offset_X; + const int right_bound = dbias_rowwise_block_offset_X + ELEMS_PER_THREAD - 1; + +#pragma unroll + for (int i = 0; i < Y; i++) { + other_row_dbias.load_from(&shmem_partial_dbias_rowwise[i][tid_rowwise_X]); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + partial_dbias_rowwise.data.elt[j] += other_row_dbias.data.elt[j]; + } + } + + if (right_bound < cols) { + partial_dbias_rowwise.store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise.store_to_elts(&dbias_workspace[dbias_offset], 0, + in_bound_elts_count); + } + } + } else { + if (threadIdx.x < CHUNK_DIM_X) { + const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_block_offset_X; + const bool col_out_of_bounds = (dbias_colwise_block_offset_X >= cols); + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias_colwise; + } + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + block_amax = reduce_max(block_amax, warp_id); + } + + if (threadIdx.x == 0 && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } diff --git a/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh index 81dc46a85..6a881ffcc 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh @@ -7,6 +7,7 @@ #pragma once #include "../../util/vectorized_pointwise.h" +#include "../../util/rocm_device_utils.cuh" namespace transformer_engine { // These 2d copy functions replace TMA tensormap async copies for AMD GPUs. @@ -64,17 +65,21 @@ __device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T * size_t g_row = g_start_row + l_y; size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); - - T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); - - shared_loader.load(l_x_vec, chunk_dim_x); - if (g_row < total_rows) { - global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; - global_storer.store(g_col_primitive_start / N_VEC, total_cols); + const T *sh_row = sh_ptr_base + l_y * chunk_dim_x; + T *g_row_ptr = g_ptr + g_row * g_stride; + + if (ALIGNED_ACCESS || g_col_primitive_start + N_VEC <= total_cols) { + NTVec v; + v.load(sh_row + l_x_vec * N_VEC); + v.nt_store(g_row_ptr + g_col_primitive_start); + } else { + for (int i = 0; i < N_VEC; i++) { + if (g_col_primitive_start + i < total_cols) { + g_row_ptr[g_col_primitive_start + i] = sh_row[l_x_vec * N_VEC + i]; + } + } + } } } } diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 755052d6d..aaf581fc7 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -101,6 +101,19 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); +#ifdef __HIP_PLATFORM_AMD__ +/*! \brief Fused multi-tensor MXFP8 quantize. Quantizes multiple tensors in a single kernel launch. + * Each tensor can have different shapes. Output tensors are written to per-tensor pointers. + * + * \param[in] num_tensors Number of tensors to quantize. + * \param[in] input_list Array of input tensors. + * \param[in,out] output_list Array of output MXFP8 tensors. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_quantize_mxfp8(size_t num_tensors, const NVTETensor *input_list, + NVTETensor *output_list, cudaStream_t stream); +#endif + /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9ed5502b7..fb28c9834 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -291,6 +291,20 @@ void multi_tensor_quantize_impl(const std::vector &input_list, } // Launch TE kernel +#ifdef USE_ROCM + if (num_tensors > 0 && detail::IsMXFP8Quantizers(quantizer_py_list[0].ptr())) { + std::vector nvte_input_list, nvte_output_list; + for (size_t i = 0; i < num_tensors; i++) { + nvte_input_list.push_back(input_list[i].data()); + nvte_output_list.push_back(output_list[i].data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_quantize_mxfp8(nvte_input_list.size(), nvte_input_list.data(), + nvte_output_list.data(), at::cuda::getCurrentCUDAStream()); + }); + return; + } +#endif if (with_fused_kernel) { // Fused kernel for multi-tensor quantize std::vector nvte_tensor_input_list; diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 63a460276..0d5700bac 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -228,6 +228,16 @@ def get_scale_shape( Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout """ + if IS_HIP_EXTENSION: + if columnwise: + return ( + math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, + shape[-1], + ) + return ( + math.prod(shape[:-1]), + shape[-1] // MXFP8_BLOCK_SCALING_SIZE, + ) if columnwise: # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] # with padding to multiples of [4, 128]