Skip to content

Enable NVFP4 grouped MLP GLU RHT amax path#3073

Merged
timmoon10 merged 9 commits into
NVIDIA:mainfrom
sraman-rgb:nvfp4-grouped-mlp-glu-rht-amax
Jun 5, 2026
Merged

Enable NVFP4 grouped MLP GLU RHT amax path#3073
timmoon10 merged 9 commits into
NVIDIA:mainfrom
sraman-rgb:nvfp4-grouped-mlp-glu-rht-amax

Conversation

@sraman-rgb

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 1, 2026
@greptile-apps

greptile-apps Bot commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR enables the NVFP4 grouped MLP GLU RHT amax path by adding two new C++ entry points (nvfp4_quantize_with_amax and nvfp4_group_quantize_with_amax) that accept externally precomputed per-group amaxes and handle distributed allreduce before quantization. A new fused grouped_gemm_glu_hadamard_kernel is wired into ForwardGroupedMLP_CuTeGEMMGLU so the FC1 GLU+RHT pass can emit amaxes directly to FC2 input quantization.

  • New C++ quantize-with-amax functions: allreduce is correctly placed before cast kernels in both paths; empty-rank distributed scenarios are handled by the compute_amax=false branch in quantize_impl.
  • quantize_impl refactor: reduce_amaxes lambda fires at a single site before all cast kernels, fixing ordering for both the existing and new paths.
  • Python integration: _group_quantize_with_amax_for_grouped_mlp dispatches to the correct C++ path; _use_tmem_post_rht_amax is properly cached with lru_cache.

Confidence Score: 5/5

The new quantize-with-amax paths correctly place allreduce before cast kernels in both the single-tensor and grouped variants, and empty-rank distributed scenarios are handled properly for all new code paths.

All new C++ and Python code introduced in this PR correctly implements the amax-reduction-before-quantization ordering. The compute_amax=false branch in quantize_impl calls reduce_amaxes() before the cast kernels for both non-empty and empty inputs. The grouped path calls allreduce_nvfp4_amax_tensors explicitly before group_quantize_nvfp4_impl. Pre-existing behavior for the regular quantize path (empty-rank early exit without allreduce) is unchanged.

No files require special attention; the most complex logic is in cast.cpp and quantizer.cpp, both of which correctly sequence allreduce before quantization in all new code paths.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds nvfp4_quantize_with_amax and nvfp4_group_quantize_with_amax; allreduce is correctly placed before the cast in both paths, and empty-rank handling is correct for the new compute_amax=false path.
transformer_engine/pytorch/csrc/quantizer.cpp Refactors quantize_impl to accept compute_amax flag; reduce_amaxes lambda is now called before cast kernels for both flag values. Empty ranks with compute_amax=false now participate in the collective.
transformer_engine/pytorch/ops/_common.py Adds _group_quantize_with_amax_for_grouped_mlp and extracts _wrap_single_nvfp4_as_grouped; the num_groups==1 path now uses nvfp4_quantize_with_amax with correct allreduce before cast.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds grouped_gemm_glu_hadamard_kernel to ForwardGroupedMLP_CuTeGEMMGLU and integrates the new fused GLU+RHT+amax kernel path; _use_tmem_post_rht_amax is properly cached via lru_cache.
transformer_engine/pytorch/csrc/common.h Moves quantize_impl from private to public on NVFP4Quantizer to allow cast.cpp to call it directly with compute_amax=false.
transformer_engine/pytorch/csrc/extensions.h Declares two new public entry points nvfp4_quantize_with_amax and nvfp4_group_quantize_with_amax; signatures are consistent with implementations.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Exposes the two new quantization-with-amax functions to Python via pybind11; argument names and default values match the C++ signatures.

Reviews (8): Last reviewed commit: "Merge branch 'main' into nvfp4-grouped-m..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM. Left a few comments on code duplication and other minor issues.

Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp Outdated
sraman-rgb and others added 8 commits June 4, 2026 13:38
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-glu-rht-amax branch from 22b9b0e to d3533df Compare June 4, 2026 21:13
@timmoon10

Copy link
Copy Markdown
Member

/te-ci pytorch

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@timmoon10 timmoon10 merged commit 3f64073 into NVIDIA:main Jun 5, 2026
20 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants