Skip to content

ops: native HIP Flash Attention kernels for AMD RDNA3 (gfx11)#2043

Open
T0nd3 wants to merge 14 commits into
OpenNMT:masterfrom
T0nd3:feat/flash-attention-hip
Open

ops: native HIP Flash Attention kernels for AMD RDNA3 (gfx11)#2043
T0nd3 wants to merge 14 commits into
OpenNMT:masterfrom
T0nd3:feat/flash-attention-hip

Conversation

@T0nd3
Copy link
Copy Markdown

@T0nd3 T0nd3 commented May 12, 2026

Summary

Adds a native HIP implementation of Flash Attention 2 for AMD RDNA3+ GPUs.
Builds on the ROCm/HIP backend introduced in v4.7.0 (#1989), filling in the
-DWITH_FLASH_ATTN=ON path that previously raised FATAL_ERROR when combined
with -DWITH_HIP=ON.

No external dependency (no CUTLASS, no rocWMMA, no Composable Kernel). All
kernels are plain HIP / Clang built-ins, FP32 accumulators throughout.

Tested on AMD Radeon RX 7900 XTX (gfx1100, RDNA3) on Windows 11 with the
ROCm SDK pip wheels.

Architecture: four dispatched kernels

The dispatcher (flash_attention_hip_impl) picks the most efficient kernel
that matches the input shape and dtype:

Path Used when FP-paths Materialises S?
WMMA seqlen_q ≥ 16, D ∈ {64,128}, gfx11+ fp16, bf16 no
Decode seqlen_q == 1, D ∈ {64,128} fp16, bf16 fits in LDS
Tiled (scalar) seqlen_q ≥ 64, D ∈ {64,80,128} fp16, bf16 no
3-pass everything else (correctness oracle) fp16, bf16 yes

WMMA path uses __builtin_amdgcn_wmma_f32_16x16x16_{f16,bf16}_w32 for both
Q·K^T and P·V. The wave32 accumulator fragment layout was reverse-engineered
empirically — see src/ops/wmma_probe.cu for the probe (kept in the tree
as a dev tool, excluded from the production DLL).

Decode kernel uses V-tiling in phase 3 so output-channel threads read V
from LDS instead of looping through HBM seqlen_k times each.

Performance (Whisper-medium, RX 7900 XTX, FP16)

Workload Flash=OFF Flash=ON Speedup
Encoder, B=1, Sq=Sk=1500 62.5 ms 58.0 ms 1.08×
Encoder, B=4 202 ms 181 ms 1.11×
generate(max_length=30) 225 ms 211 ms 1.07×
generate(max_length=200) 1009 ms 946 ms 1.07×
generate(max_length=448) 2190 ms 2128 ms 1.03×

BF16 path gives the same ~1.09× encoder speedup (separate WMMA built-in,
same wave32 layout).

HBM footprint (measured via hipMemGetInfo)

Flash=OFF Flash=ON Δ
Model persistent 2.87 GiB 2.67 GiB −200 MiB
Working set (generate max=100) +359 MiB +342 MiB −17 MiB
Score matrix per encoder layer 137 MiB allocated+freed never materialised

The 137 MiB-per-layer score matrix isn't a persistent saving — the standard
path frees it after each layer — but each layer reads/writes that buffer
~3× through HBM, so the actual benefit is roughly 3 GiB of HBM bandwidth
saved per encoder pass
, which is what shows up as the speedup.

Key correctness fix worth flagging

hip_attn_softmax_kernel does its per-row max/sum via a binary tree
reduction (for (s = blockDim.x >> 1; s > 0; s >>= 1)). The dispatcher
originally launched it with blockDim.x = min(seqlen_k, 256). For
seqlen_k = 3 (the Whisper prompt prefill of three tokens) that means
blockDim.x = 3, which is not a power of two: the very first reduction
step skips the third element, the softmax denominator is silently wrong,
and generate() ends up producing token 50411 as the first generated
token regardless of audio input.

Now the dispatcher rounds up to the next power of two; extra threads
contribute identity values (−1e9 for max, 0 for sum). Five-seed regression
test in the pytest suite specifically guards against this re-surfacing.

Tests

New python/tests/test_flash_attention.py (15 tests, all green):

  • Encoder correctness FP16 + BF16 (max-diff/rel-diff thresholds appropriate to each precision)
  • generate() token match across 5 seeds (FP16)
  • Regression test for the softmax-block-size bug above
  • Batch correctness (B=2, B=4)
  • Variable prompt length (1, 3, 5, 8) — exercises 3-pass-fallback, decode, and WMMA paths
  • Informational score-buffer-size report (no asserts)

Tests skip cleanly if the faster-whisper-medium snapshot isn't already in
the local HuggingFace cache, so they don't pull network on CI runners that
haven't pre-populated models.

python/tests/benchmark_flash_attention.py is a standalone benchmark
(not pytest) that measures both speed and HBM via ctypes -> hipMemGetInfo.
The numbers in this PR description come straight from running it.

Changes

  • src/ops/flash_attention_gpu.cu — adds the four-path HIP implementation in the existing #else // CT2_USE_HIP block; CUDA path untouched.
  • src/ops/wmma_probe.cu — dev tool for layout reverse-engineering. Excluded from the default build via a commented-out entry in CMakeLists.txt; document at the top explains how to re-enable.
  • src/layers/flash_attention.cc — pass is_causal = _is_decoder to the FlashAttention op so encoder self-attention runs non-causal.
  • src/layers/transformer.cc, src/layers/whisper.cc — pass model.use_flash_attention() through to TransformerEncoderLayer. Whisper has its own encoder class, so both needed the fix.
  • src/models/model.cc — accept flash_attention=True on HIP builds when CT2_WITH_FLASH_ATTN is defined.
  • CMakeLists.txt — adds the HIP-side -DCT2_WITH_FLASH_ATTN definition.
  • python/tests/test_flash_attention.py, python/tests/benchmark_flash_attention.py — see above.
  • CHANGELOG.md — entry under Unreleased.

Test plan

  • Build with -DWITH_HIP=ON -DWITH_FLASH_ATTN=ON on gfx1100 (Windows 11)
  • Build with -DWITH_HIP=ON -DWITH_FLASH_ATTN=OFF — verified that flash_attention=True raises the documented error
  • All 15 tests in python/tests/test_flash_attention.py pass
  • Standalone benchmark confirms the numbers above
  • CI build on Linux x86_64 ROCm (build-python-wheels-rocm ubuntu-24.04 + build-and-push-docker-images rocm)
  • CI on Windows ROCm (build-python-wheels-rocm windows-2025)

Builds on PR #2041 (Windows ROCm build guide) and #2042 (HIP test enablement)
in spirit, but doesn't depend on either.

T0nd3 added 14 commits May 11, 2026 20:01
The build previously aborted with a hard error when both WITH_HIP=ON and
WITH_FLASH_ATTN=ON were set. The CUDA Flash Attention kernels rely on
CUTLASS/CuTe (sm80-specific) and cannot be compiled for HIP directly.

Replace the FATAL_ERROR with a cmake WARNING so that the build succeeds.
Using flash_attention=True at runtime on a ROCm build already raises a clear
std::invalid_argument (see model.cc CT2_USE_HIP guard). The warning explains
that a native HIP implementation via AMD Composable Kernel is planned.
Implements scaled dot-product attention for the ROCm/HIP backend using
three plain HIP kernels (QK^T, softmax, PV) with FP32 accumulators.
No CUTLASS, CuTe, or AMD Composable Kernel dependency required.

Supports FP16 and BF16 inputs; tested on gfx1100 (RX 7900 XTX).
KV-cache (offset > 0), rotary embeddings, and ALiBi are expected to be
pre-applied by the caller in this initial implementation.

Changes:
- flash_attention_gpu.cu: add #ifndef CT2_USE_HIP / #else / #endif
  guard so CUDA and HIP share one translation unit with a single
  FlashAttention::compute<Device::CUDA> specialisation each.
  CUTLASS/CuTe headers are excluded for HIP builds.
- CMakeLists.txt: enable -DCT2_WITH_FLASH_ATTN for HIP when
  WITH_FLASH_ATTN=ON (no sm80 CUTLASS sources added).
- model.cc: honour CT2_WITH_FLASH_ATTN on HIP builds so that
  flash_attention=True is accepted at runtime instead of raising
  "FlashAttention not supported on ROCm."
Builds on the initial 3-pass HIP kernels with three changes:

1. KV-cache write path (offset > 0) — adds hip_kv_cache_write_kernel
   that stages new K/V tokens into the cached buffer before the attention
   pass.  Mirrors the CUTLASS path's append-then-attend semantics so
   autoregressive decoding works without the layer mutating the cache.

2. Softmax-reduction block-size fix — the tree reduction inside
   hip_attn_softmax_kernel requires blockDim.x to be a power of two.
   The dispatcher previously used min(seqlen_k, 256) directly, so for
   seqlen_k = 3 (e.g. the Whisper prompt prefill of three tokens) the
   reduction silently dropped the third element, corrupting the softmax
   and forcing generate() to always emit the same token regardless of
   input.  Now rounded up to the next power of two with extra threads
   contributing the identity (-1e9 for max, 0 for sum).

3. Two new fast paths in the dispatcher:
   - hip_flash_attn_fwd_tiled (Flash Attention 2): one block per
     (q_tile, head, batch), Q in registers, K/V tiles streamed through
     LDS, online-softmax state (m_i, l_i, acc[D]).  S = Q@K^T is never
     materialised in HBM.  Specialised for D in {64, 80, 128} with
     BM = BN = 64.  Used when seqlen_q >= BM.
   - hip_flash_decode_kernel: one block per (head, batch) with threads
     parallelising over K.  Phase 1 computes scores in LDS, phase 2
     reduces, phase 3 accumulates output channels with V-tiling
     (BLOCK threads stage a V_TILE-wide slab of V into LDS once, then
     D channel threads sum against the cached tile).  Used for
     seqlen_q == 1 with D in {64, 128} and seqlen_k bounded by the
     64 KiB per-block LDS budget.

The original 3-pass kernels remain as the correctness-oracle fallback
for unsupported head dimensions and the 2..BM-1 query-length gap.

Verified on Whisper-medium / RX 7900 XTX: all five test seeds produce
identical token sequences to the standard MultiHeadAttention path
(generate up to max_length=200).
Two related changes that together let the encoder's self-attention go
through FlashMultiHeadAttention when use_flash_attention=true:

- FlashMultiHeadAttention previously instantiated the FlashAttention op
  with its default is_causal=true, which only matches the decoder's
  autoregressive self-attention.  When the layer is used by the encoder
  every query would mask out its successors, producing wrong outputs.
  The op is now constructed with is_causal=_is_decoder, so encoder
  usage is non-causal and decoder usage stays causal.

- TransformerEncoder and WhisperEncoder build their layers without
  passing use_flash_attention through to TransformerEncoderLayer, so
  they always picked MultiHeadAttention regardless of the model's
  flag.  Both constructors now forward model.use_flash_attention().
  (Whisper has its own encoder class — patching only the generic
  TransformerEncoder is not enough.)

Whisper-medium / RX 7900 XTX numbers (Sq = Sk = 1500, B = 1):
  encoder-only:  64.2 ms → 59.3 ms   (1.08x)
  generate(30):  221   ms →  213 ms  (1.04x)
  generate(200): 1012  ms →  965 ms  (1.05x)
  generate(448): 2233  ms → 2154 ms  (1.04x)

Correctness preserved (FP16 rounding gives ~0.3% relative diff in the
encoder output; generated token sequences match exactly up to at least
max_length=200 across five random seeds).
Adds hip_flash_attn_wmma_fp16: a Flash Attention forward kernel that
uses the wave32 16x16x16 fp16-input / fp32-accumulator WMMA built-in
(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32) for both Q·K^T and P·V.

Block layout (one wave32 per block):
  BM_W = 16 query rows
  BN   = 16 key tokens per K/V tile
  Q tile loaded once (pre-scaled), K and V tiles streamed in
  Inner reduction over the head dimension via D/16 WMMA calls per
  output tile; the P·V phase fans out to D/16 output fragments held
  entirely in registers.  Online softmax with per-row (m_i, l_i)
  staged through LDS so the WMMA accumulator fragment layout is
  decoupled from the row-reduction code.

Activated in the dispatcher for FP16 inputs whenever head_dim is a
multiple of 16 and seqlen_q >= 16.  Falls through to the scalar
tiled kernel for BF16 and other shapes; correctness oracle (3-pass)
remains as the final fallback.

Also adds src/ops/wmma_probe.cu — a small companion file with a
probe kernel + C-ABI driver used to empirically verify the wave32
accumulator fragment layout on this hardware before writing the
real kernel.  The probe is reusable for future ports (BF16 path,
gfx12, other tile shapes).  Layout verified on gfx1100:
  lane l, slot s -> C[2*s + (l >> 4), l mod 16]

Whisper-medium / RX 7900 XTX numbers (vs. the scalar tiled kernel
that was the previous best):
  encoder B=1:  59.3 -> 57.6 ms  (1.03x)
  encoder B=4:  202.8 -> 182.6 ms (1.11x)  -- biggest win, scales
                                              with compute load
  generate 30:  227 -> 209 ms    (1.09x)
  generate 100: 549 -> 522 ms    (1.05x)
  generate 200: 1031 -> 968 ms   (1.07x)
  generate 448: 2203 -> 2099 ms  (1.05x)

Correctness preserved (FP16 rounding gives ~0.3% relative diff in
the encoder output; generated token sequences match exactly up to
max_length=200 across five random seeds).
Two complementary additions to the WMMA Flash Attention path:

1. BF16 support.  The wave32 16x16x16 fragment layout is identical
   between the FP16 and BF16 WMMA built-ins on RDNA3 — the only
   difference is which intrinsic to call.  The existing BM_W=16
   kernel is now templated over the half-precision element type
   (HalfT in {_Float16, __bf16}) and selects the right built-in
   (__builtin_amdgcn_wmma_f32_16x16x16_{f16,bf16}_w32) via
   if constexpr.  The dispatcher matches scalar_t against
   float16_t / bfloat16_t and reinterpret_casts the data pointers
   to the underlying HalfT (same memory layout, the wrapper is just
   a host-side ABI thing).

   Whisper-medium with compute_type=bfloat16, RX 7900 XTX:
     encoder B=1:  59.8 -> 55.0 ms (1.09x), matches the FP16 win.

2. New kernel hip_flash_attn_wmma_fp16_bm64<D>: 4 wave32 wavefronts
   per block, BM_W = BN = 64, each K/V tile loaded once into LDS
   and shared across all 64 query rows.  Built but currently NOT
   dispatched (kept for future tuning).  On gfx1100 it is ~5-10%
   slower than the BM_W=16 variant because:
   - 48 KiB LDS allows only 1 block per CU vs. BM_W=16's 4 blocks
     per CU -> halved occupancy
   - Per-row softmax loop iterates BN=64 cols sequentially (4x more
     work per softmax thread)
   - The expected K/V HBM-reuse benefit doesn't materialise: Whisper
     attention is compute/LDS-bound, not HBM-bound (~580 MB total
     attention HBM traffic / 960 GB/s = 0.6 ms, vs. ~60 ms encoder)
   The kernel structure (per-wave S/P scratch, cooperative tile load)
   is the foundation for a future wave-shuffle-softmax variant that
   could drop LDS enough to make BM_W=64 actually a win.

Correctness preserved across all five test seeds for FP16; BF16
matches 4/5 (the fifth differs in the last token, expected from
BF16's 8-bit mantissa — the relative encoder diff is 1.85%).
…Flash Attention

python/tests/test_flash_attention.py (CI-friendly pytest, 9 tests):
  - test_flash_attention_encoder_matches_standard[float16|bfloat16]
    Verifies Flash=ON encoder output matches Flash=OFF within
    rounding tolerances (FP16: max_abs_diff <= 0.5, rel <= 0.5%;
    BF16: <= 2.0, rel <= 3%).  Also prints the per-layer score-
    buffer size that the online-softmax path avoids materialising.
  - test_flash_attention_generate_fp16_token_match[42|123|777|999|1234]
    Five seeds, generates 20 tokens each, asserts byte-identical
    token sequences between Flash=ON and Flash=OFF.
  - test_softmax_block_size_regression
    Specifically guards the bug from commit d9016a5: a 2^k
    block-size requirement in the softmax tree reduction that
    caused all five seeds to emit token 50411 as the first
    generated token regardless of audio input.  Asserts that the
    set of first tokens across five distinct seeds has size > 1.
  - test_flash_attention_score_buffer_savings
    Pure documentation test that prints, for a range of common
    transformer shapes (Whisper-medium encoder/decoder, hypothetical
    4k and 32k LLM contexts), how much HBM the standard attention
    path would allocate for the [B, H, Sq, Sk] FP32 score matrix.

Tests are skipped automatically if the faster-whisper-medium snapshot
isn't already in the local HuggingFace cache, so they don't pull
network in CI environments that haven't pre-populated models.

python/tests/benchmark_flash_attention.py (standalone, not pytest):
  - Direct HBM measurement via ctypes -> hipMemGetInfo on the active
    HIP device.  Reports both the persistent model footprint and the
    working-set growth after a representative generate() call,
    Flash=ON vs Flash=OFF.
  - Performance measurement of encoder (B=1, B=4) and generate
    (max_length=30/100/200/448), GPU-synced before timing.

Measured on Whisper-medium, RX 7900 XTX, gfx1100:
  HBM: model 2.87 GiB OFF -> 2.67 GiB ON (-200 MiB persistent);
       working set +359 MiB OFF -> +342 MiB ON (-17 MiB peak).
  Speed: encoder B=1 1.08x, B=4 1.11x; generate 1.03-1.07x across
         max_length 30..448.  Stable across runs.

Both files include a Windows local-build dev-loop helper that
attempts to add the ROCm SDK wheel's bin directories to the DLL
search path before importing ctranslate2.  No-op in normal CI.
The wmma_probe kernel is a one-shot reverse-engineering tool for the
RDNA3 wave32 WMMA fragment layout — it's how the layout used by
hip_flash_attn_wmma_fp16 was discovered.  Useful when porting to a new
gfx target / precision / tile shape, but no business living in the
production DLL where it just exports an unused C symbol
(ct2_wmma_probe_run) and bloats the binary.

Comment it out of CMakeLists.txt's CUDA_SOURCES and document how to
re-enable it at the top of wmma_probe.cu.  Source file stays in the
tree so future ports of the WMMA kernels can rerun it.
The hip_flash_attn_wmma_fp16_bm64 kernel was added speculatively as a
"larger Q-tile so K/V loads from HBM amortise across more query rows"
experiment, but never dispatched because empirically on gfx1100 it is
~5-10% slower than the BM_W=16 variant:

- 48 KiB LDS allows only 1 block per CU vs. BM_W=16's 4 blocks per CU
  -> roughly halved wave occupancy
- The per-row softmax loop sequentially iterates BN=64 cols, 4x the
  work per softmax thread vs. BM_W=16
- The expected HBM-bandwidth-reduction benefit doesn't materialise
  because Whisper attention is compute/LDS-bound, not HBM-bound:
  total attention HBM traffic is ~580 MB which at 960 GB/s is 0.6 ms,
  vs. ~60 ms of encoder time

The lesson is in the git history; carrying disabled 240-line code around
is just a maintenance liability.  If a future wave-shuffle-softmax
refactor shrinks the LDS footprint below ~24 KiB, the kernel can be
brought back from this commit cleanly.

The dispatcher's BM_W=16 path remains the WMMA fast path for both
FP16 and BF16 inputs with head_dim in {64, 128} and seqlen_q >= 16.

Verified: all 9 tests in python/tests/test_flash_attention.py still
pass after the removal; Whisper-medium encoder/generate speedups
unchanged.
… cases

Two new parameterised tests:

- test_flash_attention_encoder_batched[2|4]
  Runs the encoder with batch_size in {2, 4} and asserts Flash=ON
  output matches Flash=OFF on every batch element.  The previous
  encoder test only covered B=1 — this catches any per-batch
  indexing bugs in the WMMA tile loads / dispatch (grid.z == batch).

- test_flash_attention_variable_prompt_length[1|3|5|8]
  Runs generate() with prompt lengths 1, 3, 5, and 8.  This is the
  case where the decoder's first forward pass has seqlen_q != 1 and
  also seqlen_q < BM_W=16, so the dispatcher routes through the
  3-pass fallback kernel rather than the WMMA or decode paths.
  Specifically covers the dispatcher branch we never explicitly
  exercised before, and re-checks the n_prompt=3 case that
  originally triggered the softmax-block-size bug.

All 15 tests pass on Whisper-medium / RX 7900 XTX.
The HIP section's introductory comment was written when there were only
two code paths (tiled + 3-pass).  Update it to reflect the current four
paths (WMMA, decode, tiled, 3-pass) with a small table of which inputs
each path handles.

Also expand the Kernel 2 (softmax) doc-comment to call out the
power-of-two block-size invariant and reference commit d9016a5 — the
bug that caught it produced the same first generated token (50411) for
every input because the 3-prompt-token prefill ran the tree reduction
with blockDim.x = 3 and silently dropped the third element.  Worth
guarding against a future maintainer "simplifying" the launch code.
Lists the four dispatched kernels (WMMA, scalar tiled, decode, 3-pass),
the KV-cache write path, the build flag pair, and the measured speedup
on Whisper-medium / RX 7900 XTX.
…ests

CI's check-python-style step caught:
  - black formatting differences (long argument lists wanted on separate lines,
    a couple of generator-expression layouts)
  - isort import-order: test_utils + ctranslate2 needed a blank-line group
    separation
  - flake8 F541: one f-string with no placeholders in benchmark_flash_attention.py
    line 231 (the savings-banner print spans two lines and only the second one
    actually interpolates -- both got prefixed with f"" by mistake).

All three checks now pass locally with the same versions CI pins
(black==22.*, isort==5.*, flake8==3.8.*); all 15 pytest tests still pass.
The wheel CI builds the ROCm wheel with
  --offload-arch=gfx1030 --offload-arch=gfx1100 ...
i.e. one device pass per architecture.  In the device pass for gfx1030
the WMMA kernel template's body is gated out (the wave32 WMMA built-in
only exists on gfx11+/gfx12+), so its name is not in scope.  The
dispatcher's `hipLaunchKernelGGL((hip_flash_attn_wmma_fp16<HalfT, D>), …)`
call is parsed by the compiler in every pass (including device passes,
even though the dispatcher itself is host-only), so on gfx1030 it fails:

  flash_attention_gpu.cu:1296:31: error: use of undeclared identifier
  'hip_flash_attn_wmma_fp16'
  1 error generated when compiling for gfx1030.

This affected all three ROCm CI jobs:
  build-python-wheels-rocm (ubuntu-24.04)
  build-python-wheels-rocm (windows-2025)
  build-and-push-docker-images (rocm)

Two-level fix:

1. Compile-time gate: wrap the dispatcher's WMMA launch (lambda + the
   two `if (launch_wmma_bm16(...)) return;` calls) with the *same*
   `#if defined(__gfx11..) || !defined(__HIP_DEVICE_COMPILE__)` guard
   that already wraps the kernel definition.  In a non-gfx11 device
   pass the entire WMMA call site is preprocessor-skipped, so the
   undeclared-identifier error is gone.

2. Runtime gate: even when the file was compiled with some gfx11
   target in the arch list, the resulting multi-arch wheel may run
   on a non-gfx11 GPU at execution time, in which case the WMMA
   kernel isn't in the loaded device binary.  Add a one-time
   `gcnArchName` inspection (via cuda::get_device_properties) that
   only enters the WMMA path on a `gfx1*` / `gfx2*` device.  On any
   other arch the dispatcher falls through to the scalar tiled /
   3-pass kernels.

Verified locally:
  - Single-arch build (gfx1100 only): WMMA path active, 15/15 tests pass.
  - Multi-arch build (gfx1030;gfx1100;gfx1101;gfx1102): compiles
    cleanly, 15/15 tests pass on the gfx1100 host.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant