Skip to content

Conversation

@jagadish-amd
Copy link

@jagadish-amd jagadish-amd commented Aug 7, 2025

mx fp8 is enabled though cherrypick patch from rel 2.7. This patch adds support to enable mx fp4.

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v

Ran 452 tests in 23.776s
OK (skipped=340)
Passed 112

mx fp8 is enabled though cherrypick patch from rel 2.7.
This patch adds support to enable mx fp4.

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v

Ran 452 tests in 23.776s
OK (skipped=340)
Passed 112

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@jagadish-amd jagadish-amd requested review from Copilot and petrex August 7, 2025 06:35
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for mx fp4 (mxfp4) format to complement the existing mx fp8 support in PyTorch. The change enables ROCm to use the mxfp4 recipe while maintaining nvfp4 support for non-ROCm platforms.

Key changes:

  • Extended test parameterization to include mxfp4 alongside existing mxfp8 and nvfp4 recipes
  • Added ROCm-specific version checks and data type mappings for Float4_e2m1fn_x2 support
  • Updated scaling function to handle mxfp4 recipe with appropriate constants

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
test/test_matmul_cuda.py Extended blockwise numerical tests to support mxfp4 recipe with ROCm-specific configurations
aten/src/ATen/native/cuda/Blas.cpp Added ROCm version checks for Float4_e2m1fn_x2 and other float8 types
aten/src/ATen/cuda/tunable/GemmHipblaslt.h Added HipDataType mapping for Float4_e2m1fn_x2 with ROCm 7.0+ support
aten/src/ATen/cuda/CUDADataType.h Extended CUDA data type mapping to include ROCm support for Float4_e2m1fn_x2

Comment on lines +1635 to +1636
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The conditional expression recipe if recipe == "mxfp4" else None is confusing. Consider extracting this logic into a clearer variable assignment or using a more explicit approach to pass the correct arguments to each scaling function.

Suggested change
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
scale_recipe_arg = recipe if recipe == "mxfp4" else None
A_scale = scale_func(A_ref, BLOCK_SIZE, scale_recipe_arg)
B_scale = scale_func(B_ref, BLOCK_SIZE, scale_recipe_arg)

Copilot uses AI. Check for mistakes.
Comment on lines +1635 to +1636
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The conditional expression recipe if recipe == "mxfp4" else None is confusing. Consider extracting this logic into a clearer variable assignment or using a more explicit approach to pass the correct arguments to each scaling function.

Suggested change
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
scale_recipe_arg = recipe if recipe == "mxfp4" else None
A_scale = scale_func(A_ref, BLOCK_SIZE, scale_recipe_arg)
B_scale = scale_func(B_ref, BLOCK_SIZE, scale_recipe_arg)

Copilot uses AI. Check for mistakes.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Aug 7, 2025

Jenkins build for 722c3351b11bf64117b080f41d52f846f1bdcb79 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@pruthvistony pruthvistony merged commit 2975ee1 into ROCm:release/2.8 Aug 7, 2025
0 of 2 checks passed
tvukovic-amd pushed a commit that referenced this pull request Aug 20, 2025
mx fp8 is enabled though cherrypick patch from rel 2.7. This patch adds
support to enable mx fp4.

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v

Ran 452 tests in 23.776s
OK (skipped=340)
Passed 112

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants