[JAX] Collective GEMM with FP8 and MXFP8 support#2740
Conversation
|
/te-ci JAX L1 |
Greptile SummaryThis PR extends the JAX Collective GEMM implementation to support Key changes:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["tex.gemm() / layernorm_mlp()"] --> B{scaling_mode}
B -- "NO_SCALING / TENSOR_SCALING" --> C[Standard GEMM path]
B -- "MXFP8_1D_SCALING" --> D{collective_op?}
D -- "NONE / is_outer" --> E[apply_padding_to_scale_inv\n+ swizzled_scale]
D -- "ALL_GATHER / REDUCE_SCATTER" --> F[Assert dims % 128 == 0\nSkip padding]
F --> G[swizzled_scale on lhs/rhs scale_inv]
G --> H{need_reorder?}
H -- "RS" --> I[_reorder_tpsp_leading on lhs\n+ lhs_scale_inv]
H -- "AG" --> J[_reorder_tpsp_leading on lhs_scale_inv only\nlhs data stays as-is]
I --> K[GemmPrimitive.inner_primitive.bind]
J --> K
C --> K
E --> K
K --> L{post-process output}
L -- "AG + need_reorder" --> M[_reorder_dp_leading on output]
L -- "other" --> N[return output as-is]
M --> O[return output]
N --> O
Last reviewed commit: 6f0c442 |
4f20d2d to
f899fa2
Compare
|
/te-ci JAX L1 |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks!
| lhs_scale_specs = rhs_scale_specs = (None,) | ||
| if scaling_mode.is_1d_block_scaling(): | ||
| rhs_scale_specs = rhs_specs | ||
| # Set the seq spec to None to trigger AG the scales as TE/Common CGEMM does not handle | ||
| # scale collecting yet | ||
| if collective_op.is_all_gather: | ||
| lhs_scale_specs = tuple( | ||
| None if i == sequence_dim else s for i, s in enumerate(lhs_specs) | ||
| ) | ||
| else: | ||
| lhs_scale_specs = lhs_specs |
There was a problem hiding this comment.
This AG is only required for overlap with Userbuffers. We'll conditionally disable it whenever we're using the cuBLASMp backend instead.
No changes needed in this PR, just dropping a note for reference.
* Enable cgemm + FP8 tests * Implement CGEMM + MXFP8 * added size check for mxfp8 * added tols for assertions * update tests with recipes * enable tests + is_quantize_recipe_supported Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Description
This PR extends the JAX Collective GEMM support with DelayedScalingFP8, CurrentScalingFP8, and MXFP8.
Unit tests for those quantization recipes are added. In addition, this PR also cleans up the test infrastructure in the collective gemm tests.
Note that Collective GEMM + MXFP8 requires all dimensions of the GEMM operands to be divisible by 128.
Besides, in the case of CGEMM + MXFP8 + AllGather, the block scales are still all-gathered in the critical path, unlike the quantized data, which is collectively gathered overlapping with the computation.
Type of change
Checklist: