From 22177c9a58d8fce6588aeeb08b3afac6677e74a8 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 11 Mar 2025 23:35:10 +0000 Subject: [PATCH 1/9] Do not suppress MXFP8 norm in Python wrapper func Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/_common.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index cd18808465..721fb4a04a 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -86,26 +86,16 @@ def apply_normalization( inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) - split_mxfp8_cast = False - if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer): - split_mxfp8_cast = True - - output = normalization_func( + return normalization_func( *inputs, eps, - None if split_mxfp8_cast else ln_out, - None if split_mxfp8_cast else output_quantizer, + ln_out, + output_quantizer, TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, fwd_ln_sm_margin, zero_centered_gamma, ) - return ( - (output_quantizer.quantize(output[0], out=ln_out), *output[1:]) - if split_mxfp8_cast - else output - ) - class _NoopCatFunc(torch.autograd.Function): """Concatenate tensors, doing a no-op if possible From e09ed7efa553043d1a03ef0ef68e1a24e4fe2227 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 12 Mar 2025 00:22:17 +0000 Subject: [PATCH 2/9] Support FP8 current scaling in tex norm functions Signed-off-by: Tim Moon --- .../pytorch/csrc/extensions/normalization.cpp | 82 +++++++++++++++---- .../pytorch/module/layernorm_linear.py | 6 -- .../pytorch/module/layernorm_mlp.py | 2 - 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index bb011faf98..e9a138ff07 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -107,14 +107,17 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } // Determine whether to avoid fused kernel - bool force_unfused_kernel = false; - if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { - if (!transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { - // TE only supports MXFP8 norm with cuDNN backend - force_unfused_kernel = true; - } else if (N % 128 != 0 || H % 128 != 0) { - // cuDNN norm requires full tile for MXFP8 - force_unfused_kernel = true; + bool force_unfused_kernel = true; + if (quantizer.is_none()) { + // No need for separate quantization step if output is unquantized + force_unfused_kernel = false; + } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { + // Always used fused kernel for FP8 delayed scaling + force_unfused_kernel = false; + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + if (transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { + // cuDNN MXFP8 kernel requires full tile + force_unfused_kernel = N % 128 == 0 && H % 128 == 0; } } TensorWrapper unquantized_out_cu; @@ -145,6 +148,28 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + te_output.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, at::cuda::getCurrentCUDAStream()); } @@ -223,14 +248,17 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } // Determine whether to avoid fused kernel - bool force_unfused_kernel = false; - if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { - if (!transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { - // TE only supports MXFP8 norm with cuDNN backend - force_unfused_kernel = true; - } else if (N % 128 != 0 || H % 128 != 0) { - // cuDNN norm requires full tile for MXFP8 - force_unfused_kernel = true; + bool force_unfused_kernel = true; + if (quantizer.is_none()) { + // No need for separate quantization step if output is unquantized + force_unfused_kernel = false; + } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { + // Always used fused kernel for FP8 delayed scaling + force_unfused_kernel = false; + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + if (transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { + // cuDNN MXFP8 kernel requires full tile + force_unfused_kernel = N % 128 == 0 && H % 128 == 0; } } TensorWrapper unquantized_out_cu; @@ -261,6 +289,28 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + te_output.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7571b17c1f..a576f24ed3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -56,7 +56,6 @@ restore_from_saved, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpp_extensions import ( general_gemm, @@ -160,11 +159,6 @@ def forward( # Configure quantizer for normalization output with_quantized_norm = fp8 and not return_layernorm_output - # for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer - # so we need to set with_quantized_norm to False - if isinstance(input_quantizer, Float8CurrentScalingQuantizer): - with_quantized_norm = False - if with_quantized_norm: if with_input_all_gather: input_quantizer.set_usage(rowwise=True, columnwise=False) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9bb76cb391..5a911a913d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -214,8 +214,6 @@ def forward( # for return_layernorm_output: layernorm output = High precision, then cast to FP8 # high precision layernorm output and output of the linear are returned with_quantized_norm = fp8 and not return_layernorm_output - if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): - with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output From 01a4d046939c8505b5b0aeb70322f8c88c89262a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 12 Mar 2025 00:29:03 +0000 Subject: [PATCH 3/9] Use single envvar to enable cuDNN MXFP8 norm kernels Signed-off-by: Tim Moon --- qa/L0_pytorch_unittest/test.sh | 2 +- transformer_engine/pytorch/csrc/extensions/normalization.cpp | 4 ++-- transformer_engine/pytorch/module/_common.py | 2 -- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e2fe2c0200..2277dc76a6 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -14,7 +14,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 -NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1 +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1 diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index e9a138ff07..447c0740f2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -115,7 +115,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Always used fused kernel for FP8 delayed scaling force_unfused_kernel = false; } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // cuDNN MXFP8 kernel requires full tile force_unfused_kernel = N % 128 == 0 && H % 128 == 0; } @@ -256,7 +256,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Always used fused kernel for FP8 delayed scaling force_unfused_kernel = false; } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // cuDNN MXFP8 kernel requires full tile force_unfused_kernel = N % 128 == 0 && H % 128 == 0; } diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 721fb4a04a..549eae53ea 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -18,8 +18,6 @@ from ..tensor.float8_tensor import Float8Tensor from ..tensor.mxfp8_tensor import MXFP8Quantizer -_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0"))) - def _get_normalization_func(normalization: str, forward: bool): fwd_normalization_funcs = { From 5051c4f991c6c81ccda19f21245488111f80b2ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Mar 2025 00:30:34 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/normalization.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 447c0740f2..7b1af3eead 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -150,13 +150,14 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe if (force_unfused_kernel) { if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); + auto my_quantizer_cs = static_cast(my_quantizer.get()); nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + c10::intrusive_ptr process_group_ptr = + my_quantizer_cs->amax_reduction_group; // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; std::vector tensors = {amax_tensor_torch}; // allreduce amax tensor c10d::AllreduceOptions allreduce_opts; @@ -291,13 +292,14 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w if (force_unfused_kernel) { if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); + auto my_quantizer_cs = static_cast(my_quantizer.get()); nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + c10::intrusive_ptr process_group_ptr = + my_quantizer_cs->amax_reduction_group; // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; std::vector tensors = {amax_tensor_torch}; // allreduce amax tensor c10d::AllreduceOptions allreduce_opts; From 12f1aa865b25aa97616258b1eb5f5931a7fa2301 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 12 Mar 2025 17:56:48 +0000 Subject: [PATCH 5/9] Debug compilation error Signed-off-by: Tim Moon --- .../pytorch/csrc/extensions/normalization.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 7b1af3eead..0cc679d48f 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -6,12 +6,13 @@ #include "common/util/system.h" #include "extensions.h" +#include "pybind.h" namespace transformer_engine::pytorch { std::pair createOutputTensor(const NVTEShape &shape, DType dtype, py::handle quantizer) { std::vector shape_vec; - for (int i = 0; i < shape.ndim; i++) { + for (size_t i = 0; i < shape.ndim; i++) { size_t t = shape.data[i]; shape_vec.push_back(t); } @@ -74,6 +75,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe float eps, py::object out, py::handle quantizer, DType out_dtype, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch; using namespace transformer_engine; @@ -111,10 +113,10 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe if (quantizer.is_none()) { // No need for separate quantization step if output is unquantized force_unfused_kernel = false; - } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { + } else if (IsFloat8Quantizers(quantizer.ptr())) { // Always used fused kernel for FP8 delayed scaling force_unfused_kernel = false; - } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + } else if (IsMXFP8Quantizers(quantizer.ptr())) { if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // cuDNN MXFP8 kernel requires full tile force_unfused_kernel = N % 128 == 0 && H % 128 == 0; @@ -148,7 +150,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); @@ -222,6 +224,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w py::object out, py::handle quantizer, transformer_engine::DType out_dtype, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch; using namespace transformer_engine; @@ -253,10 +256,10 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w if (quantizer.is_none()) { // No need for separate quantization step if output is unquantized force_unfused_kernel = false; - } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { + } else if (IsFloat8Quantizers(quantizer.ptr())) { // Always used fused kernel for FP8 delayed scaling force_unfused_kernel = false; - } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + } else if (IsMXFP8Quantizers(quantizer.ptr())) { if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // cuDNN MXFP8 kernel requires full tile force_unfused_kernel = N % 128 == 0 && H % 128 == 0; @@ -290,7 +293,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); From c82bdb60972adeaf32dc7371917bde35a48754be Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 14 Mar 2025 00:44:33 +0000 Subject: [PATCH 6/9] Fix compilation error Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/extensions/normalization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 0cc679d48f..5fd042b599 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -171,7 +171,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - te_output.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); } nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, at::cuda::getCurrentCUDAStream()); @@ -314,7 +314,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - te_output.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); } nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, at::cuda::getCurrentCUDAStream()); From f6188561b9d06447bd00271d2ec2e1f7e663b84e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 14 Mar 2025 00:55:31 +0000 Subject: [PATCH 7/9] Fix full-tile requirement for MXFP8 norm kernels Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/extensions/normalization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 5fd042b599..cbdeee5b48 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -119,7 +119,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (IsMXFP8Quantizers(quantizer.ptr())) { if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 == 0 && H % 128 == 0; + force_unfused_kernel = N % 128 != 0 || H % 128 != 0; } } TensorWrapper unquantized_out_cu; @@ -262,7 +262,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (IsMXFP8Quantizers(quantizer.ptr())) { if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 == 0 && H % 128 == 0; + force_unfused_kernel = N % 128 != 0 || H % 128 != 0; } } TensorWrapper unquantized_out_cu; From be0afe5cbe252f631aa6709e5b858c92df584cd0 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 14 Mar 2025 00:59:54 +0000 Subject: [PATCH 8/9] Remove unused imports Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/_common.py | 2 -- transformer_engine/pytorch/module/layernorm_mlp.py | 1 - 2 files changed, 3 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 549eae53ea..c2b525ab55 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -4,7 +4,6 @@ """Internal function used by multiple modules.""" -import os from typing import Any, List, Optional, Tuple, Union, Callable from dataclasses import dataclass from functools import reduce @@ -16,7 +15,6 @@ from ..constants import TE_DType from ..utils import get_default_init_method from ..tensor.float8_tensor import Float8Tensor -from ..tensor.mxfp8_tensor import MXFP8Quantizer def _get_normalization_func(normalization: str, forward: bool): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 87ab4bf5af..f54c7a925c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -60,7 +60,6 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer, From 192f4de30adc3120b964fcd1f29bbacf94d88701 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 15 Mar 2025 02:46:46 +0000 Subject: [PATCH 9/9] Add missing imports Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/layernorm_linear.py | 1 + transformer_engine/pytorch/module/layernorm_mlp.py | 1 + 2 files changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 82adc29946..6be1078dc5 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -55,6 +55,7 @@ prepare_for_saving, restore_from_saved, ) +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 572905a571..51e8905223 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -60,6 +60,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer,