Skip to content

Remove uncessary ctype being passed to GroupedGEMMQuant kernel#2922

Merged
vthumbe1503 merged 3 commits into
NVIDIA:mainfrom
vthumbe1503:remove_c_type
Apr 24, 2026
Merged

Remove uncessary ctype being passed to GroupedGEMMQuant kernel#2922
vthumbe1503 merged 3 commits into
NVIDIA:mainfrom
vthumbe1503:remove_c_type

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented Apr 24, 2026

Description

We dont depend on the intermediate C that can be potentially outputted by the grouped_gemm_quant kernel for Fused MOE blocks.

For future:
Removing this c_type would eliminate unecessary intermediate high precision C tensor being outputted from the cudnn kernel. Although to see any real memory benefit we would need to move to newer cudnn release.

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 24, 2026

Greptile Summary

This PR removes the c_dtype keyword argument from calls to grouped_gemm_quant_wrapper_sm100 (FC2 forward GEMM and FC1 backward dgrad GEMM) where the intermediate high-precision C tensor is not consumed, and updates the corresponding test. The fc1_glu_kwargs in the forward pass intentionally retains c_dtype since the fused SwiGLU kernel still requires it. As noted in the description, the actual memory benefit will only materialize with a newer cuDNN release.

Confidence Score: 5/5

Safe to merge — minimal, focused change with no functional regressions.

The PR makes three symmetrical one-line deletions across two implementation files and a test. The cuDNN wrapper accepts c_dtype as an optional parameter (defaulting to no intermediate C allocation), so omitting it is valid. The FC1 forward path deliberately keeps c_dtype, and the test is updated to match. No logic is altered.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Removes c_dtype from fc2_quant_kwargs dict used when calling grouped_gemm_quant_wrapper_sm100; fc1_glu_kwargs intentionally retains c_dtype since the fused SwiGLU kernel uses the intermediate C tensor.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Removes c_dtype from fc1_dgrad_kwargs used in the FC1 dgrad GEMM; the intermediate C tensor is not consumed in this backward path.
tests/pytorch/test_fusible_ops.py Updates the test_grouped_gemm_quant_cute_matches_mxfp8_quantized test to match the updated kernel call signature by dropping the c_dtype argument.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[ForwardGroupedMLP] -->|fc1_glu_kwargs — keeps c_dtype| B[grouped_gemm_quant_wrapper_sm100\nFC1 SwiGLU GEMM]
    A -->|fc2_quant_kwargs — c_dtype removed| C[grouped_gemm_quant_wrapper_sm100\nFC2 GEMM]
    D[BackwardGroupedMLP] -->|fc1_dgrad_kwargs — c_dtype removed| E[grouped_gemm_quant_wrapper_sm100\nFC1 dgrad GEMM]
    B -->|intermediate C tensor still allocated| F[SwiGLU activation]
    C -->|no intermediate C tensor| G[FC2 output D]
    E -->|no intermediate C tensor| H[grad_input]
Loading

Reviews (2): Last reviewed commit: "Merge pull request #4 from ksivaman/pr_f..." | Re-trigger Greptile

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be merged in as a standalone commit as it requires changes from cudnn-FE for correctly inferring this type. Once those changes are released, this should be a part of a bigger PR that raises the min requires FE for enabled the fusion and removes all of the other cumbersome FE version checks for the features as well.

timmoon10
timmoon10 previously approved these changes Apr 24, 2026
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

@ksivaman if your concern is that the title isnt inline with what the change acheives, then it makes sense and I have changed the title.

But this is a completely harmless change and I would argue a clean syntactic change, since c_type isnt even needed by the kernel and we are not using the intermediate C. So I dont think it makes sense to block this merge to get it to main and to 2.15, given that it might reduce the testing effort to a lot of folks.

@vthumbe1503 vthumbe1503 changed the title Remove ctype to eliminate memory usage from the cudnn kernel Remove uncessary ctype being passed to GroupedGEMMQuant kernel Apr 24, 2026
ksivaman
ksivaman previously approved these changes Apr 24, 2026
Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After offline discussion, this has been tested with cudnnFE 1.22.0 and title is accurate. LGTM

ksivaman and others added 2 commits April 25, 2026 01:39
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Remove c_dtype from fusible ops test
@vthumbe1503 vthumbe1503 dismissed stale reviews from ksivaman and timmoon10 via 6f02dc5 April 24, 2026 20:25
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 merged commit 9ad2e7b into NVIDIA:main Apr 24, 2026
20 of 24 checks passed
KshitijLakhani pushed a commit that referenced this pull request Apr 27, 2026
* remove ctype to eliminate memory usage from the cudnn kernel

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* Remove c_dtype from fusible ops test

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
…A#2922)

* remove ctype to eliminate memory usage from the cudnn kernel

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* Remove c_dtype from fusible ops test

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants