Skip to content

Ensure GemmTuner doesn't generate invalid input combinations for SplitK kernels#2721

Open
apicciau wants to merge 6 commits intomainfrom
apicciau/fix_asm_splitk
Open

Ensure GemmTuner doesn't generate invalid input combinations for SplitK kernels#2721
apicciau wants to merge 6 commits intomainfrom
apicciau/fix_asm_splitk

Conversation

@apicciau
Copy link
Copy Markdown
Contributor

@apicciau apicciau commented Apr 13, 2026

Motivation

ASM SplitK kernels allocate a semaphore array of size gdx * gdy (grid X × grid Y). When the grid exceeds 1024 entries, the semaphore write goes out-of-bounds, causing silent corruption or crashes during tuning.

Large matrix shapes (e.g. M=4096, N=4096 with a 64×64 tile) produce gdx * gdy = 4096, well above the limit.

Technical Details

Adds a guard in Gemm.asm_gemm_all_solutions() (gradlib/gradlib/GemmTuner.py) that skips any SplitK candidate where gdx * gdy > 1024 before appending to the task list.

The check only applies when splitK > 1 (clean kernels are unaffected).

Test Plan

Added offline unit tests in gradlib/test_gemm_tuner_splitk.py covering three cases:

  • Large grid (gdx*gdy > 1024): all generated tasks must satisfy the constraint
  • Small grid (gdx*gdy = 1): SplitK tasks must still be generated (no false filtering)
  • Boundary (gdx*gdy = 1024): tasks at the exact limit must be kept

Tests stub the full aiter/ROCm stack and run without a GPU:

cd gradlib && python -m pytest test_gemm_tuner_splitk.py -v

Test Result

3 passed in 0.93s

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2721 --add-label <label>

@apicciau apicciau marked this pull request as ready for review April 17, 2026 15:08
@apicciau apicciau requested review from a team and Copilot April 17, 2026 15:08
apicciau and others added 3 commits April 17, 2026 17:10
Tests verify that asm_gemm_all_solutions correctly filters candidates
where gdx*gdy exceeds the 1024-entry semaphore array limit, while
preserving valid small-grid and boundary-grid combinations.

No GPU required; aiter stack is stubbed for offline execution.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@apicciau apicciau force-pushed the apicciau/fix_asm_splitk branch from 025ba36 to a294190 Compare April 17, 2026 15:10
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Prevents GemmTuner.asm_gemm_all_solutions() from generating SplitK ASM tuning tasks whose grid size would exceed the semaphore workspace capacity, avoiding potential out-of-bounds writes during tuning for large GEMM shapes.

Changes:

  • Add a SplitK-only guard in asm_gemm_all_solutions() to skip candidates where gdx * gdy > 1024.
  • Add offline unit tests validating filtering behavior for large, small, and boundary grid sizes.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
gradlib/gradlib/GemmTuner.py Adds SplitK task filtering based on computed grid size to avoid semaphore overflow scenarios.
gradlib/test_gemm_tuner_splitk.py Adds offline tests (with stubs) to validate the new SplitK grid-size guard behavior.
Comments suppressed due to low confidence (1)

gradlib/gradlib/GemmTuner.py:405

  • The 1024 limit is a hard-coded magic number here. Since the semaphore workspace size appears to be defined elsewhere (e.g. aiter/ops/gemm_op_a16w16.py allocates a fixed (16, 64) semaphore tensor = 1024 entries), consider centralizing this as a named constant (or deriving it from the allocator) to prevent future drift and to document why 1024 is the correct bound.
        ) and get_gfx() == "gfx950":
            logger.warning(
                f"ASM gemm only supports indtype=bf16 and outdtype=bf16 and k%256==0 and not scaleAB is supported in {get_gfx()}, but actual indtype is {self.indtype}, outdtype is {self.outdtype}, k is  {self.k}, scaleAB is {self.scaleAB}"
            )

            self.asm_gtimedf = pd.DataFrame(columns=["gtimems", "libtype"])
            return []

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread gradlib/test_gemm_tuner_splitk.py
Comment thread gradlib/test_gemm_tuner_splitk.py
Comment on lines +93 to +98
}
for name, mod in stubs.items():
sys.modules.setdefault(name, mod)


_install_stubs()
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Installing stub modules into sys.modules at import time can leak into the rest of the test process (e.g., later tests may unexpectedly import the stubs instead of the real aiter package). Consider scoping this to the test module via unittest.mock.patch.dict(sys.modules, ...) (and/or only stubbing when aiter cannot be imported) so other tests aren’t impacted.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

class TestSplitKSemaphoreGuard(unittest.TestCase):

@patch("gradlib.GemmTuner.get_gfx", return_value="gfx942")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

class TestSplitKSemaphoreGuard(unittest.TestCase):

@patch("gradlib.GemmTuner.get_gfx", return_value="gfx942")
@patch("gradlib.GemmTuner.generate_data", return_value=None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

# tile 64x64 on a 4096x4096 grid => gdx=64, gdy=64 => 4096 > 1024
gemm = _make_gemm(m=4096, n=4096, k=256)

with patch.object(Gemm, "get_asm_kernels",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

f"Task with splitK={splitK} has grid {gdx}x{gdy}={gdx*gdy} > 1024",
)

@patch("gradlib.GemmTuner.get_gfx", return_value="gfx942")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

)

@patch("gradlib.GemmTuner.get_gfx", return_value="gfx942")
@patch("gradlib.GemmTuner.generate_data", return_value=None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

# tile 128x128 on 128x128 => gdx=1, gdy=1 => 1 <= 1024
gemm = _make_gemm(m=128, n=128, k=256)

with patch.object(Gemm, "get_asm_kernels",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

self.assertGreater(len(splitk_tasks), 0,
"Expected SplitK tasks for a small grid, got none")

@patch("gradlib.GemmTuner.get_gfx", return_value="gfx942")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

"Expected SplitK tasks for a small grid, got none")

@patch("gradlib.GemmTuner.get_gfx", return_value="gfx942")
@patch("gradlib.GemmTuner.generate_data", return_value=None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

# tile=64, m=64*32=2048, n=64*32=2048 => gdx=32, gdy=32 => exactly 1024
gemm = _make_gemm(m=2048, n=2048, k=256)

with patch.object(Gemm, "get_asm_kernels",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name patch

Copilot AI and others added 2 commits April 17, 2026 15:29
@apicciau apicciau self-assigned this Apr 17, 2026
apicciau added a commit that referenced this pull request Apr 20, 2026
…x1250

Introduces the FlyDSL A16W16 GEMM kernel for RDNA4 (gfx1250) and integrates
it as a first-class tunable backend in GemmTuner, alongside the existing
splitk_hgemm and ASM paths.

New files:
- aiter/ops/flydsl/kernels/gemm_a16w16_gfx1250.py: WMMA 16x16x32 kernel
  using RDNA4 wave32; handles K-padding and N-stride internally; supports
  fp16/bf16 input, configurable tiling (tile_m/n/k), warp layout (m/n_warp),
  double-buffering (num_buffers), waves_per_eu, and L2 prefetch distance

Changes to existing files:
- aiter/ops/flydsl/gemm_kernels.py: add get_flydsl_a16w16_gfx1250_kernels()
  catalog and get_flydsl_a16w16_gfx1250_kernel_params() lookup; kernel name
  encodes all config parameters for reversible CSV serialisation
- gradlib/gradlib/GemmTuner.py: import the new kernel; add
  run_flydsl_gemm_a16w16() run function; add flydsl_a16w16_gemm_all_sols()
  enumerator; route gfx1250 through the a16w16 path in run_asm_triton_sols()
  while other architectures continue using the existing splitk_hgemm path;
  also restores the ASM SplitK semaphore guard (gdx*gdy <= 1024) that was
  missing on main (also tracked in PR #2721)
- aiter/tuned_gemm.py: add flydsl_a16w16_gemm() dispatch function; update
  the flydsl config lookup to resolve a16w16 kernel names, falling back to
  splitk_hgemm; select the correct call site based on the resolved config
@@ -0,0 +1,197 @@
"""
Tests for GemmTuner.asm_gemm_all_solutions SplitK semaphore guard.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest we do not add tuner test here. I've add tuning_tests in op_tests which runs daily.

if splitK > 1:
gdx = (self.n + tile_n - 1) // tile_n
gdy = (self.m + tile_m - 1) // tile_m
if gdx * gdy > 1024:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My earlier concern was that this 1024 limit is not an independent tuning rule. It is tied to the current semaphore workspace size in gemm_op_a16w16.py, which is a fixed (16, 64) allocation. If we hardcode 1024 again in GemmTuner.py, the tuner will need to be updated separately whenever that workspace size changes, and that is easy to miss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants