Skip to content

[PyTorch] Cleanup cudnn-frontend requirements for fused grouped MLP#2948

Merged
ksivaman merged 3 commits into
NVIDIA:mainfrom
ksivaman:cudnn_fe_1_23_0_requirement
May 1, 2026
Merged

[PyTorch] Cleanup cudnn-frontend requirements for fused grouped MLP#2948
ksivaman merged 3 commits into
NVIDIA:mainfrom
ksivaman:cudnn_fe_1_23_0_requirement

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

Require cudnn-frontend version 1.23.0 to fused grouped MLP path.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Require cudnn-frontend version 1.23.0 to fused grouped MLP path for all features.
  • Introduce NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLP envvar to switch to cublas grouped GEMM for wgrad.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 30, 2026 23:04
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 30, 2026

Greptile Summary

This PR simplifies the cudnn-frontend version gating for the fused grouped MLP path by consolidating two identical >=1.23.0 version checks into a single _cudnn_frontend_version_supported() helper and moving the check into each op's is_supported(). It also removes the runtime signature-introspection machinery (is_fc1_bias_supported, is_fc2_bias_supported, _dglu_wrapper_has_generate_dbias_arg) and replaces the wgrad version gate with an env-var opt-out (NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLP).

  • grouped_gemm_wgrad_kernel() is the only kernel not exercised inside the try/except ImportError block in BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(). If grouped_gemm_wgrad_wrapper_sm100 is absent despite the version check passing, the op reports itself as supported but the first backward pass raises an unhandled ImportError.

Confidence Score: 3/5

Safe to merge if cudnn-frontend 1.23.0 always ships grouped_gemm_wgrad_wrapper_sm100; the guard gap in is_supported() is a latent runtime failure on atypical builds.

One P1 finding: grouped_gemm_wgrad_kernel() is not covered by the is_supported() try/except, unlike the other two kernels. On any build where the symbol is missing the op silently claims support but crashes on the first backward pass. All other changes are clean refactoring.

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — specifically the is_supported() method and grouped_gemm_wgrad_kernel() interaction.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/_common.py Merges two identical version-check helpers into one _cudnn_frontend_version_supported(), removes per-activation and per-bias-feature guards from fuse_grouped_mlp_ops, moving that responsibility to each op's is_supported().
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds version check to is_supported(), removes is_fc1/fc2_bias_supported introspection, and unconditionally includes bias_tensor (possibly None) in fc2_quant_kwargs; functional change assumes 1.23.0 always accepts bias_tensor=None.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Adds version check to is_supported(), removes is_fc1_bias_supported and _dglu_wrapper_has_generate_dbias_arg introspection, replaces version gate on grouped_gemm_wgrad_kernel with an env-var opt-out; grouped_gemm_wgrad_kernel is not covered by the is_supported try/except, creating a potential unhandled ImportError in the backward pass.
tests/pytorch/test_fusible_ops.py Replaces per-activation version guards with the unified _cudnn_frontend_version_supported() check; per-activation skip blocks are removed since is_supported() now enforces the version requirement globally.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fuse_grouped_mlp_ops called] --> B{fused_op_cls.is_supported?}
    B -- No --> C[Return ops unchanged]
    B -- Yes --> D{recipe.mxfp8?}
    D -- No --> C
    D -- Yes --> E[Sliding-window pattern match FC1 + ScaledGLU + FC2]
    E --> F{Pattern matches?}
    F -- No --> G[Emit ops individually]
    F -- Yes --> H[Create FusedOp]

    subgraph is_supported checks
        S1{NVTE_CUTEDSL_FUSED_GROUPED_MLP > 0}
        S2{SM100 compute capability}
        S3{_cudnn_frontend_version_supported >= 1.23.0}
        S4{grouped_gemm_dglu + quant imports OK, wgrad NOT checked}
        S1 --> S2 --> S3 --> S4
    end

    subgraph wgrad backward
        W1[grouped_gemm_wgrad_kernel called]
        W2{NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLP >= 1?}
        W3[Return None, cublas fallback]
        W4[Import grouped_gemm_wgrad_wrapper_sm100, no try/except]
        W1 --> W2
        W2 -- Yes --> W3
        W2 -- No --> W4
    end
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py, line 297-312 (link)

    P1 grouped_gemm_wgrad_kernel not guarded in is_supported

    grouped_gemm_dglu_kernel() and grouped_gemm_quant_kernel() are both called inside the try/except ImportError block in is_supported(), so a missing symbol causes the op to report itself as unsupported. grouped_gemm_wgrad_kernel() is not included in that guard. If cudnn-frontend 1.23.0 is installed (so the version check passes) but grouped_gemm_wgrad_wrapper_sm100 is absent (e.g. a partial or non-standard build), is_supported() returns True but the first backward pass raises an unhandled ImportError at the call sites on lines 638 and 734.

    Adding it to the try/except in is_supported() closes the gap:

            try:
                cls.grouped_gemm_dglu_kernel()
                cls.grouped_gemm_quant_kernel()
                cls.grouped_gemm_wgrad_kernel()
            except ImportError:
                return False

    Note that grouped_gemm_wgrad_kernel() already returns None safely when the env-var opt-out is set, so calling it here does not change the opt-out behavior.

Reviews (1): Last reviewed commit: "Merge branch 'main' into cudnn_fe_1_23_0..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

except PackageNotFoundError:
return False

def _cudnn_frontend_version_supported() -> bool:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really this is specific to ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 andf BackwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8. If we add more fused ops that depend on the cuDNN frontend, there's no reason it will have the same requirements.

It would be more general to have a function that returned the cuDNN FE version as an Optional[tuple[int, ...]], and then the fused ops could decide for themselves whether it's supported.

Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
@ksivaman ksivaman merged commit 0e9020d into NVIDIA:main May 1, 2026
10 of 12 checks passed
@ksivaman ksivaman deleted the cudnn_fe_1_23_0_requirement branch May 1, 2026 06:23
ksivaman added a commit that referenced this pull request May 1, 2026
…#2948)

* Switch to cuDNN-FE min version 1.23.0 to enable fused grouped MLP

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
…NVIDIA#2948)

* Switch to cuDNN-FE min version 1.23.0 to enable fused grouped MLP

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants