-
Notifications
You must be signed in to change notification settings - Fork 75
[release/2.8] Add mx fp4 support #2472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
mx fp8 is enabled though cherrypick patch from rel 2.7. This patch adds support to enable mx fp4. PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 23.776s OK (skipped=340) Passed 112 Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for mx fp4 (mxfp4) format to complement the existing mx fp8 support in PyTorch. The change enables ROCm to use the mxfp4 recipe while maintaining nvfp4 support for non-ROCm platforms.
Key changes:
- Extended test parameterization to include mxfp4 alongside existing mxfp8 and nvfp4 recipes
- Added ROCm-specific version checks and data type mappings for Float4_e2m1fn_x2 support
- Updated scaling function to handle mxfp4 recipe with appropriate constants
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| test/test_matmul_cuda.py | Extended blockwise numerical tests to support mxfp4 recipe with ROCm-specific configurations |
| aten/src/ATen/native/cuda/Blas.cpp | Added ROCm version checks for Float4_e2m1fn_x2 and other float8 types |
| aten/src/ATen/cuda/tunable/GemmHipblaslt.h | Added HipDataType mapping for Float4_e2m1fn_x2 with ROCm 7.0+ support |
| aten/src/ATen/cuda/CUDADataType.h | Extended CUDA data type mapping to include ROCm support for Float4_e2m1fn_x2 |
| A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) | ||
| B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) |
Copilot
AI
Aug 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The conditional expression recipe if recipe == "mxfp4" else None is confusing. Consider extracting this logic into a clearer variable assignment or using a more explicit approach to pass the correct arguments to each scaling function.
| A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) | |
| B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) | |
| scale_recipe_arg = recipe if recipe == "mxfp4" else None | |
| A_scale = scale_func(A_ref, BLOCK_SIZE, scale_recipe_arg) | |
| B_scale = scale_func(B_ref, BLOCK_SIZE, scale_recipe_arg) |
| A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) | ||
| B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) |
Copilot
AI
Aug 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The conditional expression recipe if recipe == "mxfp4" else None is confusing. Consider extracting this logic into a clearer variable assignment or using a more explicit approach to pass the correct arguments to each scaling function.
| A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) | |
| B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) | |
| scale_recipe_arg = recipe if recipe == "mxfp4" else None | |
| A_scale = scale_func(A_ref, BLOCK_SIZE, scale_recipe_arg) | |
| B_scale = scale_func(B_ref, BLOCK_SIZE, scale_recipe_arg) |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
Jenkins build for 722c3351b11bf64117b080f41d52f846f1bdcb79 commit finished as FAILURE |
mx fp8 is enabled though cherrypick patch from rel 2.7. This patch adds support to enable mx fp4. PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 23.776s OK (skipped=340) Passed 112 --------- Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
mx fp8 is enabled though cherrypick patch from rel 2.7. This patch adds support to enable mx fp4.
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v
Ran 452 tests in 23.776s
OK (skipped=340)
Passed 112