Skip to content

fix(ck_gemm): fix multi-arch build targeting and kernel dispatch across all CK GEMM modules#2645

Merged
valarLip merged 75 commits intomainfrom
fix/gemm_codegen_gfx_build_targets
Apr 17, 2026
Merged

fix(ck_gemm): fix multi-arch build targeting and kernel dispatch across all CK GEMM modules#2645
valarLip merged 75 commits intomainfrom
fix/gemm_codegen_gfx_build_targets

Conversation

@eppaneamd
Copy link
Copy Markdown
Contributor

@eppaneamd eppaneamd commented Apr 7, 2026

Problem

Three related bugs caused incorrect kernel dispatch and non-reproducible builds, plus a fourth critical limitation in the build system.

1. Non-deterministic builds. CSV rows were filtered by the live GPU's cu_num at build time. CI builds without a GPU skipped this filter, accidentally compiling all CSV rows. The same GPU_ARCHS setting produced different .so files depending on whether a GPU was present — making builds non-reproducible and CI coverage misleading.

2. Multi-arch dispatch collision. C++ dispatch maps were keyed by (M, N, K) only. In GPU_ARCHS=gfx942;gfx950 builds, shapes shared between architectures silently overwrote each other (last-writer-wins), selecting the wrong kernel with no error or warning.

3. Hash collisions. IntTupleHash used plain XOR, which is order-independent — permutations like (304, 128, 7168) and (304, 7168, 128) hash identically. On 160 real GEMM shapes, only 64 unique buckets were produced (60% collision rate), degrading most lookups to linear scan.

4. No retuning path during PREBUILD_KERNELS builds. When installing aiter with PREBUILD_KERNELS=1, there was no way to retune GEMM shapes against the target machine. Tuning CSVs shipped with the repo are tuned on specific hardware — if the target machine differs (different architecture, binned variant, or simply newer hardware with better kernels), all inference uses the pre-existing CSV results unchanged. There was no mechanism to rebenchmark those shapes on the actual deployment GPU and update the kernels accordingly during installation.

Fix

Build targetingget_build_targets() in chip_info.py returns (gfx, cu_num) pairs from GPU_ARCHS (or the live GPU when unset), deterministically regardless of GPU presence. Adds gfx as the first column in all tuning CSV schemas. GPU_ARCHS=native now correctly routes to live-GPU detection instead of raising.

Runtime arch detectionget_gfx_runtime() in chip_info.py always queries the live GPU via rocminfo, independent of GPU_ARCHS. The four GEMM Python dispatch modules (gemm_op_a8w8, gemm_op_a4w4, batched_gemm_op_a8w8, batched_gemm_op_bf16) use this for runtime arch selection.

C++ dispatch key — extended to (gfx, cu_num, M, N, K) for non-batched modules and (gfx, cu_num, B, M, N, K) for batched modules. Both are resolved at runtime from the current HIP device. get_device_cu_num() and get_device_gfx() are new helpers in gemm_dispatch_utils.h, cached per device ID via SynchronizedCache so that a single process calling hipSetDevice() across GPUs of different architectures always dispatches to the correct kernel.

Hash function — replaces per-module XOR IntTupleHash with a shared boost-style mixing hash (0x9e3779b9 constant). On 160 real GEMM shapes: 160 unique buckets vs 64 before, zero collisions.

CSV pretune workflowaiter/utility/pretune.py adds a retuning path for PREBUILD_KERNELS builds. All tuned shapes in the CSVs are rebenchmarked on the live GPU, results tagged with (gfx, cu_num), and the inference .so rebuilt. Two modes:

  • Build-time: PRETUNE_MODULES=module_gemm_a8w8_blockscale_tune python setup.py develop — runs automatically after the PREBUILD_KERNELS compilation phase.
  • Standalone: python3 aiter/utility/pretune.py module_gemm_a8w8_blockscale_tune — retunes on an already-installed aiter without a full rebuild.

Existing rows for other architectures are preserved — a single CSV can contain tuning results for multiple GPUs.

Scope

The diff is large because the same correctness fix is applied mechanically across 9 modules. The novel logic is confined to a small number of files — see the reviewer guide. All 10 codegen scripts and 10 C++ files must move together; any partial split leaves a broken dispatch map.

Files changed

File Change
aiter/jit/utils/build_targets.py NewGFX_CU_NUM_MAP, get_build_targets_env(), filter_tune_df(), _parse_gpu_archs_env() (torch-free)
aiter/jit/utils/chip_info.py get_build_targets(), get_gfx_runtime(), build_tune_dict(), write_lookup_header(); fix GPU_ARCHS=native
aiter/ops/gemm_op_*.py (4 files) get_gfx()get_gfx_runtime()
csrc/include/gemm_dispatch_utils.h NewGemmDispatchHash, GemmDispatchMap<>, get_device_cu_num(), get_device_gfx()
gen_instances.py + 1× gen_instances_cktile.py Dispatch key (M,N,K)(gfx,cu_num,M,N,K); shared helpers
10× C++ .cu files IntTupleHash removed; GemmDispatchMap<> type alias
8 primary + 17 model_configs tuning CSVs gfx column prepended (fmoe CSVs intentionally excluded — different dispatch path)
README.md CSV schema examples updated
aiter/jit/core.py Fix update_config_files(): duplicate CSV entries → logger.warning instead of crash; dedup_keys now includes gfx when present in merged CSV (prevents false duplicate detection across arch-sharing-cu_num targets)
aiter/utility/base_tuner.py, gemm_a8w8_blockscale_tune.py, gemm_a4w4_blockscale_tune.py Filter tuning data by (gfx, cu_num) instead of cu_num alone; auto-populate gfx when absent; get_gfxget_gfx_runtime so tuned CSV rows are tagged with live GPU arch
setup.py Read PRETUNE_MODULES env var; invoke run_pretune_modules() after ThreadPoolExecutor completes
aiter/utility/pretune.py NewPRETUNE_MODULES build-time and standalone retuning CLI; run_pretune(), run_pretune_modules(), _parse_module_list()
op_tests/test_pretune.py New — 9 test functions (script resolution, CSV deduplication, multi-path merging, module list parsing, _unsupported exclusion), no GPU required
op_tests/test_gemm_codegen.py New unit tests for build targeting and runtime dispatch; separator-only GPU_ARCHS guard
csrc/ck_gemm_a8w8{,_blockscale,_bpreshuffle}/gemm_*_tune.py (3 files) Fix _clear_op_caches(): clear module-level CSV dict caches alongside lru_cache--compare/--run_config benchmark flows now pick up updated CSVs correctly

Follow-up

  • Python CSV lookup (get_CKGEMM_config) and C++ dispatch are decoupled — the Python side finding a shape does not guarantee the C++ map dispatches it. Whether to unify these is left as a follow-up.
  • get_gfx() picks the last arch in multi-arch GPU_ARCHS lists (e.g. gfx950 on MI300X with GPU_ARCHS=gfx942;gfx950). The codegen path (gen_instances.py) is resolved via get_build_targets() + filter_tune_df(). The dispatch key and tuning path now use get_gfx_runtime(). Broader runtime callers (fused_moe.py, mla.py, attention.py, capability gating in general) still use get_gfx() — left as a follow-up.
  • CI validation only checks .so count. Inspecting generated *_lookup.h headers for {"gfx942", and {"gfx950", key prefixes would catch dispatch regressions earlier.

Testing

  • Unit tests (test_gemm_codegen.py): 31 tests covering get_build_targets() routing (including separator-only GPU_ARCHS guard), build_tune_dict() filtering, write_lookup_header() key format — no GPU required. Runtime dispatch section skipped gracefully without torch.
  • Pretune tests (test_pretune.py): 84 tests covering script resolution, CSV deduplication, multi-path merging, module list parsing, _unsupported exclusion — no GPU required.
  • Repro CSVs: shapes from the original regression verified correct dispatch on MI300X.
  • Multi-arch: GPU_ARCHS=gfx942;gfx950 build — lookup headers contain both {"gfx942", and {"gfx950", keys.
  • Runtime: test_gemm_a8w8.py passes on MI300X.

eppaneamd and others added 21 commits March 27, 2026 15:27
The merge commit 6a18cd6 accidentally preserved conflict markers in
gemm_op_a4w4.py. Apply the gfx-aware dispatch fix (same pattern as
gemm_op_a8w8.py) — use (gfx, cu_num, M, N, K) key when the CSV has a
gfx column, fall back to (cu_num, M, N, K) for old CSVs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…-arch kernel collisions, share build_tune_dict helpers across all 9 CK GEMM modules
@eppaneamd eppaneamd requested review from a team and Copilot April 7, 2026 21:30
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 7, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2645 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes non-reproducible / incorrect CK GEMM kernel selection in multi-arch builds by making tuning CSVs gfx-aware at build time, disambiguating runtime dispatch by CU count, and replacing the commutative tuple hash used by the dispatch maps.

Changes:

  • Add gfx-aware build targeting ((gfx, cu_num)), CSV filtering, and shared codegen helpers for GEMM instance generation.
  • Update C++ dispatch maps to key by (cu_num, M, N, K) (and (cu_num, B, M, N, K) for batched) and introduce shared dispatch utilities + non-commutative hashing.
  • Update tuners, runtime Python config lookups, tests, and example tuning CSVs/schema to include a gfx column.

Reviewed changes

Copilot reviewed 62 out of 69 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
op_tests/test_gemm_codegen.py New CPU-only regression tests for gfx-aware build targets, CSV filtering, and Python lookup keys
op_tests/test_gemm_a8w8.py Filter tuned-shape detection by (gfx, cu_num); add CSV-driven shape lists and result export
op_tests/test_gemm_a8w8_blockscale.py Add CSV-driven runs + output; change perf iteration defaults
op_tests/configs/gemm_codegen_gfx_filter.csv Repro tuning CSV including gfx column for regression coverage
op_tests/configs/gemm_codegen_gfx_filter_bpreshuffle.csv Repro bpreshuffle tuning CSV including gfx column
gradlib/gradlib/GemmTuner.py Include gfx in tuning keys and emitted results
csrc/include/gemm_dispatch_utils.h New shared dispatch hash + CU detection helpers and map aliases
csrc/cktile_gemm_a8w8_bpreshuffle/README.md Document gfx column in tuned CSV schema
csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py Use shared build_tune_dict + write_lookup_header helpers
csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu Switch lookup key to include cu_num; use shared dispatch utils
csrc/ck_gemm_a8w8/README.md Document gfx column in tuned CSV schema
csrc/ck_gemm_a8w8/gen_instances.py Use shared build_tune_dict + write_lookup_header helpers
csrc/ck_gemm_a8w8/gemm_a8w8.cu Switch lookup key to include cu_num; use shared dispatch utils
csrc/ck_gemm_a8w8/gemm_a8w8_tune.py Emit/expect gfx in tuning key/CSV schema
csrc/ck_gemm_a8w8_bpreshuffle/README.md Document gfx column in tuned CSV schema
csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py Use shared build_tune_dict + write_lookup_header helpers
csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle.cu Switch lookup key to include cu_num; use shared dispatch utils
csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py Emit/expect gfx in tuning key/CSV schema
csrc/ck_gemm_a8w8_blockscale/README.md Document gfx column in tuned CSV schema
csrc/ck_gemm_a8w8_blockscale/gen_instances.py Use shared build_tune_dict + write_lookup_header helpers
csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py Use shared build_tune_dict + write_lookup_header helpers
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu Switch lookup key to include cu_num; use shared dispatch utils
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py Emit/expect gfx in tuning key/CSV schema
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu Switch lookup key to include cu_num; use shared dispatch utils
csrc/ck_gemm_a8w8_blockscale_bpreshuffle/README.md Document gfx column in tuned CSV schema
csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py Use shared build_tune_dict + write_lookup_header helpers
csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gemm_a8w8_blockscale_bpreshuffle.cu Switch lookup key to include cu_num; use shared dispatch utils
csrc/ck_gemm_a4w4_blockscale/README.md Document gfx column in tuned CSV schema
csrc/ck_gemm_a4w4_blockscale/gen_instances.py Use shared build_tune_dict + write_lookup_header helpers
csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale.cu Switch lookup key to include cu_num; use shared dispatch utils
csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py Include gfx in tuning info keys
csrc/ck_deepgemm/gen_instances.py Use shared build_tune_dict + write_lookup_header helpers
csrc/ck_deepgemm/deepgemm.cu Switch lookup key to include cu_num; use shared dispatch utils
csrc/ck_batched_gemm_bf16/README.md Document gfx column in tuned CSV schema
csrc/ck_batched_gemm_bf16/gen_instances.py Use shared batched tune dict + shared lookup writer
csrc/ck_batched_gemm_bf16/batched_gemm_bf16.cu Switch lookup key to include cu_num; use shared batched dispatch utils
csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py Emit/expect gfx in tuning info keys
csrc/ck_batched_gemm_a8w8/README.md Document gfx column in tuned CSV schema
csrc/ck_batched_gemm_a8w8/gen_instances.py Use shared batched tune dict + shared lookup writer
csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8.cu Switch lookup key to include cu_num; use shared batched dispatch utils
csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py Emit/expect gfx in tuning info keys
aiter/utility/base_tuner.py Add gfx to common tuning keys and retune filtering
aiter/ops/gemm_op_a8w8.py Make Python tuned-config lookups prefer (gfx, cu_num, ...) keys with fallback
aiter/ops/gemm_op_a4w4.py Make Python tuned-config lookups prefer (gfx, cu_num, ...) keys with fallback
aiter/ops/batched_gemm_op_bf16.py Make Python tuned-config lookups prefer (gfx, cu_num, ...) keys with fallback
aiter/ops/batched_gemm_op_a8w8.py Make Python tuned-config lookups prefer (gfx, cu_num, ...) keys with fallback
aiter/jit/utils/chip_info.py Add (gfx, cu_num) build targets and shared tune/lookup codegen helpers
aiter/jit/utils/build_targets.py New pure-Python env-driven build target resolver + tune DF filter
aiter/configs/model_configs/dsv3_bf16_tuned_gemm.csv Update schema to include gfx column
aiter/configs/model_configs/dsv3_a8w8_bpreshuffle_tuned_gemm.csv Update schema to include gfx column
aiter/configs/model_configs/dsv3_a4w4_blockscale_tuned_gemm.csv Update schema to include gfx column
aiter/configs/bf16_tuned_gemm.csv Update schema header to include gfx column
aiter/configs/bf16_tuned_batched_gemm.csv Update schema to include gfx column
aiter/configs/a8w8_tuned_batched_gemm.csv Update schema to include gfx column
aiter/configs/a8w8_blockscale_tuned_gemm.csv Update schema to include gfx column
aiter/configs/a8w8_blockscale_bpreshuffle_tuned_gemm.csv Update schema header to include gfx column

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/gemm_op_a8w8.py
Comment thread aiter/ops/gemm_op_a8w8.py
Comment thread aiter/jit/utils/chip_info.py
Comment thread aiter/jit/utils/chip_info.py
Comment thread csrc/include/gemm_dispatch_utils.h Outdated
Comment on lines 21 to 26
block_shape = (128, 128)
TEST_NUM_ITERS = 100


@perftest(num_iters=5)
@perftest(num_iters=TEST_NUM_ITERS)
def run_torch(x, weight, x_scale, w_scale, dtype=dtypes.bf16):
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TEST_NUM_ITERS default is raised to 100, which can make this script significantly slower to run (and potentially too slow if used in automated runs). Consider keeping a small default (e.g. 5) and adding a CLI flag / env var to override for benchmarking.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 68 out of 75 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/utility/pretune.py
Comment thread setup.py
Comment thread csrc/include/gemm_dispatch_utils.h
Comment thread op_tests/test_gemm_a8w8_blockscale.py
@eppaneamd eppaneamd requested a review from valarLip April 15, 2026 18:23
@valarLip valarLip requested a review from yzhou103 April 16, 2026 02:18
Comment thread aiter/jit/core.py
saved_info = (
"\n".join(saved_files) if saved_files else " (no files updated)"
)
raise RuntimeError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this back to raising an error intentionally. The goal here is not just to auto-resolve duplicate shape entries, but to make the submitter explicitly aware that duplicates were introduced. Silent auto-fix can hide config quality issues and let bad updates slip through unnoticed. Failing fast gives clear visibility, forces cleanup at the source, and helps keep the merged tuning configs maintainable.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yzhou103! I see - I made the change initially because this was breaking the build_aiter_image CI (raise aborts the wheel build). Would it work to raise only when the save-back actually fails, and log a warning otherwise? The warning still includes the full duplicate list and which files were cleaned, so visibility is preserved. Or have I misunderstood something?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that makes sense, and I agree warning-only would be functionally safe if merged_df is also updated in the current run. My hesitation is more about policy than correctness: we would still like main to stay as clean and stable as possible, and a patch introducing many duplicate shape entries is not really the expected pattern. And we also want the person updating the CSVs to be very explicitly aware of exactly what happened and which shape entries were duplicated, rather than this becoming a mostly transparent cleanup path.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed there are many duplicate shape entries in this PR, but my understanding is that it is not adding any new tuned shapes, so I am a bit confused about where these duplicates are coming from.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzhou103 quick comment: there should not be duplicates within CSVs now, and duplicates across CSVs (e.g. model configs) are not introduced by this PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example:

qwen3_235b — introduced in ba596dc (tune gemm and moe for Qwen3 MoE models #1585)
qwen3.5_397b — introduced in 3fcfeac (tuned qwen3.5 gemm #2485)

@yzhou103
Copy link
Copy Markdown
Contributor

to resolve the duplicated shapes quickly, we can run relative test like

python op_tests/test_gemm_a8w8_blockscale.py

@valarLip
Copy link
Copy Markdown
Collaborator

hold till atom test pass

@valarLip valarLip merged commit 727253a into main Apr 17, 2026
57 of 59 checks passed
@valarLip valarLip deleted the fix/gemm_codegen_gfx_build_targets branch April 17, 2026 09:00
sunway513 added a commit that referenced this pull request Apr 22, 2026
…#2645 cherry-pick

PR #2645 introduced csrc/include/gemm_dispatch_utils.h which references
the SynchronizedCache<Key, T> template. That template was added by a
SEPARATE earlier PR (#2221, 2026-04-15) to csrc/include/aiter_hip_common.h
on main, but #2221 was never on release/v0.1.12.

Cherry-picking #2645 alone fails to compile with:
  csrc/include/gemm_dispatch_utils.h:46:12: error: no template named 'SynchronizedCache'

Hand-port just the template definition (23 lines + 2 includes) — minimum
needed for the dispatch fix to compile. Skips #2221's other changes
(replacing std::unordered_map usages in 12 .cu files), since release/v0.1.12
doesn't have those dispatch sites.
sunway513 added a commit that referenced this pull request Apr 23, 2026
PR #2645 introduced 'kid' as a placeholder/stub on this line; PR #2734
later refactored the tune.py files and properly named the variable.
Our cherry-pick of #2645 onto release/v0.1.12 doesn't include #2734's
refactor, so the dangling 'kid' reference fails ruff F821.

Use the inner loop variable 'i' (kernel index) which is what 'kid'
was meant to refer to in this context (KernelID = i).
sunway513 added a commit that referenced this pull request Apr 23, 2026
PR #2645 introduced 'kid' as a placeholder/stub on this line; PR #2734
later refactored the tune.py files and properly named the variable.
Our cherry-pick of #2645 onto release/v0.1.12 doesn't include #2734's
refactor, so the dangling 'kid' reference fails ruff F821.

Use the inner loop variable 'i' (kernel index) which is what 'kid'
was meant to refer to in this context (KernelID = i).
sunway513 added a commit that referenced this pull request Apr 23, 2026
[release/v0.1.12] Cherry-pick #2645 (multi-arch CK GEMM dispatch) + SynchronizedCache backport for v0.1.12.post2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants