Skip to content

[gfx1250] Optimize FP8/FP4 GEMM kernel and tests#533

Merged
coderfeli merged 3 commits into
mainfrom
gfx1250/gemm_fp8_opt
May 16, 2026
Merged

[gfx1250] Optimize FP8/FP4 GEMM kernel and tests#533
coderfeli merged 3 commits into
mainfrom
gfx1250/gemm_fp8_opt

Conversation

@aoli26
Copy link
Copy Markdown
Contributor

@aoli26 aoli26 commented May 15, 2026

Motivation

Optimize gfx1250 fp8 gemm:

  • Lift WMMA% by reordering per-quadrant callbacks (filler / B prefetch / store) into the WMMA stream.
  • Add B-streaming to hide B-load latency by streaming B per-quadrant instead of staging the full B tile.
  • Free TDM descriptor slots by routing FP8 scales through buffer_load + LDS.
  • Cut launch overhead via lru_cache + flyc.compile ctypes fast path (~17us/call).

Also drop the FFM COMGR preload shim — on current FFM it pulls in a second LLVM and crashes import with Option 'spirv-expand-step' registered more than once.

Technical Details

  • Quadrant scheduling reorganization: reorder callbacks between WMMA issues to hide LDS / TDM / store latency.
  • b_streaming=True: issue B fragment loads inside the WMMA loop per quadrant, overlapping with prior-quadrant WMMA. Lower VGPR / LDS footprint.
  • scale_load_path="buffer_lds_stage" | "buffer_lds_stage_ab_split": scales go via buffer_load → LDS instead of TDM. Frees TDM for A/B; the _ab_split variant pipelines A-scale and B-scale separately.
  • Active TDM load handling refactor: cleaner per-warp distribution.
  • Fast launch path: @lru_cache on compile_mxscale_gemm + flyc.compile pre-binding, so calls bypass JitFunction.bind's inspect.Signature + cache-key hashing.
  • COMGR shim removal (python/flydsl/__init__.py, python/flydsl/_compat.py): required for FFM import.

Test Plan

Run tests/kernels/test_gemm_fp8fp4_gfx1250.py (FP4 / FP8 / A8W4, mcast, irregular-tile) on gfx1250.

Test Result

pytest ::test_mxfp8_gemm: 60 passed, 36 skipped, 0 failed. Other parametrized cases pass on gfx1250.

Submission Checklist

aoli26 added 2 commits May 15, 2026 13:20
Bring kernels/gemm_fp8fp4_gfx1250.py and its test up to the latest
optimized version:

- LRU-cache compile_mxscale_gemm via flyc.compile fast path
- B-streaming compute path option (b_streaming=True)
- FP8 scale async load split path (scale_load_path=
  'buffer_lds_stage'|'buffer_lds_stage_ab_split')
- Quadrant callback / scheduling reorganization
- FP8 active TDM load handling refactor
- Test parametrization extended to cover the new paths
Copilot AI review requested due to automatic review settings May 15, 2026 13:45
@aoli26 aoli26 added the enhancement New feature or request label May 15, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Optimizes the gfx1250 FP8/FP4 GEMM kernel by adding a new FP8 quadrant compute schedule, an opt-in B-streaming schedule, an alternative scale-load path through buffer_load+LDS (with an A/B-split variant), and a lru_cache/flyc.compile fast-launch path. Also removes the FFM COMGR preload shim that fails on current FFM builds and adds new test coverage (b-streaming, scale-load-path matrix, hipGraph capture/replay).

Changes:

  • New compute schedules (fp8_quadrant, b_streaming) and matching hot_loop_scheduler_*, plus buffer_lds_stage[_ab_split] scale-load paths with new TDM-descriptor halves and refactored _select_active_tdm selection.
  • Test additions: b-streaming correctness, scale-load-path matrix, AB-split, hipGraph cudagraph test, and a hipGraph-based bench helper; benchmark CLI gains --scale-load-path, --b-streaming, --use-graph; default bench shape shrunk to 1024³ × 2048.
  • lru_cache on compile_mxscale_gemm + pre-bind via flyc.compile for the ctypes fast launch path; removal of python/flydsl/_compat.py and its _maybe_preload_system_comgr shim invocation.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.

File Description
kernels/gemm_fp8fp4_gfx1250.py Adds new compute schedules, B-streaming, AB-split scale path, refactors cluster helpers to gpu.*, caches compile result.
tests/kernels/test_gemm_fp8fp4_gfx1250.py Adds b-streaming, scale-path, AB-split, cudagraph tests; switches launch to flyc.compile fast path; extends bench CLI.
python/flydsl/init.py Drops the COMGR preload call to allow FFM import to succeed.
python/flydsl/_compat.py File deleted along with the _maybe_preload_system_comgr shim.
Comments suppressed due to low confidence (1)

kernels/gemm_fp8fp4_gfx1250.py:2015

  • Same issue as the cluster_position/mcast_masks calls above: gpu.cluster_barrier() does not exist on flydsl.expr.gpu; cluster_barrier is defined in flydsl.expr.rocdl.cluster. This branch fires when loop_iters == 0 and use_cluster=True, so it will raise at compile/codegen time for small-K cluster configurations.
            gpu.cluster_barrier()

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/gemm_fp8fp4_gfx1250.py Outdated
Comment thread kernels/gemm_fp8fp4_gfx1250.py Outdated
Comment thread kernels/gemm_fp8fp4_gfx1250.py
Comment thread kernels/gemm_fp8fp4_gfx1250.py Outdated
Comment thread kernels/gemm_fp8fp4_gfx1250.py Outdated
Comment thread kernels/gemm_fp8fp4_gfx1250.py
Comment thread tests/kernels/test_gemm_fp8fp4_gfx1250.py
Comment thread tests/kernels/test_gemm_fp8fp4_gfx1250.py
Comment thread python/flydsl/__init__.py
@aoli26 aoli26 force-pushed the gfx1250/gemm_fp8_opt branch from 198d358 to 27f4217 Compare May 15, 2026 14:23
@aoli26 aoli26 force-pushed the gfx1250/gemm_fp8_opt branch from 27f4217 to fabd391 Compare May 15, 2026 15:01
@coderfeli coderfeli merged commit daa5bcf into main May 16, 2026
12 checks passed
@coderfeli coderfeli deleted the gfx1250/gemm_fp8_opt branch May 16, 2026 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants