Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import (
_nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu,
_cudnn_frontend_version_supported,
)

from transformer_engine.pytorch.ops.fused import (
Expand Down Expand Up @@ -3638,10 +3638,7 @@ def test_grouped_mlp(
quantization == "mxfp8"
and dtype in (torch.bfloat16, torch.float16)
and glu_interleave_size == 32
and (
activation != "scaled_clamped_qgeglu"
or _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
)
and _cudnn_frontend_version_supported()
):
if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
forward_ops = module._module_groups[0]._forward_ops
Expand Down Expand Up @@ -3744,12 +3741,6 @@ def test_grouped_mlp_single_weight_numerics(
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")
if activation == "scaled_clamped_qgeglu" and not (
_nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
):
pytest.skip(
"ScaledClampedQGeGLU fused grouped MLP requires nvidia-cudnn-frontend >= 1.23.0"
)

split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
random.shuffle(split_sizes)
Expand Down Expand Up @@ -4106,12 +4097,6 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8(
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")
if dtype not in (torch.bfloat16, torch.float16):
pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16")
if activation == "scaled_clamped_qgeglu" and not (
_nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
):
pytest.skip(
"ScaledClampedQGeGLU fused grouped MLP requires nvidia-cudnn-frontend >= 1.23.0"
)

split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
random.shuffle(split_sizes)
Expand Down
28 changes: 4 additions & 24 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,11 @@


@functools.lru_cache(maxsize=1)
def _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() -> bool:
"""Check cuDNN FE min version with fixed numerics for qgeglu."""
try:
return PkgVersion(get_pkg_version("nvidia-cudnn-frontend")) >= PkgVersion("1.23.0")
except PackageNotFoundError:
return False

def _cudnn_frontend_version_supported() -> bool:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really this is specific to ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 andf BackwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8. If we add more fused ops that depend on the cuDNN frontend, there's no reason it will have the same requirements.

It would be more general to have a function that returned the cuDNN FE version as an Optional[tuple[int, ...]], and then the fused ops could decide for themselves whether it's supported.

"""Check cuDNN frontend is at least 1.23.0.

@functools.lru_cache(maxsize=1)
def _nvidia_cudnn_frontend_supports_wgrad() -> bool:
"""Check cuDNN FE min version for grouped GEMM wgrad kernel."""
All grouped MLP fused-kernel features require cuDNN frontend 1.23.0.
"""
try:
return PkgVersion(get_pkg_version("nvidia-cudnn-frontend")) >= PkgVersion("1.23.0")
except PackageNotFoundError:
Expand Down Expand Up @@ -140,8 +134,6 @@ def fuse_grouped_mlp_ops(
constructor accepting ``fc1``, ``glu_op``, ``fc2`` keyword args. The
``glu_op`` must be :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledSwiGLU`
or :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledClampedQGeGLU`.
May also expose ``is_fc1_bias_supported()`` and/or
``is_fc2_bias_supported()`` classmethods for bias eligibility.

Returns
-------
Expand All @@ -159,13 +151,6 @@ def fuse_grouped_mlp_ops(
if recipe is None or not recipe.mxfp8():
return ops

fc1_bias_ok = (
not hasattr(fused_op_cls, "is_fc1_bias_supported") or fused_op_cls.is_fc1_bias_supported()
)
fc2_bias_ok = (
not hasattr(fused_op_cls, "is_fc2_bias_supported") or fused_op_cls.is_fc2_bias_supported()
)

out = []
window, ops = ops[:3], ops[3:]
while len(window) == 3:
Expand All @@ -179,7 +164,6 @@ def fuse_grouped_mlp_ops(
matches_pattern = False
elif isinstance(window[1], ScaledClampedQGeGLU) and (
abs(window[1]._clamped.alpha - 1.702) > 0.001
or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
):
matches_pattern = False
elif window[0].num_groups != window[2].num_groups:
Expand All @@ -193,10 +177,6 @@ def fuse_grouped_mlp_ops(
matches_pattern = False
elif window[1].glu_interleave_size != 32:
matches_pattern = False
elif window[0].has_bias and not fc1_bias_ok:
matches_pattern = False
elif window[2].has_bias and not fc2_bias_ok:
matches_pattern = False

if matches_pattern:
op = fused_op_cls(
Expand Down
36 changes: 8 additions & 28 deletions transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
from __future__ import annotations
from collections.abc import Callable
import functools
import inspect
import math
import os
from typing import Optional

import torch

import transformer_engine_torch as tex
from ...module.base import get_dummy_wgrad
from ...quantization import Recipe
from ...tensor.grouped_tensor import GroupedTensor
from ...tensor.mxfp8_tensor import MXFP8Quantizer
Expand All @@ -25,13 +23,13 @@
from ..fuser import register_backward_fusion
from ..op import FusedOperation, FusibleOperation, OperationContext
from .._common import (
_nvidia_cudnn_frontend_supports_wgrad,
_cudnn_frontend_version_supported,
fuse_grouped_mlp_ops,
maybe_dequantize,
validate_grouped_mlp_dims,
)
from ...cpp_extensions import general_grouped_gemm_for_grouped_tensor
from ...module.base import _2X_ACC_WGRAD
from ...module.base import _2X_ACC_WGRAD, get_dummy_wgrad
from ...triton.grouped_dbias_dscales import compute_grouped_dbias_dscales


Expand Down Expand Up @@ -109,20 +107,6 @@ def _cudnn_compute_wgrad(
)


@functools.lru_cache(maxsize=1)
def _dglu_wrapper_has_generate_dbias_arg() -> bool:
"""True if cudnn-frontend SM100 dGLU wrapper accepts ``generate_dbias``."""
try:
from cudnn import grouped_gemm_dglu_wrapper_sm100 # pylint: disable=import-outside-toplevel
except ImportError:
return False
try:
params = inspect.signature(grouped_gemm_dglu_wrapper_sm100).parameters
except (TypeError, ValueError):
return False
return "generate_dbias" in params


def _compute_grad_params(
fc_op,
ctx,
Expand Down Expand Up @@ -300,10 +284,11 @@ def grouped_gemm_quant_kernel(cls) -> Callable:
@functools.lru_cache(maxsize=None)
def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]:
"""CuTe DSL kernel for grouped GEMM wgrad on SM100+.
Returns ``None`` when the cuDNN front-end package is older than
1.23.0.

Returns ``None`` when the environment variable
``NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLP`` is set to ``1``.
"""
if not _nvidia_cudnn_frontend_supports_wgrad():
if int(os.environ.get("NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLP", "0")) >= 1:
Comment thread
timmoon10 marked this conversation as resolved.
return None
from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module

Expand All @@ -317,20 +302,15 @@ def is_supported(cls) -> bool:
return False
if get_device_compute_capability()[0] != 10:
return False
if not _cudnn_frontend_version_supported():
return False
try:
cls.grouped_gemm_dglu_kernel()
cls.grouped_gemm_quant_kernel()
except ImportError:
return False
return True

@classmethod
def is_fc1_bias_supported(cls) -> bool:
"""Whether cudnn-frontend exposes ``generate_dbias`` on the dGLU SM100 wrapper (FC1 bias grad only)."""
if not cls.is_supported():
return False
return _dglu_wrapper_has_generate_dbias_arg()

def __init__(
self,
*,
Expand Down
43 changes: 4 additions & 39 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import annotations
from collections.abc import Callable, Iterable
import functools
import inspect
import os
from typing import Any, Optional

Expand All @@ -24,6 +23,7 @@
from ..fuser import register_forward_fusion
from ..op import FusedOperation, FusibleOperation, OperationContext
from .._common import (
_cudnn_frontend_version_supported,
fuse_grouped_mlp_ops,
is_quantized_tensor,
maybe_dequantize,
Expand Down Expand Up @@ -76,49 +76,15 @@ def is_supported(cls) -> bool:
return False
if get_device_compute_capability()[0] != 10:
return False
if not _cudnn_frontend_version_supported():
return False
try:
cls.grouped_gemm_glu_kernel()
cls.grouped_gemm_quant_kernel()
except ImportError:
return False
return True

@classmethod
@functools.lru_cache(maxsize=1)
def is_fc1_bias_supported(cls) -> bool:
"""Whether cudnn-frontend exposes ``bias_tensor`` on the grouped GEMM GLU SM100 wrapper (FC1)."""
if not cls.is_supported():
return False
try:
from cudnn import (
grouped_gemm_glu_wrapper_sm100,
) # pylint: disable=import-outside-toplevel
except ImportError:
return False
try:
params = inspect.signature(grouped_gemm_glu_wrapper_sm100).parameters
except (TypeError, ValueError):
return False
return "bias_tensor" in params

@classmethod
@functools.lru_cache(maxsize=1)
def is_fc2_bias_supported(cls) -> bool:
"""Whether cudnn-frontend exposes ``bias_tensor`` on the grouped GEMM Quant SM100 wrapper (FC2)."""
if not cls.is_supported():
return False
try:
from cudnn import (
grouped_gemm_quant_wrapper_sm100,
) # pylint: disable=import-outside-toplevel
except ImportError:
return False
try:
params = inspect.signature(grouped_gemm_quant_wrapper_sm100).parameters
except (TypeError, ValueError):
return False
return "bias_tensor" in params

def __init__(
self,
*,
Expand Down Expand Up @@ -433,6 +399,7 @@ def fuser_forward(
"sfa_tensor": fc1_kernel_out["sfd_row_tensor"],
"padded_offsets": split_points,
"alpha_tensor": alpha_tensor.float(),
"bias_tensor": fc2_bias_packed,
"norm_const_tensor": None,
"prob_tensor": fc2_scales_tensor,
"acc_dtype": torch.float32,
Expand All @@ -442,8 +409,6 @@ def fuser_forward(
"current_stream": current_stream,
"use_dynamic_sched": True,
}
if self.is_fc2_bias_supported():
fc2_quant_kwargs["bias_tensor"] = fc2_bias_packed

if fc2_op.single_grouped_weight:
# Clone and swizzle scales for GEMM (original stays unmodified for save_for_backward)
Expand Down
Loading