Ensure GemmTuner doesn't generate invalid input combinations for SplitK kernels#2721
Ensure GemmTuner doesn't generate invalid input combinations for SplitK kernels#2721
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Tests verify that asm_gemm_all_solutions correctly filters candidates where gdx*gdy exceeds the 1024-entry semaphore array limit, while preserving valid small-grid and boundary-grid combinations. No GPU required; aiter stack is stubbed for offline execution.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
025ba36 to
a294190
Compare
There was a problem hiding this comment.
Pull request overview
Prevents GemmTuner.asm_gemm_all_solutions() from generating SplitK ASM tuning tasks whose grid size would exceed the semaphore workspace capacity, avoiding potential out-of-bounds writes during tuning for large GEMM shapes.
Changes:
- Add a SplitK-only guard in
asm_gemm_all_solutions()to skip candidates wheregdx * gdy > 1024. - Add offline unit tests validating filtering behavior for large, small, and boundary grid sizes.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
gradlib/gradlib/GemmTuner.py |
Adds SplitK task filtering based on computed grid size to avoid semaphore overflow scenarios. |
gradlib/test_gemm_tuner_splitk.py |
Adds offline tests (with stubs) to validate the new SplitK grid-size guard behavior. |
Comments suppressed due to low confidence (1)
gradlib/gradlib/GemmTuner.py:405
- The 1024 limit is a hard-coded magic number here. Since the semaphore workspace size appears to be defined elsewhere (e.g.
aiter/ops/gemm_op_a16w16.pyallocates a fixed(16, 64)semaphore tensor = 1024 entries), consider centralizing this as a named constant (or deriving it from the allocator) to prevent future drift and to document why 1024 is the correct bound.
) and get_gfx() == "gfx950":
logger.warning(
f"ASM gemm only supports indtype=bf16 and outdtype=bf16 and k%256==0 and not scaleAB is supported in {get_gfx()}, but actual indtype is {self.indtype}, outdtype is {self.outdtype}, k is {self.k}, scaleAB is {self.scaleAB}"
)
self.asm_gtimedf = pd.DataFrame(columns=["gtimems", "libtype"])
return []
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| } | ||
| for name, mod in stubs.items(): | ||
| sys.modules.setdefault(name, mod) | ||
|
|
||
|
|
||
| _install_stubs() |
There was a problem hiding this comment.
Installing stub modules into sys.modules at import time can leak into the rest of the test process (e.g., later tests may unexpectedly import the stubs instead of the real aiter package). Consider scoping this to the test module via unittest.mock.patch.dict(sys.modules, ...) (and/or only stubbing when aiter cannot be imported) so other tests aren’t impacted.
There was a problem hiding this comment.
@copilot apply changes based on this feedback
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
|
||
| class TestSplitKSemaphoreGuard(unittest.TestCase): | ||
|
|
||
| @patch("gradlib.GemmTuner.get_gfx", return_value="gfx942") |
| class TestSplitKSemaphoreGuard(unittest.TestCase): | ||
|
|
||
| @patch("gradlib.GemmTuner.get_gfx", return_value="gfx942") | ||
| @patch("gradlib.GemmTuner.generate_data", return_value=None) |
| # tile 64x64 on a 4096x4096 grid => gdx=64, gdy=64 => 4096 > 1024 | ||
| gemm = _make_gemm(m=4096, n=4096, k=256) | ||
|
|
||
| with patch.object(Gemm, "get_asm_kernels", |
| f"Task with splitK={splitK} has grid {gdx}x{gdy}={gdx*gdy} > 1024", | ||
| ) | ||
|
|
||
| @patch("gradlib.GemmTuner.get_gfx", return_value="gfx942") |
| ) | ||
|
|
||
| @patch("gradlib.GemmTuner.get_gfx", return_value="gfx942") | ||
| @patch("gradlib.GemmTuner.generate_data", return_value=None) |
| # tile 128x128 on 128x128 => gdx=1, gdy=1 => 1 <= 1024 | ||
| gemm = _make_gemm(m=128, n=128, k=256) | ||
|
|
||
| with patch.object(Gemm, "get_asm_kernels", |
| self.assertGreater(len(splitk_tasks), 0, | ||
| "Expected SplitK tasks for a small grid, got none") | ||
|
|
||
| @patch("gradlib.GemmTuner.get_gfx", return_value="gfx942") |
| "Expected SplitK tasks for a small grid, got none") | ||
|
|
||
| @patch("gradlib.GemmTuner.get_gfx", return_value="gfx942") | ||
| @patch("gradlib.GemmTuner.generate_data", return_value=None) |
| # tile=64, m=64*32=2048, n=64*32=2048 => gdx=32, gdy=32 => exactly 1024 | ||
| gemm = _make_gemm(m=2048, n=2048, k=256) | ||
|
|
||
| with patch.object(Gemm, "get_asm_kernels", |
Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/60c3afdd-caeb-41ad-b1b6-ab84a202a9e2 Co-authored-by: apicciau <227765539+apicciau@users.noreply.github.com>
Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/60c3afdd-caeb-41ad-b1b6-ab84a202a9e2 Co-authored-by: apicciau <227765539+apicciau@users.noreply.github.com>
…x1250 Introduces the FlyDSL A16W16 GEMM kernel for RDNA4 (gfx1250) and integrates it as a first-class tunable backend in GemmTuner, alongside the existing splitk_hgemm and ASM paths. New files: - aiter/ops/flydsl/kernels/gemm_a16w16_gfx1250.py: WMMA 16x16x32 kernel using RDNA4 wave32; handles K-padding and N-stride internally; supports fp16/bf16 input, configurable tiling (tile_m/n/k), warp layout (m/n_warp), double-buffering (num_buffers), waves_per_eu, and L2 prefetch distance Changes to existing files: - aiter/ops/flydsl/gemm_kernels.py: add get_flydsl_a16w16_gfx1250_kernels() catalog and get_flydsl_a16w16_gfx1250_kernel_params() lookup; kernel name encodes all config parameters for reversible CSV serialisation - gradlib/gradlib/GemmTuner.py: import the new kernel; add run_flydsl_gemm_a16w16() run function; add flydsl_a16w16_gemm_all_sols() enumerator; route gfx1250 through the a16w16 path in run_asm_triton_sols() while other architectures continue using the existing splitk_hgemm path; also restores the ASM SplitK semaphore guard (gdx*gdy <= 1024) that was missing on main (also tracked in PR #2721) - aiter/tuned_gemm.py: add flydsl_a16w16_gemm() dispatch function; update the flydsl config lookup to resolve a16w16 kernel names, falling back to splitk_hgemm; select the correct call site based on the resolved config
| @@ -0,0 +1,197 @@ | |||
| """ | |||
| Tests for GemmTuner.asm_gemm_all_solutions SplitK semaphore guard. | |||
There was a problem hiding this comment.
I'd suggest we do not add tuner test here. I've add tuning_tests in op_tests which runs daily.
| if splitK > 1: | ||
| gdx = (self.n + tile_n - 1) // tile_n | ||
| gdy = (self.m + tile_m - 1) // tile_m | ||
| if gdx * gdy > 1024: |
There was a problem hiding this comment.
My earlier concern was that this 1024 limit is not an independent tuning rule. It is tied to the current semaphore workspace size in gemm_op_a16w16.py, which is a fixed (16, 64) allocation. If we hardcode 1024 again in GemmTuner.py, the tuner will need to be updated separately whenever that workspace size changes, and that is easy to miss.
Motivation
ASM SplitK kernels allocate a semaphore array of size
gdx * gdy(grid X × grid Y). When the grid exceeds 1024 entries, the semaphore write goes out-of-bounds, causing silent corruption or crashes during tuning.Large matrix shapes (e.g. M=4096, N=4096 with a 64×64 tile) produce
gdx * gdy = 4096, well above the limit.Technical Details
Adds a guard in
Gemm.asm_gemm_all_solutions()(gradlib/gradlib/GemmTuner.py) that skips any SplitK candidate wheregdx * gdy > 1024before appending to the task list.The check only applies when
splitK > 1(clean kernels are unaffected).Test Plan
Added offline unit tests in
gradlib/test_gemm_tuner_splitk.pycovering three cases:gdx*gdy > 1024): all generated tasks must satisfy the constraintgdx*gdy = 1): SplitK tasks must still be generated (no false filtering)gdx*gdy = 1024): tasks at the exact limit must be keptTests stub the full aiter/ROCm stack and run without a GPU:
Test Result
Submission Checklist