Skip to content

[JAX] Collective GEMM with FP8 and MXFP8 support#2740

Merged
phu0ngng merged 23 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_fp8
Mar 13, 2026
Merged

[JAX] Collective GEMM with FP8 and MXFP8 support#2740
phu0ngng merged 23 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_fp8

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng commented Mar 5, 2026

Description

This PR extends the JAX Collective GEMM support with DelayedScalingFP8, CurrentScalingFP8, and MXFP8.
Unit tests for those quantization recipes are added. In addition, this PR also cleans up the test infrastructure in the collective gemm tests.

Note that Collective GEMM + MXFP8 requires all dimensions of the GEMM operands to be divisible by 128.
Besides, in the case of CGEMM + MXFP8 + AllGather, the block scales are still all-gathered in the critical path, unlike the quantized data, which is collectively gathered overlapping with the computation.

Type of change

  • Documentation change (change only to the documentation, either a fix or new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8 [JAX] CGEMM + FP8MXFP8 Mar 10, 2026
@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8MXFP8 [JAX] CGEMM + FP8/MXFP8 Mar 10, 2026
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

@phu0ngng phu0ngng marked this pull request as ready for review March 10, 2026 23:23
@phu0ngng phu0ngng changed the title [JAX] CGEMM + FP8/MXFP8 [JAX] Collective GEMM with FP8 and MXFP8 support Mar 10, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 10, 2026

Greptile Summary

This PR extends the JAX Collective GEMM implementation to support DelayedScalingFP8, CurrentScalingFP8, and MXFP8 quantization recipes, and cleans up the collective GEMM test infrastructure. The core logic change in gemm.py adds MXFP8-aware scale tensor reordering (_reorder_tpsp_leading / _reorder_dp_leading) and updates the SPMD sharding specs to correctly distribute block scales alongside their operands during AllGather and ReduceScatter collectives.

Key changes:

  • GemmPrimitive.impl gains an MXFP8 + Collective path that skips padding, validates 128-alignment, reorders scale tensors for ReduceScatter (both data and scale) and AllGather (scale only; data layout is handled by the kernel)
  • _parse_operand_output_specs is updated to propagate per-operand scale sharding specs, replacing the previous uniform lhs_sharding / none_sharding logic
  • helper.py gains get_quantization_recipe and is_quantize_recipe_supported utility functions for mapping string recipe names to objects, now used throughout the test suite
  • The test suite is refactored to use per-test-case pytest node IDs in run_test_cgemm.sh, and new test classes cover FP8 / MXFP8 for all three test modules
  • One remaining diagnostic bug: The assertion at gemm.py:704–708 checks lhs_scale_inv.shape[sequence_dim] but its error message says "RHS scale inv sequence dimension", which would mislead users debugging alignment issues on the LHS scale

Confidence Score: 4/5

  • Safe to merge with the minor diagnostic copy-paste fix addressed; no correctness or data-integrity issues found beyond what is already tracked in open review threads.
  • The core MXFP8 + Collective GEMM logic is well-structured: scale reordering is guarded by scaling_mode.is_1d_block_scaling() and need_reorder, the sharding spec updates correctly propagate block-scale axes, and the new helpers are clean. The main outstanding issues are misleading assertion messages (copy-paste errors already flagged in prior review threads plus one new one at line 704), none of which affect runtime correctness. The test coverage for the new paths is comprehensive. Score is 4 rather than 5 because of the cluster of diagnostic copy-paste errors in the new assertion messages.
  • transformer_engine/jax/cpp_extensions/gemm.py — the assertion message at lines 704–708 references "RHS scale inv" but checks lhs_scale_inv, consistent with the copy-paste pattern already flagged for lines 698–702.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Core GEMM primitive extended with MXFP8 + Collective GEMM support; introduces scale reordering helpers, updated sharding specs for block scales, and an NVFP4 guard — but contains multiple copy-paste errors in assertion messages (line 699 says "LHS" for the RHS check, line 705 says "RHS scale inv" for the LHS scale check) that would mislead users debugging alignment failures.
transformer_engine/jax/quantize/helper.py Adds two clean utility functions (get_quantization_recipe and is_quantize_recipe_supported) that bridge string recipe names to recipe objects; straightforward and well-documented.
examples/jax/collective_gemm/test_layernorm_mlp_grad.py Adds FP8/MXFP8 test cases for LayerNorm MLP gradient; uses QuantizerFactory.create_set(n_quantizer_sets=2) correctly to create independent sets per layer. Both the reference and collective calls share the same quantizer_sets object, which is acceptable for purely functional JAX quantizer pytrees.
examples/jax/collective_gemm/common.py Clean refactor: moves shared distributed helpers and imports to the top of the file, adds FP8 tolerances, introduces get_tolerance_dtype helper, and updates cgemm_parser to use --quantize-recipe instead of --fp8-recipe.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["tex.gemm() / layernorm_mlp()"] --> B{scaling_mode}
    B -- "NO_SCALING / TENSOR_SCALING" --> C[Standard GEMM path]
    B -- "MXFP8_1D_SCALING" --> D{collective_op?}
    D -- "NONE / is_outer" --> E[apply_padding_to_scale_inv\n+ swizzled_scale]
    D -- "ALL_GATHER / REDUCE_SCATTER" --> F[Assert dims % 128 == 0\nSkip padding]
    F --> G[swizzled_scale on lhs/rhs scale_inv]
    G --> H{need_reorder?}
    H -- "RS" --> I[_reorder_tpsp_leading on lhs\n+ lhs_scale_inv]
    H -- "AG" --> J[_reorder_tpsp_leading on lhs_scale_inv only\nlhs data stays as-is]
    I --> K[GemmPrimitive.inner_primitive.bind]
    J --> K
    C --> K
    E --> K
    K --> L{post-process output}
    L -- "AG + need_reorder" --> M[_reorder_dp_leading on output]
    L -- "other" --> N[return output as-is]
    M --> O[return output]
    N --> O
Loading

Last reviewed commit: 6f0c442

Comment thread transformer_engine/jax/cpp_extensions/gemm.py
Comment thread examples/jax/collective_gemm/run_test_cgemm.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
Comment thread examples/jax/collective_gemm/test_layernorm_mlp_grad.py Outdated
Comment thread transformer_engine/jax/quantize/helper.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
@phu0ngng phu0ngng requested a review from denera March 11, 2026 19:52
@phu0ngng phu0ngng force-pushed the cgemm_fp8 branch 2 times, most recently from 4f20d2d to f899fa2 Compare March 11, 2026 21:07
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

Comment thread examples/jax/collective_gemm/test_layernorm_mlp_grad.py Outdated
phu0ngng and others added 11 commits March 11, 2026 14:37
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
pre-commit-ci Bot and others added 6 commits March 11, 2026 14:37
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
pre-commit-ci Bot and others added 4 commits March 11, 2026 21:38
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Copy Markdown
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines +994 to +1004
lhs_scale_specs = rhs_scale_specs = (None,)
if scaling_mode.is_1d_block_scaling():
rhs_scale_specs = rhs_specs
# Set the seq spec to None to trigger AG the scales as TE/Common CGEMM does not handle
# scale collecting yet
if collective_op.is_all_gather:
lhs_scale_specs = tuple(
None if i == sequence_dim else s for i, s in enumerate(lhs_specs)
)
else:
lhs_scale_specs = lhs_specs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This AG is only required for overlap with Userbuffers. We'll conditionally disable it whenever we're using the cuBLASMp backend instead.

No changes needed in this PR, just dropping a note for reference.

@phu0ngng phu0ngng merged commit 14c29da into NVIDIA:main Mar 13, 2026
9 of 12 checks passed
@phu0ngng phu0ngng deleted the cgemm_fp8 branch March 13, 2026 15:27
vthumbe1503 pushed a commit to ksivaman/TransformerEngine-1 that referenced this pull request Apr 1, 2026
* Enable cgemm + FP8 tests

* Implement CGEMM + MXFP8

* added size check for mxfp8

* added tols for assertions

* update tests with recipes

* enable tests + is_quantize_recipe_supported

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants