feat(hip): port RoPE to ROCm#223
Merged
demandal25 merged 9 commits intoMay 20, 2026
Merged
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends the ROCm/HIP RoPE implementation to match CUDA feature coverage by adding int64 position-id dispatch, enabling MLA (2D-K) support in fused RoPE+FP8 quantize and paged-cache append paths, and tuning kernel launch geometry (notably for CDNA3).
Changes:
- Add
DISPATCH_PYTORCH_IDTYPE_TO_CTYPEand use it in HIP RoPE bindings to supportint32/int64position-id tensors. - Unify the fused RoPE+quantize+append kernel to support both GQA/MHA (
paged_kv_t) and MLA (paged_kv_mla_t) via aCacheTtemplate +constexprbranching; add an MLA-specific host launcher. - Expand ROCm test coverage for pos-id dtype variants and MLA paths; re-enable the cos/sin-cache test on HIP with looser tolerances.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
flashinfer/csrc_rocm/pytorch_extension_utils.h |
Adds an idtype dispatch macro used by ROCm bindings to accept int32/int64 tensors. |
flashinfer/csrc_rocm/rope.cu |
Updates HIP bindings to dispatch idtype, adds MLA support in quantize/append paths, and routes MLA cache appends through paged_kv_mla_t. |
include/flashinfer/attention/generic/pos_enc.cuh |
Generalizes the fused kernel to support MLA cache layout and updates launch geometry using kWarpSize. |
tests/rocm_tests/test_rope_hip.py |
Adds idtype parametrization, MLA coverage, and re-enables cos/sin-cache testing with HIP-appropriate tolerances. |
tests/rocm_tests/test_activation_hip.py |
Removes ad-hoc __main__ execution block from the test file. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
c7f822c to
ae4ced6
Compare
demandal25
added a commit
to demandal25/flashinfer
that referenced
this pull request
May 20, 2026
…usivity Address Copilot review comments on PR ROCm#223: - Add TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) in apply_rope and apply_llama31_rope to prevent UB when caller passes mismatched dtypes (e.g. indptr=int64, offsets=int32). - Add TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) in rope_quantize_append_paged_kv_cache so callers with both cache sets get a clear error instead of silently falling through to the GQA path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
demandal25
added a commit
to demandal25/flashinfer
that referenced
this pull request
May 20, 2026
…usivity Address Copilot review comments on PR ROCm#223: - Add TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) in apply_rope and apply_llama31_rope to prevent UB when caller passes mismatched dtypes (e.g. indptr=int64, offsets=int32). - Add TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) in rope_quantize_append_paged_kv_cache so callers with both cache sets get a clear error instead of silently falling through to the GQA path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
c3553a9 to
fa233b9
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (1)
include/flashinfer/attention/generic/pos_enc.cuh:1236
- Similarly, the Q-noRoPE section unconditionally loads/stores
vec_sizeelements per thread with no tail guard. Forno_rope_dimvalues that aren’t multiples ofrope_dim, this can read/write out of bounds. Add a tail predicate (or validateno_rope_dim % rope_dim == 0up front) so the last chunk doesn’t access beyondq_nope_*tensors.
} else {
uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
DType* q_nope_in_ptr =
q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n,
q_nope_in_stride_h);
QuantType* q_nope_out_ptr =
q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n,
q_nope_out_stride_h);
vec_t q_nope_vec;
q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
#pragma unroll
demandal25
added a commit
to demandal25/flashinfer
that referenced
this pull request
May 20, 2026
…usivity Address Copilot review comments on PR ROCm#223: - Add TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) in apply_rope and apply_llama31_rope to prevent UB when caller passes mismatched dtypes (e.g. indptr=int64, offsets=int32). - Add TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) in rope_quantize_append_paged_kv_cache so callers with both cache sets get a clear error instead of silently falling through to the GQA path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1fba9c2 to
31c88b9
Compare
… utils Mirrors the CUDA DISPATCH_DLPACK_IDTYPE_TO_CTYPE pattern, dispatching at::kInt→int32_t and at::kLong→int64_t. Used by the rope bindings to support int64 position-id tensors. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…CDNA3 tuning - Add typename CacheT = paged_kv_t<QuantType, IdType> as 7th template parameter to RopeQuantizeAppendPagedKVCacheKernel; rename paged_kv to paged_kv_like so GQA and MLA share one kernel body. - Add IS_MLA constexpr branch: 2D K pointer arithmetic for input, get_kpe_ptr/get_ckv_ptr cache stores, skip V section. - Add RopeQuantizeAppendPagedMLACache host launcher that instantiates the kernel with CacheT=paged_kv_mla_t. - CDNA3 launch geometry: vec_size=16/sizeof(DType) (single global_load_dwordx4), num_threads=max(2*kWarpSize, bdx). - Add <type_traits> include for std::is_same. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Wrap all apply_rope* and apply_llama31_rope* position-id lookups in DISPATCH_PYTORCH_IDTYPE_TO_CTYPE, supporting int32 and int64 pos_ids. - rope_quantize: add 2D-K MLA branch (k_rope_in.dim()==2 → num_kv_heads=1, head stride aliased to token stride) matching csrc/rope.cu:319-394. - rope_quantize_append_paged_kv_cache: replace the MLA hard-error with a full MLA dispatch that constructs paged_kv_mla_t and calls the new RopeQuantizeAppendPagedMLACache launcher. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…tivation test test_rope_hip.py: - Enable idtype parametrize (int32/int64) on test_rope_pos_ids. - Remove skip from test_rope_cos_sin_cache. - Add MLA variants to test_generalized_rope_quantize_hip and test_generalized_rope_quantize_append_kv_cache_hip. - Add test_mla_rope_quantize for the no-append MLA path. - Fix _rope_apply_interleave_f32 to handle 2D K natively via unsqueeze/squeeze inside _rot(). test_activation_hip.py: - Remove debug if __name__ == "__main__" block. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In the C++ binding, check that v_in is contiguous along the last dim and its scalar type matches q_rope_in before computing strides. Mirror the check at the Python entry point so mismatched v dtypes are caught early with a clear error message. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…nstraint in rope bindings Remove self-evident MLA/GQA path labels and stale line-reference comments flagged in simplify/review pass. Add WHY comments explaining that pos_ids is int32-only in paged-cache paths (shared int32 index arithmetic). Document the observed bf16 rounding error (~3e-2) that drives the 5e-2 tolerance in test_rope_cos_sin_cache. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…usivity Address Copilot review comments on PR ROCm#223: - Add TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) in apply_rope and apply_llama31_rope to prevent UB when caller passes mismatched dtypes (e.g. indptr=int64, offsets=int32). - Add TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) in rope_quantize_append_paged_kv_cache so callers with both cache sets get a clear error instead of silently falling through to the GQA path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…undary The kernel tiles no_rope_dim in rope_dim-sized chunks using ceiling division. If no_rope_dim is not a multiple of rope_dim, the last chunk reads/writes past the end of the K-nope buffer. Add TORCH_CHECK in both rope_quantize and rope_quantize_append_paged_kv_cache to catch this at the API boundary rather than silently causing OOB in the kernel. All known real-world configs satisfy the constraint (DeepSeek MLA: no_rope_dim=512, rope_dim=64; Llama: no_rope_dim=0). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
31c88b9 to
602aff6
Compare
…omments - Add K-tensor shape checks in rope_quantize (token/head/last dims, both 2D MLA and 3D GQA paths) to fail fast on caller shape mismatches instead of producing OOB inside the kernel. - Add K and cache shape checks in the MLA branch of rope_quantize_append_paged_kv_cache (k_rope_in/k_nope_in token+last dims; ckv_cache/kpe_cache rank, page_size, and last-dim parity with rope_dim/no_rope_dim). - Update the RopeQuantizeAppendPagedKVCacheKernel doc comment: it dispatches between GQA/MHA and MLA via the CacheT template and the IS_MLA constexpr, not GQA-only. - Reword the pos_ids int32 rationale in rope_quantize: that path is the non-paged kernel; the real reason is that the kernel template is instantiated with IdType=int32_t. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
6 tasks
demandal25
added a commit
that referenced
this pull request
May 21, 2026
## Summary Refresh the FlashInfer+ROCm README aimed at library consumers, refresh the Feature Support Matrix to match what has actually landed on `amd-integration`, and align the ROCm MLA wrapper with the rest of the ROCm backends so `backend="auto"` is accepted everywhere. ### What changed #### `README.md` - **Intro and structure.** Tighten the intro to call out HIP-in-repo kernels vs AITER dispatch up front; link to the Feature Support Matrix and AITER sections from the first paragraph. Cross-link CDNA3 / CDNA4 to AMD's official architecture whitepapers on first mention. - **Feature Support Matrix.** Replaced with a five-column table (Kernel / HIP / AITER / `backend="auto"` resolves to / Notes). New ✅ rows: Cascade (#221), MLA via AITER (#232), RoPE (#223), paged KV-cache append, RMSNorm via AITER (#232), sliding-window decode on the AITER path (#234), activation, quantization, and opt-in `torch.compile` (#210). Every ✅ is backed by a `tests/rocm_tests/test_*_hip.py`. FP8 status is folded into per-row notes rather than a dedicated column. - **GPU / ROCm / PyTorch.** Consolidated into one section with arch codenames inline (gfx942 → MI300X/MI325X = CDNA3, gfx950 → MI355X = CDNA4). `pip install torch` uses `--index-url` instead of `-f` so pip cannot silently fall back to a CPU-only PyPI wheel (matches CLAUDE.md). - **Getting Started.** Collapsed the Docker image table to the latest validated tag and pointed at Docker Hub for older releases. Dropped the manual `micromamba activate base` step (the env is auto-activated). Used the concrete image tag plus a `--name=flashinfer-rocm` in the `docker run` snippet. - **Trying the Examples.** Simplified to point at `examples/` plus one run command — no wget-based downloads. - **Install from Source.** Renamed from "Build from Source"; rewrote the ambiguous "Environment name varies …" note (and later removed it once the build / run blocks made the matching tag self-evident). - **AITER Support.** Collapsed the section intro to avoid re-listing conditions already in the matrix; cross-link Known Limitations. Rewrote Known Limitations preamble to state the two-group split (hard errors vs silently-ignored kwargs). Dropped the redundant Single Prefill Example (Basic Usage already shows the call pattern). - **Environment Variables.** New section documenting runtime env vars — `FLASHINFER_USE_TORCH_CUSTOM_OPS`, `FLASHINFER_HIP_FUSED_CASCADE`, `FLASHINFER_LOGGING_LEVEL`, `FLASHINFER_DISABLE_JIT`, `ROCM_PATH` / `ROCM_HOME`. Build-time vars stay in `CLAUDE.md` and are linked from here. - **Runtime Helpers.** Short snippet showing `is_aiter_supported` and `check_torch_rocm_compatibility`; calls out `validate_flashinfer_rocm_arch` as a build-time validator, not a runtime helper. - **CPX-mode pytest notes.** Split the dense paragraph into labelled bullets (Worker count / Reruns / `slow` marker / HIPBLAS retry). - **Basic Usage.** Moved to the end of the README as a closing example. - **License and Acknowledgements.** Added; the contributing reminder lives on its own line. #### `flashinfer/mla_rocm.py` + `tests/rocm_tests/test_mla_aiter_hip.py` - Accept `backend="auto"` as an alias for `"aiter"` on the ROCm MLA wrapper (default is now `"auto"` to match every other ROCm wrapper). Previously the wrapper raised `ValueError` on anything other than `"aiter"`, leaving MLA as the odd one out in the public API even though there is exactly one implementation to pick from on ROCm. - New tests: `test_mla_backend_accepts_auto_and_aiter` (parametrized over both values) and `test_mla_backend_rejects_unsupported` (confirms `backend="fa2"` still raises; runs without a GPU since the check fires before the AITER probe). ## Test plan - [x] `pre-commit run -a` passes. - [x] `pre-commit run markdownlint --files README.md` passes after every change. - [x] Every TOC entry resolves to an `##` heading in the body. - [x] Every ✅ in the Feature Support Matrix has a backing `tests/rocm_tests/test_*_hip.py`. - [x] `pytest tests/rocm_tests/test_mla_aiter_hip.py` — 11 passed. - [x] Render the README on the PR page and visually confirm tables, code blocks, and `<details>` sections look right. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes the remaining gaps between the CUDA and ROCm RoPE implementations: adds int64 position-id dispatch, enables MLA (DeepSeek-class) paths in the fused RoPE+FP8 quantize and paged-cache append kernels, and tunes launch geometry for CDNA3/Wave64.
What changed
DISPATCH_PYTORCH_IDTYPE_TO_CTYPEmacroflashinfer/csrc_rocm/pytorch_extension_utils.h— New macro mirroring CUDA'sDISPATCH_DLPACK_IDTYPE_TO_CTYPE, dispatchingat::kInt→int32_tandat::kLong→int64_t. Used by all fiveapply_rope*/apply_llama31_rope*HIP bindings so callers can freely pass either dtype for position tensors.MLA support in the fused RoPE kernel
include/flashinfer/attention/generic/pos_enc.cuh— Addstypename CacheT = paged_kv_t<QuantType, IdType>as a 7th template parameter onRopeQuantizeAppendPagedKVCacheKernel. Aconstexpr bool IS_MLA = std::is_same<CacheT, paged_kv_mla_t<...>>::valuebranch selects at compile time:get_kpe_ptr/get_ckv_ptr(MLA) vsget_k_ptr/get_v_ptr(GQA/MHA)total_blocks_yexcludes the V blocks whenIS_MLARopeQuantizeAppendPagedMLACache<DType, IdType, QuantType>host launcher constructs apaged_kv_mla_tand calls the shared kernel.CDNA3 launch-geometry tuning
Replaces NVIDIA-ported warp32 heuristics with Wave64-aware geometry throughout
pos_enc.cuh:BatchQKApplyRotary*/BatchQKApplyLlama31Rotary*vec_size = max(16/sizeof(DType), HEAD_DIM/32), 128-thread blockvec_size = max(16/sizeof(DType), HEAD_DIM/kWarpSize),num_threads = max(2*kWarpSize, bdx)RopeQuantize/RopeQuantizeAppendPagedKVCachevec_size = 32/sizeof(DType), fixed 128 threadsvec_size = 16/sizeof(DType)(singleglobal_load_dwordx4),num_threads = max(2*kWarpSize, bdx)All changes use
gpu_iface::kWarpSizeso CUDA continues to see warp32 geometry unchanged.HIP rope bindings
flashinfer/csrc_rocm/rope.cuapply_rope/apply_llama31_rope: nestedDISPATCH_PYTORCH_IDTYPE_TO_CTYPEinside the existing dtype dispatch supports both int32 and int64 position tensors. AddedTORCH_CHECK(offsets.scalar_type() == indptr.scalar_type())to catch dtype mismatches between the two index tensors.rope_quantize: 2D-K MLA branch — whenk_rope_in.ndim() == 2, setsnum_kv_heads=1and aliases the head stride to the token stride. AddedTORCH_CHECK(no_rope_dim == 0 || no_rope_dim % rope_dim == 0)to prevent OOB in the K-nope tiling loop.rope_quantize_append_paged_kv_cache: replaces theTORCH_CHECK(!has_mla_caches)hard-error with a fullpaged_kv_mla_tdispatch viaRopeQuantizeAppendPagedMLACache; addsTORCH_CHECK(!(has_gqa_caches && has_mla_caches))so callers providing both sets of caches fail fast; movesv_indtype and contiguity validation inside the GQA-only block.flashinfer/rope.py: addsv.dtype != q_rope.dtypecheck in the Python GQA path ofrope_quantize_fp8_append_paged_kv_cache.Test coverage
tests/rocm_tests/test_rope_hip.py:idtypeparametrize (torch.int32/torch.int64) ontest_rope_pos_idstest_rope_cos_sin_cache— uses 5× looser tolerances motivated by measured ~3e-2 bfloat16 rounding error (two cast_load/cast_store round-trips through bf16 in the interleaved rotation path)test_mla_rope_quantize(2D-K, DeepSeek-class config)test_generalized_rope_quantize_hipandtest_generalized_rope_quantize_append_kv_cache_hiptest_rope_quantize_fp8_append_paged_kv_cache_decode_hip(decode scenario)_rope_apply_interleave_f32reference helper to handle 2D K nativelytests/rocm_tests/test_activation_hip.py: removes staleif __name__ == "__main__":block.Test plan
pytest tests/rocm_tests/test_rope_hip.py -m "not slow"— 13,079 passed on MI300X (gfx942)apply_rope/apply_llama31_ropepass withint64position IDs across all(head_dim, interleave, inplace)variantstest_rope_cos_sin_cachepasses on HIP (previously skipped) with bfloat16 inputstest_mla_rope_quantizeand MLA variants in fused-quantize and cache-append tests passpytest tests/rocm_tests/ -m "not slow"— all rope tests pass; full suite results tracked separatelypre-commit run -apassesGenerated with Claude Code