Import get_fp8_dtypes from the correct place in bench_gemm_a8wfp4#2602
Import get_fp8_dtypes from the correct place in bench_gemm_a8wfp4#2602
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Fixes bench_gemm_a8wfp4.py crashing on startup by importing get_fp8_dtypes from the module where it is actually defined (aiter.ops.triton.utils.types) rather than from arch_info.
Changes:
- Import
get_fp8_dtypesfromaiter.ops.triton.utils.types. - Replace
arch_info.get_fp8_dtypes()call withget_fp8_dtypes().
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| def bench_gemm_fn(M: int, N: int, K: int, metric: str, layout: str): | ||
| e5m2_type, e4m3_type = arch_info.get_fp8_dtypes() | ||
| e5m2_type, e4m3_type = get_fp8_dtypes() |
There was a problem hiding this comment.
e5m2_type is assigned but never used. With the repo’s Ruff check enabled in CI, this will likely raise F841 on this changed line. Use _ for the unused return value (or otherwise use e5m2_type) to avoid lint failures.
| e5m2_type, e4m3_type = get_fp8_dtypes() | |
| _, e4m3_type = get_fp8_dtypes() |
There was a problem hiding this comment.
This suggestion makes sense IMHO.
|
|
||
| def bench_gemm_fn(M: int, N: int, K: int, metric: str, layout: str): | ||
| e5m2_type, e4m3_type = arch_info.get_fp8_dtypes() | ||
| e5m2_type, e4m3_type = get_fp8_dtypes() |
There was a problem hiding this comment.
This suggestion makes sense IMHO.
|
A similar fix should have been merged as part of #2434. Please check if main works. If not, please resolve conflicts and we can discuss |
|
Confirmed: this fix is already on |
Motivation
When running bench_gemm_a8wfp4.py, the script fails immediately with:
Technical Details
bench_gemm_a8wfp4.pywas callingarch_info.get_fp8_dtypes(), butget_fp8_dtypesis defined inaiter.ops.triton.utils.types, not inarch_info. Fixed by importing and calling it from the correct module, consistent with every other callsite in the codebase (e.g.
test_gemm_a8wfp4.py,test_gemm_a8w8.py, etc.).Test Plan
Run
bench_gemm_a8wfp4.py --model all -M 4096 --model-configs ... --metric time --layout TT -oand verify it no longer raisesAttributeError.Test Result
Script launches successfully after the fix.