-
Notifications
You must be signed in to change notification settings - Fork 23
Enabled fp8 gemm gelu_aux_bias #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, | ||
| &bias_type, sizeof(bias_type))); | ||
| } | ||
| #if HIP_VERSION >= 70000000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it rather be HIPBLASLT_VERSION check? I understand it will result in discrepancy between cpp UT and TE code but I'd rather keep TE itsef correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to check from which HIPBLASLT_VERSION, the commit is there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check 1.0 which released in ROCm 7.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes were in version 0.15.0, and then the version is updated to 1.0.0 and greater. But as you mentioned, 1.0 is released in ROCm 7.0. So, based on this i can guard it with >=1.0.0. what are your thoughts @ipanfilo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.15 is what is for MXFP8 guarding too. So it should be OK to use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not updated
| NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, | ||
| "FP8 input to GEMM requires inverse of scale!"); | ||
|
|
||
| // check consistency of arguments: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not part of your changes but looks like those 2 comment lines are outdated and not relevant to code,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed comments
| } | ||
| #else | ||
| hipDeviceProp_t prop; | ||
| NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better move hipGetDeviceProperties under if not to call it every time for no reason
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
| HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, | ||
| &bias_type, sizeof(bias_type))); | ||
| } | ||
| #if HIP_VERSION >= 70000000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not updated
8955dcf to
2bd2d58
Compare
Description
Added support to fp8 gelu_aux_bias
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: