[PyTorch] Cleanup cudnn-frontend requirements for fused grouped MLP#2948
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR simplifies the cudnn-frontend version gating for the fused grouped MLP path by consolidating two identical
Confidence Score: 3/5Safe to merge if cudnn-frontend 1.23.0 always ships One P1 finding:
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fuse_grouped_mlp_ops called] --> B{fused_op_cls.is_supported?}
B -- No --> C[Return ops unchanged]
B -- Yes --> D{recipe.mxfp8?}
D -- No --> C
D -- Yes --> E[Sliding-window pattern match FC1 + ScaledGLU + FC2]
E --> F{Pattern matches?}
F -- No --> G[Emit ops individually]
F -- Yes --> H[Create FusedOp]
subgraph is_supported checks
S1{NVTE_CUTEDSL_FUSED_GROUPED_MLP > 0}
S2{SM100 compute capability}
S3{_cudnn_frontend_version_supported >= 1.23.0}
S4{grouped_gemm_dglu + quant imports OK, wgrad NOT checked}
S1 --> S2 --> S3 --> S4
end
subgraph wgrad backward
W1[grouped_gemm_wgrad_kernel called]
W2{NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLP >= 1?}
W3[Return None, cublas fallback]
W4[Import grouped_gemm_wgrad_wrapper_sm100, no try/except]
W1 --> W2
W2 -- Yes --> W3
W2 -- No --> W4
end
|
| except PackageNotFoundError: | ||
| return False | ||
|
|
||
| def _cudnn_frontend_version_supported() -> bool: |
There was a problem hiding this comment.
Really this is specific to ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 andf BackwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8. If we add more fused ops that depend on the cuDNN frontend, there's no reason it will have the same requirements.
It would be more general to have a function that returned the cuDNN FE version as an Optional[tuple[int, ...]], and then the fused ops could decide for themselves whether it's supported.
…#2948) * Switch to cuDNN-FE min version 1.23.0 to enable fused grouped MLP Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…NVIDIA#2948) * Switch to cuDNN-FE min version 1.23.0 to enable fused grouped MLP Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Description
Require
cudnn-frontendversion1.23.0to fused grouped MLP path.Type of change
Changes
cudnn-frontendversion1.23.0to fused grouped MLP path for all features.NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLPenvvar to switch to cublas grouped GEMM for wgrad.Checklist: