diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 732f0a16d1..03bffce751 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -26,7 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index bb011faf98..cbdeee5b48 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -6,12 +6,13 @@ #include "common/util/system.h" #include "extensions.h" +#include "pybind.h" namespace transformer_engine::pytorch { std::pair createOutputTensor(const NVTEShape &shape, DType dtype, py::handle quantizer) { std::vector shape_vec; - for (int i = 0; i < shape.ndim; i++) { + for (size_t i = 0; i < shape.ndim; i++) { size_t t = shape.data[i]; shape_vec.push_back(t); } @@ -74,6 +75,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe float eps, py::object out, py::handle quantizer, DType out_dtype, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch; using namespace transformer_engine; @@ -107,14 +109,17 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } // Determine whether to avoid fused kernel - bool force_unfused_kernel = false; - if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { - if (!transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { - // TE only supports MXFP8 norm with cuDNN backend - force_unfused_kernel = true; - } else if (N % 128 != 0 || H % 128 != 0) { - // cuDNN norm requires full tile for MXFP8 - force_unfused_kernel = true; + bool force_unfused_kernel = true; + if (quantizer.is_none()) { + // No need for separate quantization step if output is unquantized + force_unfused_kernel = false; + } else if (IsFloat8Quantizers(quantizer.ptr())) { + // Always used fused kernel for FP8 delayed scaling + force_unfused_kernel = false; + } else if (IsMXFP8Quantizers(quantizer.ptr())) { + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // cuDNN MXFP8 kernel requires full tile + force_unfused_kernel = N % 128 != 0 || H % 128 != 0; } } TensorWrapper unquantized_out_cu; @@ -145,6 +150,29 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = + my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, at::cuda::getCurrentCUDAStream()); } @@ -196,6 +224,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w py::object out, py::handle quantizer, transformer_engine::DType out_dtype, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch; using namespace transformer_engine; @@ -223,14 +252,17 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } // Determine whether to avoid fused kernel - bool force_unfused_kernel = false; - if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { - if (!transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { - // TE only supports MXFP8 norm with cuDNN backend - force_unfused_kernel = true; - } else if (N % 128 != 0 || H % 128 != 0) { - // cuDNN norm requires full tile for MXFP8 - force_unfused_kernel = true; + bool force_unfused_kernel = true; + if (quantizer.is_none()) { + // No need for separate quantization step if output is unquantized + force_unfused_kernel = false; + } else if (IsFloat8Quantizers(quantizer.ptr())) { + // Always used fused kernel for FP8 delayed scaling + force_unfused_kernel = false; + } else if (IsMXFP8Quantizers(quantizer.ptr())) { + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // cuDNN MXFP8 kernel requires full tile + force_unfused_kernel = N % 128 != 0 || H % 128 != 0; } } TensorWrapper unquantized_out_cu; @@ -261,6 +293,29 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = + my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); + } nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index cd18808465..c2b525ab55 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -4,7 +4,6 @@ """Internal function used by multiple modules.""" -import os from typing import Any, List, Optional, Tuple, Union, Callable from dataclasses import dataclass from functools import reduce @@ -16,9 +15,6 @@ from ..constants import TE_DType from ..utils import get_default_init_method from ..tensor.float8_tensor import Float8Tensor -from ..tensor.mxfp8_tensor import MXFP8Quantizer - -_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0"))) def _get_normalization_func(normalization: str, forward: bool): @@ -86,26 +82,16 @@ def apply_normalization( inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) - split_mxfp8_cast = False - if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer): - split_mxfp8_cast = True - - output = normalization_func( + return normalization_func( *inputs, eps, - None if split_mxfp8_cast else ln_out, - None if split_mxfp8_cast else output_quantizer, + ln_out, + output_quantizer, TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, fwd_ln_sm_margin, zero_centered_gamma, ) - return ( - (output_quantizer.quantize(output[0], out=ln_out), *output[1:]) - if split_mxfp8_cast - else output - ) - class _NoopCatFunc(torch.autograd.Function): """Concatenate tensors, doing a no-op if possible diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4d4d5ca78b..fb7af6022c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -55,9 +55,9 @@ prepare_for_saving, restore_from_saved, ) +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpp_extensions import ( general_gemm, @@ -160,11 +160,6 @@ def forward( # Configure quantizer for normalization output with_quantized_norm = fp8 and not return_layernorm_output - # for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer - # so we need to set with_quantized_norm to False - if isinstance(input_quantizer, Float8CurrentScalingQuantizer): - with_quantized_norm = False - if with_quantized_norm: if with_input_all_gather: input_quantizer.set_usage(rowwise=True, columnwise=False) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f20c95c0fc..09f70ebcb0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -212,8 +212,6 @@ def forward( # for return_layernorm_output: layernorm output = High precision, then cast to FP8 # high precision layernorm output and output of the linear are returned with_quantized_norm = fp8 and not return_layernorm_output - if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): - with_quantized_norm = False tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output