Conversation
…8 GEMM test scripts
The merge commit 6a18cd6 accidentally preserved conflict markers in gemm_op_a4w4.py. Apply the gfx-aware dispatch fix (same pattern as gemm_op_a8w8.py) — use (gfx, cu_num, M, N, K) key when the CSV has a gfx column, fall back to (cu_num, M, N, K) for old CSVs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…-arch kernel collisions, share build_tune_dict helpers across all 9 CK GEMM modules
…-free build_targets.py
…tion in all 8 GEMM READMEs
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR fixes non-reproducible / incorrect CK GEMM kernel selection in multi-arch builds by making tuning CSVs gfx-aware at build time, disambiguating runtime dispatch by CU count, and replacing the commutative tuple hash used by the dispatch maps.
Changes:
- Add gfx-aware build targeting (
(gfx, cu_num)), CSV filtering, and shared codegen helpers for GEMM instance generation. - Update C++ dispatch maps to key by
(cu_num, M, N, K)(and(cu_num, B, M, N, K)for batched) and introduce shared dispatch utilities + non-commutative hashing. - Update tuners, runtime Python config lookups, tests, and example tuning CSVs/schema to include a
gfxcolumn.
Reviewed changes
Copilot reviewed 62 out of 69 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_gemm_codegen.py | New CPU-only regression tests for gfx-aware build targets, CSV filtering, and Python lookup keys |
| op_tests/test_gemm_a8w8.py | Filter tuned-shape detection by (gfx, cu_num); add CSV-driven shape lists and result export |
| op_tests/test_gemm_a8w8_blockscale.py | Add CSV-driven runs + output; change perf iteration defaults |
| op_tests/configs/gemm_codegen_gfx_filter.csv | Repro tuning CSV including gfx column for regression coverage |
| op_tests/configs/gemm_codegen_gfx_filter_bpreshuffle.csv | Repro bpreshuffle tuning CSV including gfx column |
| gradlib/gradlib/GemmTuner.py | Include gfx in tuning keys and emitted results |
| csrc/include/gemm_dispatch_utils.h | New shared dispatch hash + CU detection helpers and map aliases |
| csrc/cktile_gemm_a8w8_bpreshuffle/README.md | Document gfx column in tuned CSV schema |
| csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py | Use shared build_tune_dict + write_lookup_header helpers |
| csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu | Switch lookup key to include cu_num; use shared dispatch utils |
| csrc/ck_gemm_a8w8/README.md | Document gfx column in tuned CSV schema |
| csrc/ck_gemm_a8w8/gen_instances.py | Use shared build_tune_dict + write_lookup_header helpers |
| csrc/ck_gemm_a8w8/gemm_a8w8.cu | Switch lookup key to include cu_num; use shared dispatch utils |
| csrc/ck_gemm_a8w8/gemm_a8w8_tune.py | Emit/expect gfx in tuning key/CSV schema |
| csrc/ck_gemm_a8w8_bpreshuffle/README.md | Document gfx column in tuned CSV schema |
| csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py | Use shared build_tune_dict + write_lookup_header helpers |
| csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle.cu | Switch lookup key to include cu_num; use shared dispatch utils |
| csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py | Emit/expect gfx in tuning key/CSV schema |
| csrc/ck_gemm_a8w8_blockscale/README.md | Document gfx column in tuned CSV schema |
| csrc/ck_gemm_a8w8_blockscale/gen_instances.py | Use shared build_tune_dict + write_lookup_header helpers |
| csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py | Use shared build_tune_dict + write_lookup_header helpers |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu | Switch lookup key to include cu_num; use shared dispatch utils |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py | Emit/expect gfx in tuning key/CSV schema |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu | Switch lookup key to include cu_num; use shared dispatch utils |
| csrc/ck_gemm_a8w8_blockscale_bpreshuffle/README.md | Document gfx column in tuned CSV schema |
| csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py | Use shared build_tune_dict + write_lookup_header helpers |
| csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gemm_a8w8_blockscale_bpreshuffle.cu | Switch lookup key to include cu_num; use shared dispatch utils |
| csrc/ck_gemm_a4w4_blockscale/README.md | Document gfx column in tuned CSV schema |
| csrc/ck_gemm_a4w4_blockscale/gen_instances.py | Use shared build_tune_dict + write_lookup_header helpers |
| csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale.cu | Switch lookup key to include cu_num; use shared dispatch utils |
| csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py | Include gfx in tuning info keys |
| csrc/ck_deepgemm/gen_instances.py | Use shared build_tune_dict + write_lookup_header helpers |
| csrc/ck_deepgemm/deepgemm.cu | Switch lookup key to include cu_num; use shared dispatch utils |
| csrc/ck_batched_gemm_bf16/README.md | Document gfx column in tuned CSV schema |
| csrc/ck_batched_gemm_bf16/gen_instances.py | Use shared batched tune dict + shared lookup writer |
| csrc/ck_batched_gemm_bf16/batched_gemm_bf16.cu | Switch lookup key to include cu_num; use shared batched dispatch utils |
| csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py | Emit/expect gfx in tuning info keys |
| csrc/ck_batched_gemm_a8w8/README.md | Document gfx column in tuned CSV schema |
| csrc/ck_batched_gemm_a8w8/gen_instances.py | Use shared batched tune dict + shared lookup writer |
| csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8.cu | Switch lookup key to include cu_num; use shared batched dispatch utils |
| csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py | Emit/expect gfx in tuning info keys |
| aiter/utility/base_tuner.py | Add gfx to common tuning keys and retune filtering |
| aiter/ops/gemm_op_a8w8.py | Make Python tuned-config lookups prefer (gfx, cu_num, ...) keys with fallback |
| aiter/ops/gemm_op_a4w4.py | Make Python tuned-config lookups prefer (gfx, cu_num, ...) keys with fallback |
| aiter/ops/batched_gemm_op_bf16.py | Make Python tuned-config lookups prefer (gfx, cu_num, ...) keys with fallback |
| aiter/ops/batched_gemm_op_a8w8.py | Make Python tuned-config lookups prefer (gfx, cu_num, ...) keys with fallback |
| aiter/jit/utils/chip_info.py | Add (gfx, cu_num) build targets and shared tune/lookup codegen helpers |
| aiter/jit/utils/build_targets.py | New pure-Python env-driven build target resolver + tune DF filter |
| aiter/configs/model_configs/dsv3_bf16_tuned_gemm.csv | Update schema to include gfx column |
| aiter/configs/model_configs/dsv3_a8w8_bpreshuffle_tuned_gemm.csv | Update schema to include gfx column |
| aiter/configs/model_configs/dsv3_a4w4_blockscale_tuned_gemm.csv | Update schema to include gfx column |
| aiter/configs/bf16_tuned_gemm.csv | Update schema header to include gfx column |
| aiter/configs/bf16_tuned_batched_gemm.csv | Update schema to include gfx column |
| aiter/configs/a8w8_tuned_batched_gemm.csv | Update schema to include gfx column |
| aiter/configs/a8w8_blockscale_tuned_gemm.csv | Update schema to include gfx column |
| aiter/configs/a8w8_blockscale_bpreshuffle_tuned_gemm.csv | Update schema header to include gfx column |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| block_shape = (128, 128) | ||
| TEST_NUM_ITERS = 100 | ||
|
|
||
|
|
||
| @perftest(num_iters=5) | ||
| @perftest(num_iters=TEST_NUM_ITERS) | ||
| def run_torch(x, weight, x_scale, w_scale, dtype=dtypes.bf16): |
There was a problem hiding this comment.
TEST_NUM_ITERS default is raised to 100, which can make this script significantly slower to run (and potentially too slow if used in automated runs). Consider keeping a small default (e.g. 5) and adding a CLI flag / env var to override for benchmarking.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 68 out of 75 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| saved_info = ( | ||
| "\n".join(saved_files) if saved_files else " (no files updated)" | ||
| ) | ||
| raise RuntimeError( |
There was a problem hiding this comment.
I changed this back to raising an error intentionally. The goal here is not just to auto-resolve duplicate shape entries, but to make the submitter explicitly aware that duplicates were introduced. Silent auto-fix can hide config quality issues and let bad updates slip through unnoticed. Failing fast gives clear visibility, forces cleanup at the source, and helps keep the merged tuning configs maintainable.
There was a problem hiding this comment.
Thanks @yzhou103! I see - I made the change initially because this was breaking the build_aiter_image CI (raise aborts the wheel build). Would it work to raise only when the save-back actually fails, and log a warning otherwise? The warning still includes the full duplicate list and which files were cleaned, so visibility is preserved. Or have I misunderstood something?
There was a problem hiding this comment.
Thanks, that makes sense, and I agree warning-only would be functionally safe if merged_df is also updated in the current run. My hesitation is more about policy than correctness: we would still like main to stay as clean and stable as possible, and a patch introducing many duplicate shape entries is not really the expected pattern. And we also want the person updating the CSVs to be very explicitly aware of exactly what happened and which shape entries were duplicated, rather than this becoming a mostly transparent cleanup path.
There was a problem hiding this comment.
I noticed there are many duplicate shape entries in this PR, but my understanding is that it is not adding any new tuned shapes, so I am a bit confused about where these duplicates are coming from.
There was a problem hiding this comment.
@yzhou103 quick comment: there should not be duplicates within CSVs now, and duplicates across CSVs (e.g. model configs) are not introduced by this PR.
…eshuffle_tuned_gemm_qwen3.5_397b.csv
|
to resolve the duplicated shapes quickly, we can run relative test like
|
|
hold till atom test pass |
…#2645 cherry-pick PR #2645 introduced csrc/include/gemm_dispatch_utils.h which references the SynchronizedCache<Key, T> template. That template was added by a SEPARATE earlier PR (#2221, 2026-04-15) to csrc/include/aiter_hip_common.h on main, but #2221 was never on release/v0.1.12. Cherry-picking #2645 alone fails to compile with: csrc/include/gemm_dispatch_utils.h:46:12: error: no template named 'SynchronizedCache' Hand-port just the template definition (23 lines + 2 includes) — minimum needed for the dispatch fix to compile. Skips #2221's other changes (replacing std::unordered_map usages in 12 .cu files), since release/v0.1.12 doesn't have those dispatch sites.
PR #2645 introduced 'kid' as a placeholder/stub on this line; PR #2734 later refactored the tune.py files and properly named the variable. Our cherry-pick of #2645 onto release/v0.1.12 doesn't include #2734's refactor, so the dangling 'kid' reference fails ruff F821. Use the inner loop variable 'i' (kernel index) which is what 'kid' was meant to refer to in this context (KernelID = i).
PR #2645 introduced 'kid' as a placeholder/stub on this line; PR #2734 later refactored the tune.py files and properly named the variable. Our cherry-pick of #2645 onto release/v0.1.12 doesn't include #2734's refactor, so the dangling 'kid' reference fails ruff F821. Use the inner loop variable 'i' (kernel index) which is what 'kid' was meant to refer to in this context (KernelID = i).
[release/v0.1.12] Cherry-pick #2645 (multi-arch CK GEMM dispatch) + SynchronizedCache backport for v0.1.12.post2
Problem
Three related bugs caused incorrect kernel dispatch and non-reproducible builds, plus a fourth critical limitation in the build system.
1. Non-deterministic builds. CSV rows were filtered by the live GPU's
cu_numat build time. CI builds without a GPU skipped this filter, accidentally compiling all CSV rows. The sameGPU_ARCHSsetting produced different.sofiles depending on whether a GPU was present — making builds non-reproducible and CI coverage misleading.2. Multi-arch dispatch collision. C++ dispatch maps were keyed by
(M, N, K)only. InGPU_ARCHS=gfx942;gfx950builds, shapes shared between architectures silently overwrote each other (last-writer-wins), selecting the wrong kernel with no error or warning.3. Hash collisions.
IntTupleHashused plain XOR, which is order-independent — permutations like(304, 128, 7168)and(304, 7168, 128)hash identically. On 160 real GEMM shapes, only 64 unique buckets were produced (60% collision rate), degrading most lookups to linear scan.4. No retuning path during
PREBUILD_KERNELSbuilds. When installing aiter withPREBUILD_KERNELS=1, there was no way to retune GEMM shapes against the target machine. Tuning CSVs shipped with the repo are tuned on specific hardware — if the target machine differs (different architecture, binned variant, or simply newer hardware with better kernels), all inference uses the pre-existing CSV results unchanged. There was no mechanism to rebenchmark those shapes on the actual deployment GPU and update the kernels accordingly during installation.Fix
Build targeting —
get_build_targets()inchip_info.pyreturns(gfx, cu_num)pairs fromGPU_ARCHS(or the live GPU when unset), deterministically regardless of GPU presence. Addsgfxas the first column in all tuning CSV schemas.GPU_ARCHS=nativenow correctly routes to live-GPU detection instead of raising.Runtime arch detection —
get_gfx_runtime()inchip_info.pyalways queries the live GPU via rocminfo, independent ofGPU_ARCHS. The four GEMM Python dispatch modules (gemm_op_a8w8,gemm_op_a4w4,batched_gemm_op_a8w8,batched_gemm_op_bf16) use this for runtime arch selection.C++ dispatch key — extended to
(gfx, cu_num, M, N, K)for non-batched modules and(gfx, cu_num, B, M, N, K)for batched modules. Both are resolved at runtime from the current HIP device.get_device_cu_num()andget_device_gfx()are new helpers ingemm_dispatch_utils.h, cached per device ID viaSynchronizedCacheso that a single process callinghipSetDevice()across GPUs of different architectures always dispatches to the correct kernel.Hash function — replaces per-module XOR
IntTupleHashwith a shared boost-style mixing hash (0x9e3779b9constant). On 160 real GEMM shapes: 160 unique buckets vs 64 before, zero collisions.CSV pretune workflow —
aiter/utility/pretune.pyadds a retuning path forPREBUILD_KERNELSbuilds. All tuned shapes in the CSVs are rebenchmarked on the live GPU, results tagged with(gfx, cu_num), and the inference.sorebuilt. Two modes:PRETUNE_MODULES=module_gemm_a8w8_blockscale_tune python setup.py develop— runs automatically after thePREBUILD_KERNELScompilation phase.python3 aiter/utility/pretune.py module_gemm_a8w8_blockscale_tune— retunes on an already-installed aiter without a full rebuild.Existing rows for other architectures are preserved — a single CSV can contain tuning results for multiple GPUs.
Scope
The diff is large because the same correctness fix is applied mechanically across 9 modules. The novel logic is confined to a small number of files — see the reviewer guide. All 10 codegen scripts and 10 C++ files must move together; any partial split leaves a broken dispatch map.
Files changed
aiter/jit/utils/build_targets.pyGFX_CU_NUM_MAP,get_build_targets_env(),filter_tune_df(),_parse_gpu_archs_env()(torch-free)aiter/jit/utils/chip_info.pyget_build_targets(),get_gfx_runtime(),build_tune_dict(),write_lookup_header(); fixGPU_ARCHS=nativeaiter/ops/gemm_op_*.py(4 files)get_gfx()→get_gfx_runtime()csrc/include/gemm_dispatch_utils.hGemmDispatchHash,GemmDispatchMap<>,get_device_cu_num(),get_device_gfx()gen_instances.py+ 1×gen_instances_cktile.py(M,N,K)→(gfx,cu_num,M,N,K); shared helpers.cufilesIntTupleHashremoved;GemmDispatchMap<>type aliasgfxcolumn prepended (fmoe CSVs intentionally excluded — different dispatch path)README.mdaiter/jit/core.pyupdate_config_files(): duplicate CSV entries →logger.warninginstead of crash;dedup_keysnow includesgfxwhen present in merged CSV (prevents false duplicate detection across arch-sharing-cu_num targets)aiter/utility/base_tuner.py,gemm_a8w8_blockscale_tune.py,gemm_a4w4_blockscale_tune.py(gfx, cu_num)instead ofcu_numalone; auto-populategfxwhen absent;get_gfx→get_gfx_runtimeso tuned CSV rows are tagged with live GPU archsetup.pyPRETUNE_MODULESenv var; invokerun_pretune_modules()afterThreadPoolExecutorcompletesaiter/utility/pretune.pyPRETUNE_MODULESbuild-time and standalone retuning CLI;run_pretune(),run_pretune_modules(),_parse_module_list()op_tests/test_pretune.py_unsupportedexclusion), no GPU requiredop_tests/test_gemm_codegen.pyGPU_ARCHSguardcsrc/ck_gemm_a8w8{,_blockscale,_bpreshuffle}/gemm_*_tune.py(3 files)_clear_op_caches(): clear module-level CSV dict caches alongsidelru_cache—--compare/--run_configbenchmark flows now pick up updated CSVs correctlyFollow-up
get_CKGEMM_config) and C++ dispatch are decoupled — the Python side finding a shape does not guarantee the C++ map dispatches it. Whether to unify these is left as a follow-up.get_gfx()picks the last arch in multi-archGPU_ARCHSlists (e.g.gfx950on MI300X withGPU_ARCHS=gfx942;gfx950). The codegen path (gen_instances.py) is resolved viaget_build_targets()+filter_tune_df(). The dispatch key and tuning path now useget_gfx_runtime(). Broader runtime callers (fused_moe.py,mla.py,attention.py, capability gating in general) still useget_gfx()— left as a follow-up..socount. Inspecting generated*_lookup.hheaders for{"gfx942",and{"gfx950",key prefixes would catch dispatch regressions earlier.Testing
test_gemm_codegen.py): 31 tests coveringget_build_targets()routing (including separator-onlyGPU_ARCHSguard),build_tune_dict()filtering,write_lookup_header()key format — no GPU required. Runtime dispatch section skipped gracefully without torch.test_pretune.py): 84 tests covering script resolution, CSV deduplication, multi-path merging, module list parsing,_unsupportedexclusion — no GPU required.GPU_ARCHS=gfx942;gfx950build — lookup headers contain both{"gfx942",and{"gfx950",keys.test_gemm_a8w8.pypasses on MI300X.