Add optimized FMHA kernel on MI355X#629
Merged
Merged
Conversation
- kernels/flash_attn_func.py: add num_kv_heads (default = num_heads = MHA);
split STRIDE_TOKEN/global_idx/head_idx into Q (used by Q,O) and KV (used
by K,V coop_load and DMA); recursive auto-launch threads num_kv_heads.
- tests: --num_kv_heads CLI flag; DEFAULT_CONFIGS extended to 5-tuple
(batch, seq_len, num_heads, num_kv_heads, head_dim); add GQA-8 row
(16, 8192, 64, 8, 128); pytorch reference expands K/V via
repeat_interleave; table/CSV layout adds Hkv column.
- tests: fix aiter.mha_fwd / aiter.fmha_v3_fwd argcount mismatch (now uses
keyword args for trailing optionals).
Verified on MI355X (gfx950) bf16 B=16 S=8192 D=128 iters=100 after clean
FlyDSL rebuild:
MHA (Hkv=64): causal 692.3 / nocausal 636.8 TFLOPS,
MaxErr 3.91e-03 / 2.44e-04 (PASS)
GQA-8 (Hkv=8) : causal 687.1 / nocausal 701.7 TFLOPS,
MaxErr 3.91e-03 / 2.44e-04 (PASS)
Co-authored-by: Cursor <cursoragent@cursor.com>
- Change log().info(...) to log().debug(...) for the per-rewriter AST diff and the final transformed code dump in ast_rewriter.py. - These messages are noisy compile-time diagnostics; users only need them when actively debugging the AST rewrite. They now require FLYDSL_DEBUG_LOG_LEVEL=DEBUG (with FLYDSL_DEBUG_LOG_TO_CONSOLE=1 or FLYDSL_DEBUG_LOG_TO_FILE=...) to surface, instead of the more commonly enabled INFO level. Co-authored-by: Cursor <cursoragent@cursor.com>
- Add opus_attn C++ templates (d128/d512 causal/noncausal), host/driver, and setuptools build - Add opus_attn Python wrapper, compare/rebuild/install scripts, and README - Add run.sh launcher and test_flash_opus_attn kernel tests - Extend test_flash_attn_func DEFAULT_CONFIGS with GQA kv_heads=num_heads line Verified: not benchmarked in this commit (functional tests available under tests/kernels/) Co-authored-by: Cursor <cursoragent@cursor.com>
- new kernels/flash_attn_opus.py: D=128 bf16 fast path for gfx950+,
modeled after opus_attn/gqa_d128_kernel_template.hpp.
Key OPUS optimizations included:
* 3D grid launch (H, num_q_blocks, B) for better workload distribution
* Double-buffered K LDS with buffer_load_dwordx4_lds DMA
* ds_read_tr16_b64 HW-transpose V reads
* Online softmax with lazy rescale (ballot+read_exec): clamps
row_max := m_running when (row_max - m_row) <= 8.0 across all lanes,
skipping O *= corr (corr == 1)
* s_setprio(1)/s_setprio(0) brackets around GEMM2/rescale cluster
* s_nop 15 + s_nop 7 yield window after s_setprio(0)
* Causal mask via per-element v_cmp_lt + v_cndmask (select chain)
- kernels/flash_attn_func.py: add OPUS dispatcher
* built only when head_dim=128, dtype=bf16, gfx950+
* runtime dispatch when seq_len >= 384 and seq_len % 256 == 0
* gated by FLYDSL_ENABLE_OPUS_PATH=1 (opt-in until perf matches baseline)
* non-eligible runtime shapes and configs fall through unchanged
- env-var knobs in flash_attn_opus.py:
FLYDSL_OPUS_LAZY_RESCALE / FLYDSL_OPUS_SETPRIO / FLYDSL_OPUS_YIELD_NOP
Verified correctness (MaxErr threshold 8e-03):
B=1 S=512 causal+nocausal MaxErr=3.91e-03 / 4.88e-04 PASS
B=1 S=8192 causal+nocausal MaxErr=3.91e-03 / 2.44e-04 PASS
B=4 S=2048 causal MaxErr=3.91e-03 PASS
Performance B=16 S=8192 H=64 D=128 bf16 (MI355X):
default-path OPUS-path OPUS-C++ ASM
causal 716 TFLOPS 636 1131 595
nocausal 640 TFLOPS 678 1165 1249
OPUS path is currently below baseline for causal but a small win for
nocausal; ships as opt-in to avoid causal regression. Reaching the
~1074 TFLOPS target (95% of OPUS C++) requires porting OPUS's full
8-cluster pipeline, in-flight Q scaling, and exact sched_barrier_pairs
fence ordering, which remains as future work.
Co-authored-by: Cursor <cursoragent@cursor.com>
Restructure flash_attn_opus.py to match the OPUS C++ kernel layout
line-by-line at the cluster level. This is the structural foundation
on top of which subsequent perf phases (P2-P5) will land.
Structural changes (every code section labelled with its C++ line range):
- Prologue (C++ 397-436): K[0]/K[1]/V[0] async ladder + mma0 of tile 0
+ causal mask + first-half exp2 + K[2] async kickoff. Establishes the
loop invariant v_s_0_partial = (exp2(s_lo - m_row), s_hi - m_row).
- Main loop (C++ 439-561): rewritten with `j += 2`, processing 2 KV
tiles per iteration across 8 clusters with V double-buffer:
* Cluster 0: V[j-2] async + K[j-2] ds_read
* Cluster 1: GEMM0 S[j-2] + finish softmax of carried v_s[0]
* Cluster 2: K[j] async
* Cluster 3: GEMM2 P[j-3]@v[j-3] via step_k(0..3) +
lazy-rescale check (ballot/all_below) +
sub_row + first-half exp2(v_s[1])
* Cluster 4: V[j-1] async + K[j-1] ds_read
* Cluster 5: GEMM0 S[j-1] + finish softmax of v_s[1]
* Cluster 6: K[j+1] async + causal mask on v_s[0] (= S[j-1])
* Cluster 7: GEMM2 P[j-2]@v[j-2] via step_k(0..3) +
lazy-rescale + sub_row + first-half exp2(v_s[0])
- Epilogue (C++ 565-742): new 13-cluster drainer for the last 3 KV
tiles plus the carried partial v_s[0]. Each tile uses FULL mma1
(no lazy in epilogue, matches C++). Cluster 11 does sub_row +
first-half exp + sched_barrier + second-half exp + cast + scale o,
matching C++ exactly. Cluster 13 emits the final mma1.
- max_num_tiles computed from ceil_div(N, KV_TILE_SIZE) with causal
cap (C++ 383-390) — replaces previous kv_upper token-based bound.
- Loop state extended to carry v_s_0_lo_partial / v_s_0_hi_partial
(partial exp2 fragment) between iterations.
- s_waitcnt encodings refined per C++ template:
* 0xC07F : lgkmcnt(0)
* vmcnt(k+v=4) : main loop and epilogue clusters 0/2/4
* vmcnt(v=2) : epilogue clusters 6/8 (only V outstanding)
* vmcnt(0) : epilogue cluster 10 (drain all VMEM)
- Each cluster boundary uses the C++ triple:
sched_barrier(0) ; s_barrier ; sched_barrier(0).
Deferred to later phases (NOT in P1):
P2: in-flight Q scaling (drop per-FMA c_sm_scale_log2e multiplies)
P3: sched_group_barrier_pairs / _exp_pairs scheduler discipline
P4: inline-asm attn_mask_vec2_imm + 8 register anchors
P5: stagger mechanism (warp_id // 4 asymmetric barriers)
Verified: MaxErr 3.91e-03 (causal) / 9.77e-04 (nocausal) at seq=512,
MaxErr 3.91e-03 (causal) / 2.44e-04 (nocausal) at seq=8192,
all below 8e-03 threshold. Perf intentionally NOT optimized
in P1; current numbers will be addressed by P2-P5.
Co-authored-by: Cursor <cursoragent@cursor.com>
Align with gqa_d128_kernel_template.hpp lines 404-406: pre-multiply Q by temperature_scale = (1/sqrt(D)) * log2(e) during the prologue load, so that subsequent softmax math operates directly in log2 space and the per-FMA "* sm_scale_log2e" multiplications disappear. Changes - Prologue Q load (lines 392-433): each bf16 MFMA pack is extended to f32x8, multiplied element-wise by c_sm_scale_log2e, then truncf'd back to bf16x8 before feeding GEMM0. Constants block is moved above the Q load so c_sm_scale_log2e is in scope. - _sub_row_first_half_exp (lines 545-559): the per-element FMA chain fma(s, c_sm_scale_log2e, -m_row*c_sm_scale_log2e) collapses to a plain subf(s, m_row) since both operands are already log2-scaled. - Main loop clusters 3 and 7: lazy-rescale check uses m_diff = m_tile_max - m_row (no scaling) and corr = exp2(m_row - m_new). - Epilogue clusters 3, 7, 11: rescale_eX = exp2(m_row - row_max_eX) without the extra c_sm_scale_log2e multiply. Numerics - S becomes (sm_scale*log2e) * S_old throughout the pipeline, so every (row_max - m_row), corr, rescale and exp2 argument is rescaled by the same constant factor; final P and O values are mathematically identical to P1. MaxErr unchanged at every test config. Verified - seq_len=512 causal MaxErr 3.91e-03, nocausal MaxErr 4.88e-04 (PASS) - seq_len=8192 causal MaxErr 3.91e-03, nocausal MaxErr 2.44e-04 (PASS) Per the user's phased plan, only logic alignment + correctness are checked at this phase; performance optimizations (sched_group_barrier pairs, inline-asm causal mask, stagger) are deferred to P3-P5. Co-authored-by: Cursor <cursoragent@cursor.com>
Replace generic sched_barrier(0) fences inside the compute clusters
with the same sched_group_barrier (MFMA/VALU/EXP) groups used by the
OPUS C++ template (gqa_d128_kernel_template.hpp lines 14-30, 455-720).
This gives the LLVM AMDGPU scheduler explicit hints about expected
instruction densities in every pipeline stage so it can reproduce the
intended MFMA-VALU-EXP interleaving.
Changes
- Define module-local mask constants matching LLVM AMDGPU semantics:
MFMA=0x008, VALU=0x002, EXP=0x400.
- Add two Python helpers that mirror the C++ recursive templates:
_sched_barrier_pairs(pairs, valu_cnt, group) (C++ lines 18-23)
_sched_barrier_exp_pairs(pairs, exp_cnt, group) (C++ lines 25-30)
- Insert hint pairs in every loop/epilogue cluster at the same call
sites as the C++ template:
Main loop Cluster 1 → group 1 (exp_pairs<6,3,1>; pairs<10,5,1>)
Cluster 3 → group 2 (pairs<4,5,2>; pairs<6,5,2>; exp_pairs<6,3,2>)
Cluster 5 → group 3 (exp_pairs<6,3,3>; pairs<10,5,3>)
Cluster 7 → group 4 (pairs<4,5,4>; pairs<6,5,4>; exp_pairs<6,3,4>)
Epilogue Cluster 1 → group 5 (exp_pairs<6,3,5>; pairs<10,5,5>)
Cluster 3 → group 6 (pairs<10,5,6>; exp_pairs<6,3,6>)
Cluster 5 → group 7 (exp_pairs<6,3,7>; pairs<10,5,7>)
Cluster 7 → group 8 (pairs<10,5,8>; exp_pairs<6,3,8>)
Cluster 9 → group 9 (exp_pairs<6,3,9>; pairs<10,5,9>)
Cluster 11→ group 10 (pairs<10,5,10>; exp_pairs<6,3,10>)
Behavior
- sched_group_barrier emits no instructions of its own; it constrains
pre-/post-RA scheduling. Numerics are identical to P2: MaxErr is
bit-for-bit the same across every test config.
- Performance currently drops because the hints require matching
instruction densities that the FlyDSL backend does not yet produce
(the C++ kernel relies on hand-tuned register anchors and the
stagger mechanism added in P4/P5 to fully exploit them). The phased
plan explicitly defers performance to later phases.
Verified
- seq_len=512 causal MaxErr 3.91e-03, nocausal MaxErr 4.88e-04 (PASS)
- seq_len=8192 causal MaxErr 3.91e-03, nocausal MaxErr 2.44e-04 (PASS)
Co-authored-by: Cursor <cursoragent@cursor.com>
Align with OPUS C++ template's hand-coded inline assembly for two
classes of low-level constructs (gqa_d128_kernel_template.hpp lines
233-249 and the eight v_s/v_p anchor sites scattered across lines
430-635).
Changes
- Add `_attn_mask_imm_single` helper that emits the
`v_cmp_lt_i32_e64 + v_cndmask_b32_e64` pair with the threshold baked
into the asm string as an immediate literal. `_attn_mask_vec2_imm`
invokes it twice per (thr_x, thr_y) pair, matching the C++ semantics.
Split into two single-asm calls rather than a 4-output struct return
because MLIR's llvm.inline_asm with two simultaneous "=s" sgpr-pair
outputs proved brittle.
- Rewrite `_causal_mask_inplace` to mirror C++ attn_mask_causal_tile:
compute rel = q_pos - k_pos with k_pos = kv_start + i_n*W_N +
lane_group*c_pack, then iterate the 8 (thr_x, thr_y) immediate pairs
derived from the C++ static_for nest. Masks s_lo (i_n=0) and s_hi
(i_n=1) with the same threshold list but with rel_hi = rel_lo - W_N.
- Add `_anchor_vec`, `_anchor_pair`, `_anchor_packs` helpers that emit
`asm volatile("" : "+v"(v))`-style fences using LLVM inline asm with
a tied "=v,0" constraint and has_side_effects=True.
- Place the 8 register anchors at the C++-matching sites:
#1 Prologue (v_s[0]) — C++ line 430
#2 Main Cluster 1 (v_p) — C++ line 454
#3 Main Cluster 3 (v_s[1]) — C++ line 489
#4 Main Cluster 5 (v_p) — C++ line 512
#5 Main Cluster 7 (v_s[0]) — C++ line 553
#6 Epi Cluster 1 (v_p) — C++ line 578
#7 Epi Cluster 3 (v_s[1]) — C++ line 607
#8 Epi Cluster 5 (v_p) — C++ line 635
Behavior
- Numerics identical to P3: per-element causal-mask decision is
q_pos < absolute_K_col, rewritten as rel < threshold. The anchors
emit no real instructions, so the data values are unchanged.
Verified
- seq_len=512 causal MaxErr 3.91e-03, nocausal MaxErr 4.88e-04 (PASS)
- seq_len=8192 causal MaxErr 3.91e-03, nocausal MaxErr 2.44e-04 (PASS)
Performance is still well below the C++ target; the phased plan keeps
that for the remaining stages (P5 stagger and any follow-on tuning).
Co-authored-by: Cursor <cursoragent@cursor.com>
Mirror the OPUS C++ template's dual-wave-group phase-shift scheme
(gqa_d128_kernel_template.hpp lines 308, 415-418, 748-750):
const int warp_id = __builtin_amdgcn_readfirstlane(
thread_id_x() / WARP_SIZE);
const int stagger = warp_id / 4;
...
if (stagger) {
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
}
...
if (!stagger) {
__builtin_amdgcn_s_barrier();
}
Changes
- Add `OPUS_ENABLE_STAGGER` env-gated flag (FLYDSL_OPUS_STAGGER, default 0).
- Compute a SCALAR (SGPR-resident) stagger value:
wave_id_uni = readfirstlane(tid / WARP_SIZE)
stagger = wave_id_uni / 4
using `rocdl.readfirstlane` + `arith.divsi`. The result feeds two
`arith.cmpi` results: `stagger_is_one_i1` and `stagger_is_zero_i1`.
- Add `_stagger_extra_barrier_if_one` / `_stagger_extra_barrier_if_zero`
helpers that emit inline assembly:
s_cmp_eq_u32 $0, 0
s_cbranch_scc{1,0} 1f
s_barrier
1:
Verified in the final ISA: `21_final_isa.s` shows the expected
s_cmp + s_cbranch_scc + s_barrier triple at both stagger sites and
the asymmetric barrier counts line up across waves.
- Two `if const_expr(OPUS_ENABLE_STAGGER)` gates:
* Prologue stagger site (post-vmcnt, pre-mma0): emits the asymmetric
barrier when ON, an unconditional `sched_barrier(0) + gpu.barrier()`
when OFF.
* Pre-store stagger site (post-`inv_l`, pre-global store): emits the
complementary asymmetric barrier when ON, an unconditional
`gpu.barrier()` when OFF.
- Add `scf` import for completeness (uses confined to helpers).
Why default OFF
- The asymmetric barrier is verified correct at the ISA level. However,
enabling it currently produces wrong results because the FlyDSL kernel
loads V from LDS inside Cluster 3 (via
`_read_v_packs_for_k_substep(0, ...)`), while the C++ reference loads
V into registers in Cluster 2 (tr_load before the cluster-2 barrier).
With phase-shifted execution, warps 4-7 end up reading `s_v[0]` after
warps 0-3 have already issued the Cluster-4 async_load that
overwrites it. Hoisting the V reads is a structural change outside
the P5 scope and is left to a follow-up phase.
- With stagger OFF, all 8 waves stay in lockstep so the LDS lifetime
invariants hold and the kernel still produces correct results.
Verified (FLYDSL_OPUS_STAGGER unset → OFF)
- seq_len=512 causal MaxErr 3.91e-03, nocausal MaxErr 4.88e-04 (PASS)
- seq_len=8192 causal MaxErr 3.91e-03, nocausal MaxErr 2.44e-04 (PASS)
Co-authored-by: Cursor <cursoragent@cursor.com>
The P5 stagger path (FLYDSL_OPUS_STAGGER=1) previously produced wrong
results because V was being read from LDS inside Cluster 3/7/11/13 (i.e.
AFTER the cluster-2/6/10/12 s_barrier), while the C++ template loads V
into registers in the preceding cluster (tr_load BEFORE the s_barrier).
Under the dual-group phase shift, warps 4-7 would end up reading from
s_v[*] AFTER warps 0-3 had already issued the next async_load that
overwrites the same LDS buffer — a data race.
Fix: move all 6 V LDS read sites one cluster earlier so V is captured
into VGPRs BEFORE each cluster-boundary barrier, mirroring the C++
template gqa_d128_kernel_template.hpp exactly:
Main loop:
- Cluster 3 V[j-3] from s_v[0] → hoisted into Cluster 2
- Cluster 7 V[j-2] from s_v[1] → hoisted into Cluster 6
Epilogue:
- Cluster 3 V[max-4] from s_v[0] → hoisted into Cluster 2 (epi)
- Cluster 7 V[max-3] from s_v[1] → hoisted into Cluster 6 (epi)
- Cluster 11 V[max-2] from s_v[0]→ hoisted into Cluster 10 (epi)
- Cluster 13 V[max-1] from s_v[1]→ hoisted into Cluster 12 (epi)
With V in VGPRs across each cluster boundary, peer async_loads
overwriting the LDS buffer are harmless.
Other changes:
- Update banner comment near stagger setup: now states the path is
correctness-safe and no longer warns that V hoisting is required.
- Update inline comments at both stagger sites and OPUS_ENABLE_STAGGER
env-var docs to reflect the new state.
- Keep FLYDSL_OPUS_STAGGER default OFF: the asymmetric barrier
currently regresses throughput in this port (108 → 85 TFLOPS @ S=8192
B=16 causal) due to extra V-substep register pressure across the
barrier. The user can opt in with FLYDSL_OPUS_STAGGER=1 once the
scheduling tradeoff is addressed.
Verified (S=8192 B=16 H=64 D=128 bf16, MI355X, FLYDSL_OPUS_STAGGER=1):
causal: MaxErr 3.91e-03 (matches OPUS C++ bit-for-bit), 85.0 TFLOPS
nocausal: MaxErr 2.44e-04 (matches OPUS C++ bit-for-bit), 78.1 TFLOPS
Verified (S=8192 B=16 H=64 D=128 bf16, MI355X, FLYDSL_OPUS_STAGGER=0):
causal: MaxErr 3.91e-03, 108.5 TFLOPS (no regression vs prior tip)
nocausal: MaxErr 2.44e-04, 103.6 TFLOPS
Also verified at S=256 and S=512 small-scale, both stagger paths PASS.
Co-authored-by: Cursor <cursoragent@cursor.com>
The P5 stagger path was correctness-fixed in the previous commit (V LDS
reads hoisted into Cluster 2/6/10/12). To make the OPUS path fully
self-contained — i.e. setting only `FLYDSL_ENABLE_OPUS_PATH=1` actually
exercises every P1..P6 modification — flip the default for
`FLYDSL_OPUS_STAGGER` to ON.
Now the OPUS path requires no auxiliary env vars: all of LAZY_RESCALE,
SETPRIO, STAGGER, YIELD_NOP default ON, so the kernel is a faithful
end-to-end port of gqa_d128_kernel_template.hpp.
Verified with the exact user-requested command on MI355X:
FLYDSL_ENABLE_OPUS_PATH=1 \
python tests/kernels/test_flash_opus_attn.py --causal \
--dtype bf16 --batch 16 --num_heads 64 --num_kv_heads 64 \
--seq_len 8192 --head_dim 128 --iters 100 --compare
causal: MaxErr 3.91e-03 (bit-for-bit match with OPUS C++), 85.0 TFLOPS
`FLYDSL_OPUS_STAGGER=0` still available as escape hatch for A/B testing.
Co-authored-by: Cursor <cursoragent@cursor.com>
Atomic switch of K/V LDS layout, DMA writers, register-side readers, and
causal mask to match OPUS gqa_d128_kernel_template.hpp (sections 4-5).
Key changes (all in lockstep, no env toggle):
- LDS layout: interleaved K0/V0/K1/V1 double-buffer, 68096 B total
(K tile 8320 bf16 x 2, V tile 8704 bf16 x 2). Line stride is
smem_linear_wave + smem_padding (K: 520 bf16, V: 544 bf16) so the
hardware ds_read / ds_read_tr16_b64 path is bank-conflict-free.
- coop_dma_k / coop_dma_v rewritten on the OPUS u_gk/u_sk/u_gv/u_sv
layouts: each wave owns 8 rows of N x 64 bf16 of D, two d_rpt stripes
per buffer, raw_ptr_buffer_load_lds into the new line-strided slots.
- _read_k_packs_for_buf rewritten on OPUS u_rk: lane%32 = ((m%8)*8+m/8)
with d_rpt-major step_k layout (outer stride 4160 bf16, inner 16,
v_s_hi at +256 bf16 from v_s_lo).
- _read_v_packs_for_k_substep rewritten on OPUS u_rv via
ds_read_tr16_b64 (v4f16) + 8-lane shuffle: 4 D-chunks per step_k with
per-lane base = grp_k*(lane/32) + lane_hi*((lane%16)/4) +
grp_n*((lane/16)%2) + lane_lo*(lane%4).
- _causal_mask_inplace: N-axis thresholds reordered to the OPUS
permutation pi(m) = (m%8)*8 + m/8 so the 8 register-anchored vec2
comparisons cover the permuted N positions of S; v_s_hi delta becomes
-4 (matches OPUS smem_d_n_split=4 N-half offset).
- SmemAllocator finalized against LDS_KV_TOTAL_SIZE (68096 B).
Verified on hyg_trn_rocm7.1 / MI355X with cleared ~/.flydsl/cache:
FLYDSL_ENABLE_OPUS_PATH=1 python tests/kernels/test_flash_opus_attn.py \
--warmup 5 --iters 100
All 15 configs PASSED (max err 3.91e-03 < 1e-2, min cos 0.99999 > 0.99).
Co-authored-by: Cursor <cursoragent@cursor.com>
- Add ds_read_b64_tr_b16 immediate-offset inline asm; V LDS reads in OPUS issue order - Fix K/V coop_dma global row index (n_in_warp*NUM_WAVES+wave_id); clarify u_gk comments - Add raw buffer resource helper for O; set DMA aux 0 per loader path - run.sh: run causal opus test without --compare; keep compare variant commented Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
- FLASH_ATTN_OPUS_Kernel_Analysis_Detail.md: FlyDSL OPUS kernel walkthrough - FLASH_ATTN_OPUS_vs_CPP_Differences.md: Python vs C++ OPUS mapping - GQA_D128_KERNEL_Analysis_Detail.md: GQA d128 reference-kernel analysis Verified: documentation only Co-authored-by: Cursor <cursoragent@cursor.com>
- Split K/V buffer_load_lds addressing into uniform `soffset` and per-lane `voffset` - Drop redundant `_dma_soff`; softmax exp paths use `rocdl.exp2` instead of arith.exp2 Verified: not run in this session Co-authored-by: Cursor <cursoragent@cursor.com>
- Add q_rsrc and use buffer_ops.buffer_load for per-step Q MFMA packs - Remove redundant global half-vec helpers and unused Q/K/V/O pointer locals Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
… readers - Move Q/K/V/O resources, DMA knobs, causal tile bounds, and scaled Q packs (buffer_load) ahead of MFMA/sched-group helpers for clearer codegen order - Pass urk_base_per_lane / urv_base_per_lane into K/V ds_read helpers - Trim repetitive cluster/epilogue commentary (behavior unchanged intent) Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
- Annotate prologue, steady-state clusters, softmax/rescale path, epilogue/store - Map gqa_d128_kernel_template constructs (layouts, async_load, mma*, barriers) Verified: comment-only annotations (no codegen logic change intended) Co-authored-by: Cursor <cursoragent@cursor.com>
- Rename coop_dma/read/wave-max helpers to async_load_* and attn_* for reader parity - Reorder prologue and inner-cluster sequences (GEMM0, mask, waits, DMA) to follow OPUS pipeline comments Verified: not re-run in this session Co-authored-by: Cursor <cursoragent@cursor.com>
- Extend launcher/kernel with stride_q_n, stride_kv_n, head_dim_runtime (defaults unchanged) - Build Q/K/V buffer resources using base-byte offsets; DMA uses stride_kv_n for GMEM indexing - Compute softmax scale as rsqrt(head_dim_runtime)*log2e at kernel entry - Use uniform wave id for K/V DMA LDS line addressing; simplify LDS ptr creation for DMA - Move Q pack load/load path after barrier; buffer_load uses row-in-block indexing via stride_q_n Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
- Split buffer_load → f32 extend → softmax-scale → bf16 trunc into helpers - Keeps per-step logic identical; improves readability near async K prefetch Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
- Load/concat Q to one bf16 shard; scale/trunc via broadcast once; MFMA slices via shuffle - Replace stagger inline-asm with scf.If + sched_barrier + rocdl.s_barrier for LLVM intrinsic paths - Add convert-ub-to-llvm to RocmBackend pipeline Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
…cale - Load Q packs via RawPtrBufferLoadOp (i32 vec) with byte offsets; bitcast to bf16 - Replace Vec extf/mul/truncf in scale_q_all with llvm FPExtOp/FPTruncOp + arith.mulf and fastmath attrs Verified: not re-run in this session Co-authored-by: Cursor <cursoragent@cursor.com>
- Add shared alias_scope domain plus lds_{k,v}{0,1} tags on raw_ptr_buffer_load_lds
- Load K mfma packs via aligned llvm.LoadOp with matching alias/noalias scopes
- Pass **kw through buffer_load_to_lds wrappers to raw_ptr_buffer_load_lds
Verified: not benchmarked in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Implement attn_mask as one inline-asm block (2× v_cmp + 2× v_cndmask) for lo/hi pairs - Split s_hi pair masking into a second constexpr loop; mark asm side effects - Add scf.if prologue when q_start_pos < KV_TILE_SIZE (C++ attn_mask path) Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
… order - Replace xor shuffle reductions with rocdl.permlane32_swap pairs for row-max/sum - Anchor paired vec operands via llvm.inline_asm struct outputs (=v,=v,0,1) - Match OPUS prologue: attn_sub_row before anchor; split exp2-first-half slice path Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
- Replace wave-sync sites with explicit s_barrier to mirror OPUS C++ path - Keep sched_barrier pairing unchanged; no intended semantic change Verified: not re-run in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
…ests - Move flash_attn_opus.v1.s under exp_isa/ with build.sh, opus_asm_ext, and Python wrapper - Extend trace_segment_cycles.py and specific_part.json segment fixtures - test_flash_opus_attn: chunked PyTorch ref for large score tensors; exp_isa path hooks Verified: not rerun in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
…counts - Add _anchor_v_o inline-asm pins for four 16xf32 accumulators (C++ v_o_pin pattern) - Call after non-lazy corr scale and epilogue rescale_e3/e7/e11 paths - Bump _sched_barrier_pairs (4,6,*) / (6,6,*) / (9,6,10) and _sched_barrier_exp_pairs (7,3,10) - Refresh exp_isa/flash_attn_opus.v1.s annotations to match Verified: not rerun in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
…re harness - trace_segment_cycles: optional specific_part trace pairs, perf-counter series, instruction-type count deltas, and related compare reporting - ana_trace.sh: factor interval append + print_selected_summary; fyd_cpp_compare only - seg_asm: update fyd_cpp_compare.json and specific_part.json fixtures - test_flash_opus_attn: compare path uses aiter asm backend (exp_isa bench commented) - run.sh: build exp_isa .co before compare; exp_isa/build.sh skips py ext by default - Minor refresh of opus GQA d128 causal gfx950 ISA reference Co-authored-by: Cursor <cursoragent@cursor.com>
- point15: drop waitcnt/barrier; anchor on setprio + s_mov + GEMM2 MFMA - point19: anchor on setprio + MFMA + v_max3 (match current FlyDSL ISA layout) Co-authored-by: Cursor <cursoragent@cursor.com>
- input_hand_asm_thread_trace.yaml targets hand-asm kernel with waves_per_eu=2 trace dir - Enable ATT perf counters (VALU/MFMA busy and coexec) for segment analysis Co-authored-by: Cursor <cursoragent@cursor.com>
…hors - _attn_sub_row returns packed f32 vectors directly; remove _anchor_v_s calls - Add _anchor_scalar_f32 in lazy rescale cold path so m_tile_max merges as PHI - Re-enable llvm.intr_expect on all_below branch predicate Co-authored-by: Cursor <cursoragent@cursor.com>
…re hooks - Vendor fmha_fwd 256x64 gfx950 msk0/msk1 .s; fmha_asm.py + fmha_asm_ext.cc - exp_isa/build.sh: compile opus v1 + both FMHA .co objects; setup.py extension wiring - Add flash_attn_opus.v0.s snapshot; minor v1.s touch - input_asm_fmha_thread_trace.yaml; ana_trace dumps asm_fmha b2_s1024 traces - test_flash_opus_attn: run_exp_isa_fmha_bench; compare uses exp_isa FMHA asm path - run.sh: duplicate compare invocation after exp_isa build Verified: not rerun in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
…ead helpers - Replace mirrored C++ line comments with prologue/main-loop/epilogue/stagger docs - Drop unused KV load batching constants and _waitcnt_lgkm_0_vm_n wrapper - Use _waitcnt_vm_n / s_waitcnt(lgkmcnt(0)) directly at cluster boundaries - Minor epilogue sched_barrier pair count tweaks (e.g. cluster 11) Verified: not rerun in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
- flash_attn_func.py -> flash_attn_generic.py (flash_attn_generic_kernel) - flash_attn_opus.py -> flash_attn_gfx950.py (build_flash_attn_dualwave_swp_module, flash_attn_dualwave_swp_gfx950_kernel / launch_flash_attn_dualwave_swp) - test_flash_opus_attn.py -> test_flash_attn_fwd.py; test_flash_attn_func.py -> test_flash_attn_fwd_ori.py - Update imports, dispatcher wiring, and test references to new module names Verified: not rerun in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
… tree - Move ana_trace, rocprof YAML inputs, seg_asm fixtures, trace_segment_cycles, perf CSVs, run/compare helpers, and analysis markdown into fmha_opt_tools/ - Remove FLASH_ATTN_OPUS_vs_CPP_Differences.md (superseded by in-tree analysis) - Update fmha_opt_tools/run.sh for flash_attn_fwd / dualwave env toggles Co-authored-by: Cursor <cursoragent@cursor.com>
- Dispatch dualwave SWP from flash_attn_generic when dtype is bf16 or f16 - Select mfma_f32_32x32x16_bf16 vs _f16; fp16 pack/trunc for softmax P and O store - Relax builder dtype check and update module docstrings Verified: not rerun in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
…JSON - Update fyd_cpp_compare.json instruction anchors and fyd_attn specific_part - ana_trace.sh: fix paths under fmha_opt_tools/; default log at.log - Remove superseded main_loop_cluster0_7*.json (merged into fyd_cpp_compare) Co-authored-by: Cursor <cursoragent@cursor.com>
…rt.json - fyd_cpp_compare point11: match attn_sub_row v_sub before setprio 0 / s_barrier - Drop seg_asm/specific_part.json (superseded by fyd_cpp_compare.json) Co-authored-by: Cursor <cursoragent@cursor.com>
- Wire FLYDSL_DUALWAVE_SWP_* env vars into build and optional debug_counts launch - Add TRIGGER_LAZY_ELSE Q/K fixture and lazy branch counter reporting - Use chunked PyTorch reference when SDPA score workspace exceeds 128M elems - Warm up benchmark path under torch.profiler before run_perftest Verified: not rerun in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
… py ext - Add build_env_and_run_benchmark.sh (aiter, LLVM, FlyDSL, opus_attn, exp_isa, then compare bench) - exp_isa/build.sh: re-enable setup.py build_ext --inplace for asm Python extensions Verified: not rerun in this commit session Co-authored-by: Cursor <cursoragent@cursor.com>
- Absorb rocm/main infrastructure: CI, custom LLVM tools, external_llvm, backend cmake, and related FlyDSL framework updates - Resolve conflicts keeping opus_align functionality (GQA, dualwave SWP, flash_attn_generic module naming, 5-tuple test configs) - From rocm/main in test_flash_attn_fwd_ori: _custom_llvm_tools_env() wrapper and multi-line formatting only Co-authored-by: Cursor <cursoragent@cursor.com>
…ilds - Port _custom_llvm_tools_env() from test_flash_attn_fwd_ori to gfx950 test - Wrap build_flash_attn_dualwave_swp_module in custom LLVM context manager - fmha_opt_tools/run.sh: enable FLYDSL_FLASH_ATTN_FUNC_USE_CUSTOM_LLVM=1 - Refresh run.sh benchmark command examples (fp16/GQA/compare variants) Co-authored-by: Cursor <cursoragent@cursor.com>
- Rename launch_flash_attn_func -> launch_flash_attn_generic for module clarity - rocm backend: remove convert-ub-to-llvm from default LLVM lowering pipeline - fmha_opt_tools/run.sh: IR dump + cache clear dev workflow; 8192 compare bench Co-authored-by: Cursor <cursoragent@cursor.com>
175b5e7 to
edb4a5d
Compare
- Point docs and README to flash_attn_generic/gfx950 and test_flash_attn_fwd - run_benchmark.sh: 7-field GQA shapes, --num_kv_heads, legacy 6-field fallback Co-authored-by: Cursor <cursoragent@cursor.com>
- Remove duplicated _custom_llvm_tools_env helpers from gfx950 test harness - Build dualwave module with bundled FlyDSL LLVM (custom LLVM stays in ori test) Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR adds a gfx950-optimized FlashAttention forward fast path for
head_dim=128on MI350-series GPUs. The goal is to bring FlyDSL's FMHA performance closer to the OPUS/hand-tuned baselines while keeping the same math and layout contract as the existing generic FlashAttention path.The new path targets both
bf16andfp16, supports causal and non-causal operation, and handles both MHA and GQA/MQA (num_kv_heads <= num_heads). It is intended to improve the common D=128 attention workload on gfx950 without changing behavior for unsupported architectures, dimensions, or shapes.Technical Details
kernels/flash_attn_func.pytokernels/flash_attn_generic.py, and kept it as the fallback path.kernels/flash_attn_gfx950.py, a dual-wave, software-pipelined gfx950 implementation for D=128bf16/fp16FlashAttention.gpu_arch >= gfx950head_dim == 128bf16/fp16seq_len >= 384andseq_len % 256 == 0num_heads,num_kv_heads,GQA_GROUP_SIZE).ds_read_b64_tr_b16transpose LDS reads, and explicitsched_barrier/sched_group_barrierhints to constrain instruction scheduling.s_setprio(1)/(0)to improve handoff between the two wave groups.O *= corrwhen all lanes remain below the rescale thresholdexp2placement across clusters to hide transcendental latency behind MFMA chainsv_cmp_lt_i32 + v_cndmask_b32pairsfmha_opt_tools/and comparison support for OPUS,aiter_ck, and hand-assembly baselines.Test Plan
rocm/pytorch:rocm7.1_ubuntu24.04_py3.12_pytorch_release_2.8.0.kernels/flash_attn_gfx950.pyimplementation on that branch is identical to the code in this PR. The branch also contains benchmark-only test harness changes that are intentionally not included in this PR, but were used for fair comparisons:run_opus_attn_bench(...)runs the OPUS reference kernel for supported bf16 shapes.run_exp_isa_fmha_bench(...)runs the prebuilt hand-assembly FMHA kernel baseline for supported bf16 shapes.aiter:45c428e54ac15b9b49d66018c8a1108b20c8336a7f77ca0dbda4abbf9af06537b2c475f20ccd6007fmha_perf_compare_MI355X.csv.flash_attn_opus.v1.cofmha_fwd_hd128_bf16_1tg_8w_256x64_350_msk0_gm0.cofmha_fwd_hd128_bf16_1tg_8w_256x64_350_msk1_gm0.coD=128over bothfp16andbf16.S=128..8192,H/Hkv=64/64,32/32,16/16,8/8, plus the GQA caseH=64,Hkv=8; large-batch rows includeB=4/8/16/32atS=8192.aiter_ckforfp16andbf16; forbf16, also compared against OPUS andaiter_asmwhere those baselines are available.export FLYDSL_FLASH_ATTN_FUNC_USE_CUSTOM_LLVM=0 python tests/kernels/test_flash_attn_fwd.py --causal --dtype fp16 --iters 100 --compare python tests/kernels/test_flash_attn_fwd.py --causal --dtype bf16 --iters 100 --compareTest Result
Benchmark CSV before rebase: fmha_perf_compare_MI355X.before_rebase.csv
Benchmark CSV after rebase: fmha_perf_compare_MI355X.csv
Build completed successfully.
All benchmark rows completed, with compared rows matching baseline correctness (MaxErr ratio 1.00x).
fp16 causal: FlyDSL avg 666.0 TFLOPS / 1943.2 us, MaxErr 4.88e-04; aiter_ck avg 530.3 TFLOPS / 2559.0 us. Average FlyDSL/aiter_ck throughput: 119.9%.
bf16 causal: FlyDSL avg 699.2 TFLOPS / 1830.6 us, MaxErr 3.91e-03; OPUS avg 826.9 TFLOPS where available, aiter_ck avg 551.4 TFLOPS, and aiter_asm avg 781.5 TFLOPS. Average ratios: 101.6% vs OPUS, 120.1% vs aiter_ck, 103.1% vs aiter_asm.
fp16 causal full table
bf16 causal full table
Submission Checklist