Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
Expand Down
89 changes: 72 additions & 17 deletions transformer_engine/pytorch/csrc/extensions/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"

namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) {
std::vector<size_t> shape_vec;
for (int i = 0; i < shape.ndim; i++) {
for (size_t i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i];
shape_vec.push_back(t);
}
Expand Down Expand Up @@ -74,6 +75,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
float eps, py::object out, py::handle quantizer,
DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;

Expand Down Expand Up @@ -107,14 +109,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}

// Determine whether to avoid fused kernel
bool force_unfused_kernel = false;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) {
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel = true;
} else if (N % 128 != 0 || H % 128 != 0) {
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel = true;
bool force_unfused_kernel = true;
if (quantizer.is_none()) {
// No need for separate quantization step if output is unquantized
force_unfused_kernel = false;
} else if (IsFloat8Quantizers(quantizer.ptr())) {
// Always used fused kernel for FP8 delayed scaling
force_unfused_kernel = false;
} else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel = N % 128 != 0 || H % 128 != 0;
}
}
TensorWrapper unquantized_out_cu;
Expand Down Expand Up @@ -145,6 +150,29 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe

// Quantize output if using unfused kernel
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
}
Expand Down Expand Up @@ -196,6 +224,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
py::object out, py::handle quantizer,
transformer_engine::DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;

Expand Down Expand Up @@ -223,14 +252,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}

// Determine whether to avoid fused kernel
bool force_unfused_kernel = false;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) {
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel = true;
} else if (N % 128 != 0 || H % 128 != 0) {
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel = true;
bool force_unfused_kernel = true;
if (quantizer.is_none()) {
// No need for separate quantization step if output is unquantized
force_unfused_kernel = false;
} else if (IsFloat8Quantizers(quantizer.ptr())) {
// Always used fused kernel for FP8 delayed scaling
force_unfused_kernel = false;
} else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel = N % 128 != 0 || H % 128 != 0;
}
}
TensorWrapper unquantized_out_cu;
Expand Down Expand Up @@ -261,6 +293,29 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w

// Quantize output if using unfused kernel
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
}
Expand Down
20 changes: 3 additions & 17 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

"""Internal function used by multiple modules."""

import os
from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass
from functools import reduce
Expand All @@ -16,9 +15,6 @@
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.mxfp8_tensor import MXFP8Quantizer

_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0")))


def _get_normalization_func(normalization: str, forward: bool):
Expand Down Expand Up @@ -86,26 +82,16 @@ def apply_normalization(

inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)

split_mxfp8_cast = False
if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer):
split_mxfp8_cast = True

output = normalization_func(
return normalization_func(
*inputs,
eps,
None if split_mxfp8_cast else ln_out,
None if split_mxfp8_cast else output_quantizer,
ln_out,
output_quantizer,
TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype,
fwd_ln_sm_margin,
zero_centered_gamma,
)

return (
(output_quantizer.quantize(output[0], out=ln_out), *output[1:])
if split_mxfp8_cast
else output
)


class _NoopCatFunc(torch.autograd.Function):
"""Concatenate tensors, doing a no-op if possible
Expand Down
7 changes: 1 addition & 6 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
prepare_for_saving,
restore_from_saved,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import (
general_gemm,
Expand Down Expand Up @@ -160,11 +160,6 @@ def forward(

# Configure quantizer for normalization output
with_quantized_norm = fp8 and not return_layernorm_output
# for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer
# so we need to set with_quantized_norm to False
if isinstance(input_quantizer, Float8CurrentScalingQuantizer):
with_quantized_norm = False

if with_quantized_norm:
if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,6 @@ def forward(
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
with_quantized_norm = fp8 and not return_layernorm_output
if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
with_quantized_norm = False

tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output
Expand Down