Skip to content

feat(hip): port RoPE to ROCm#223

Merged
demandal25 merged 9 commits into
ROCm:amd-integrationfrom
demandal25:port-positional-encoding
May 20, 2026
Merged

feat(hip): port RoPE to ROCm#223
demandal25 merged 9 commits into
ROCm:amd-integrationfrom
demandal25:port-positional-encoding

Conversation

@demandal25
Copy link
Copy Markdown
Collaborator

@demandal25 demandal25 commented May 12, 2026

Summary

Closes the remaining gaps between the CUDA and ROCm RoPE implementations: adds int64 position-id dispatch, enables MLA (DeepSeek-class) paths in the fused RoPE+FP8 quantize and paged-cache append kernels, and tunes launch geometry for CDNA3/Wave64.

What changed

DISPATCH_PYTORCH_IDTYPE_TO_CTYPE macro

  • flashinfer/csrc_rocm/pytorch_extension_utils.h — New macro mirroring CUDA's DISPATCH_DLPACK_IDTYPE_TO_CTYPE, dispatching at::kInt→int32_t and at::kLong→int64_t. Used by all five apply_rope* / apply_llama31_rope* HIP bindings so callers can freely pass either dtype for position tensors.

MLA support in the fused RoPE kernel

  • include/flashinfer/attention/generic/pos_enc.cuh — Adds typename CacheT = paged_kv_t<QuantType, IdType> as a 7th template parameter on RopeQuantizeAppendPagedKVCacheKernel. A constexpr bool IS_MLA = std::is_same<CacheT, paged_kv_mla_t<...>>::value branch selects at compile time:
    • K pointer arithmetic: 2D stride (no head dim) for MLA vs 3D for GQA/MHA
    • Cache writes: get_kpe_ptr / get_ckv_ptr (MLA) vs get_k_ptr / get_v_ptr (GQA/MHA)
    • V section: skipped entirely for MLA — total_blocks_y excludes the V blocks when IS_MLA
  • New RopeQuantizeAppendPagedMLACache<DType, IdType, QuantType> host launcher constructs a paged_kv_mla_t and calls the shared kernel.

CDNA3 launch-geometry tuning

Replaces NVIDIA-ported warp32 heuristics with Wave64-aware geometry throughout pos_enc.cuh:

Kernel group Old New
BatchQKApplyRotary* / BatchQKApplyLlama31Rotary* vec_size = max(16/sizeof(DType), HEAD_DIM/32), 128-thread block vec_size = max(16/sizeof(DType), HEAD_DIM/kWarpSize), num_threads = max(2*kWarpSize, bdx)
RopeQuantize / RopeQuantizeAppendPagedKVCache vec_size = 32/sizeof(DType), fixed 128 threads vec_size = 16/sizeof(DType) (single global_load_dwordx4), num_threads = max(2*kWarpSize, bdx)

All changes use gpu_iface::kWarpSize so CUDA continues to see warp32 geometry unchanged.

HIP rope bindings

  • flashinfer/csrc_rocm/rope.cu
    • apply_rope / apply_llama31_rope: nested DISPATCH_PYTORCH_IDTYPE_TO_CTYPE inside the existing dtype dispatch supports both int32 and int64 position tensors. Added TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) to catch dtype mismatches between the two index tensors.
    • rope_quantize: 2D-K MLA branch — when k_rope_in.ndim() == 2, sets num_kv_heads=1 and aliases the head stride to the token stride. Added TORCH_CHECK(no_rope_dim == 0 || no_rope_dim % rope_dim == 0) to prevent OOB in the K-nope tiling loop.
    • rope_quantize_append_paged_kv_cache: replaces the TORCH_CHECK(!has_mla_caches) hard-error with a full paged_kv_mla_t dispatch via RopeQuantizeAppendPagedMLACache; adds TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) so callers providing both sets of caches fail fast; moves v_in dtype and contiguity validation inside the GQA-only block.
  • flashinfer/rope.py: adds v.dtype != q_rope.dtype check in the Python GQA path of rope_quantize_fp8_append_paged_kv_cache.

Test coverage

  • tests/rocm_tests/test_rope_hip.py:
    • Enables idtype parametrize (torch.int32 / torch.int64) on test_rope_pos_ids
    • Un-skips test_rope_cos_sin_cache — uses 5× looser tolerances motivated by measured ~3e-2 bfloat16 rounding error (two cast_load/cast_store round-trips through bf16 in the interleaved rotation path)
    • Adds test_mla_rope_quantize (2D-K, DeepSeek-class config)
    • Adds MLA variants to test_generalized_rope_quantize_hip and test_generalized_rope_quantize_append_kv_cache_hip
    • Adds test_rope_quantize_fp8_append_paged_kv_cache_decode_hip (decode scenario)
    • Updates _rope_apply_interleave_f32 reference helper to handle 2D K natively
  • tests/rocm_tests/test_activation_hip.py: removes stale if __name__ == "__main__": block.

Test plan

  • pytest tests/rocm_tests/test_rope_hip.py -m "not slow" — 13,079 passed on MI300X (gfx942)
  • No regressions in existing GQA/MHA rope paths
  • apply_rope / apply_llama31_rope pass with int64 position IDs across all (head_dim, interleave, inplace) variants
  • test_rope_cos_sin_cache passes on HIP (previously skipped) with bfloat16 inputs
  • test_mla_rope_quantize and MLA variants in fused-quantize and cache-append tests pass
  • pytest tests/rocm_tests/ -m "not slow" — all rope tests pass; full suite results tracked separately
  • pre-commit run -a passes

Generated with Claude Code

Copilot AI review requested due to automatic review settings May 12, 2026 05:40
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the ROCm/HIP RoPE implementation to match CUDA feature coverage by adding int64 position-id dispatch, enabling MLA (2D-K) support in fused RoPE+FP8 quantize and paged-cache append paths, and tuning kernel launch geometry (notably for CDNA3).

Changes:

  • Add DISPATCH_PYTORCH_IDTYPE_TO_CTYPE and use it in HIP RoPE bindings to support int32/int64 position-id tensors.
  • Unify the fused RoPE+quantize+append kernel to support both GQA/MHA (paged_kv_t) and MLA (paged_kv_mla_t) via a CacheT template + constexpr branching; add an MLA-specific host launcher.
  • Expand ROCm test coverage for pos-id dtype variants and MLA paths; re-enable the cos/sin-cache test on HIP with looser tolerances.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
flashinfer/csrc_rocm/pytorch_extension_utils.h Adds an idtype dispatch macro used by ROCm bindings to accept int32/int64 tensors.
flashinfer/csrc_rocm/rope.cu Updates HIP bindings to dispatch idtype, adds MLA support in quantize/append paths, and routes MLA cache appends through paged_kv_mla_t.
include/flashinfer/attention/generic/pos_enc.cuh Generalizes the fused kernel to support MLA cache layout and updates launch geometry using kWarpSize.
tests/rocm_tests/test_rope_hip.py Adds idtype parametrization, MLA coverage, and re-enables cos/sin-cache testing with HIP-appropriate tolerances.
tests/rocm_tests/test_activation_hip.py Removes ad-hoc __main__ execution block from the test file.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread flashinfer/csrc_rocm/rope.cu
Comment thread flashinfer/csrc_rocm/rope.cu
Comment thread flashinfer/csrc_rocm/rope.cu Outdated
Comment thread flashinfer/csrc_rocm/rope.cu
@demandal25 demandal25 changed the title fix(hip): port RoPE positional encoding to ROCm — MLA support, idtype dispatch, CDNA3 tuning fix(hip): port RoPE positional encoding to ROCm May 12, 2026
@demandal25 demandal25 marked this pull request as draft May 12, 2026 14:14
@demandal25 demandal25 changed the title fix(hip): port RoPE positional encoding to ROCm feat(hip): port RoPE positional encoding to ROCm May 18, 2026
@demandal25 demandal25 force-pushed the port-positional-encoding branch from c7f822c to ae4ced6 Compare May 20, 2026 02:23
@demandal25 demandal25 changed the title feat(hip): port RoPE positional encoding to ROCm feat(hip): port RoPE positional encoding to ROCm using HIP May 20, 2026
demandal25 added a commit to demandal25/flashinfer that referenced this pull request May 20, 2026
…usivity

Address Copilot review comments on PR ROCm#223:
- Add TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) in
  apply_rope and apply_llama31_rope to prevent UB when caller passes
  mismatched dtypes (e.g. indptr=int64, offsets=int32).
- Add TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) in
  rope_quantize_append_paged_kv_cache so callers with both cache sets
  get a clear error instead of silently falling through to the GQA path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@demandal25 demandal25 marked this pull request as ready for review May 20, 2026 04:56
Copilot AI review requested due to automatic review settings May 20, 2026 04:56
demandal25 added a commit to demandal25/flashinfer that referenced this pull request May 20, 2026
…usivity

Address Copilot review comments on PR ROCm#223:
- Add TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) in
  apply_rope and apply_llama31_rope to prevent UB when caller passes
  mismatched dtypes (e.g. indptr=int64, offsets=int32).
- Add TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) in
  rope_quantize_append_paged_kv_cache so callers with both cache sets
  get a clear error instead of silently falling through to the GQA path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@demandal25 demandal25 force-pushed the port-positional-encoding branch from c3553a9 to fa233b9 Compare May 20, 2026 04:57
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (1)

include/flashinfer/attention/generic/pos_enc.cuh:1236

  • Similarly, the Q-noRoPE section unconditionally loads/stores vec_size elements per thread with no tail guard. For no_rope_dim values that aren’t multiples of rope_dim, this can read/write out of bounds. Add a tail predicate (or validate no_rope_dim % rope_dim == 0 up front) so the last chunk doesn’t access beyond q_nope_* tensors.
    } else {
      uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
      uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
      uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
      uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;

      DType* q_nope_in_ptr =
          q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n,
                                           q_nope_in_stride_h);
      QuantType* q_nope_out_ptr =
          q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n,
                                            q_nope_out_stride_h);

      vec_t q_nope_vec;
      q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
#pragma unroll

Comment thread flashinfer/csrc_rocm/rope.cu
Comment thread flashinfer/csrc_rocm/rope.cu
Comment thread include/flashinfer/attention/generic/pos_enc.cuh
@demandal25 demandal25 changed the title feat(hip): port RoPE positional encoding to ROCm using HIP feat(hip): port RoPE to ROCm — MLA, int64 idtype, CDNA3 tuning, AITER decode fixes May 20, 2026
demandal25 added a commit to demandal25/flashinfer that referenced this pull request May 20, 2026
…usivity

Address Copilot review comments on PR ROCm#223:
- Add TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) in
  apply_rope and apply_llama31_rope to prevent UB when caller passes
  mismatched dtypes (e.g. indptr=int64, offsets=int32).
- Add TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) in
  rope_quantize_append_paged_kv_cache so callers with both cache sets
  get a clear error instead of silently falling through to the GQA path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 20, 2026 05:23
@demandal25 demandal25 force-pushed the port-positional-encoding branch from 1fba9c2 to 31c88b9 Compare May 20, 2026 05:23
@demandal25 demandal25 changed the title feat(hip): port RoPE to ROCm — MLA, int64 idtype, CDNA3 tuning, AITER decode fixes feat(hip): port RoPE to ROCm — MLA, int64 idtype, CDNA3 tuning May 20, 2026
@demandal25 demandal25 changed the title feat(hip): port RoPE to ROCm — MLA, int64 idtype, CDNA3 tuning feat(hip): port RoPE to ROCm May 20, 2026
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.

Comment thread flashinfer/csrc_rocm/rope.cu
Comment thread flashinfer/csrc_rocm/rope.cu
Comment thread flashinfer/csrc_rocm/rope.cu Outdated
Comment thread include/flashinfer/attention/generic/pos_enc.cuh
demandal25 and others added 7 commits May 20, 2026 06:19
… utils

Mirrors the CUDA DISPATCH_DLPACK_IDTYPE_TO_CTYPE pattern, dispatching
at::kInt→int32_t and at::kLong→int64_t. Used by the rope bindings to
support int64 position-id tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…CDNA3 tuning

- Add typename CacheT = paged_kv_t<QuantType, IdType> as 7th template
  parameter to RopeQuantizeAppendPagedKVCacheKernel; rename paged_kv
  to paged_kv_like so GQA and MLA share one kernel body.
- Add IS_MLA constexpr branch: 2D K pointer arithmetic for input,
  get_kpe_ptr/get_ckv_ptr cache stores, skip V section.
- Add RopeQuantizeAppendPagedMLACache host launcher that instantiates
  the kernel with CacheT=paged_kv_mla_t.
- CDNA3 launch geometry: vec_size=16/sizeof(DType) (single
  global_load_dwordx4), num_threads=max(2*kWarpSize, bdx).
- Add <type_traits> include for std::is_same.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Wrap all apply_rope* and apply_llama31_rope* position-id lookups in
  DISPATCH_PYTORCH_IDTYPE_TO_CTYPE, supporting int32 and int64 pos_ids.
- rope_quantize: add 2D-K MLA branch (k_rope_in.dim()==2 → num_kv_heads=1,
  head stride aliased to token stride) matching csrc/rope.cu:319-394.
- rope_quantize_append_paged_kv_cache: replace the MLA hard-error with
  a full MLA dispatch that constructs paged_kv_mla_t and calls the new
  RopeQuantizeAppendPagedMLACache launcher.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…tivation test

test_rope_hip.py:
- Enable idtype parametrize (int32/int64) on test_rope_pos_ids.
- Remove skip from test_rope_cos_sin_cache.
- Add MLA variants to test_generalized_rope_quantize_hip and
  test_generalized_rope_quantize_append_kv_cache_hip.
- Add test_mla_rope_quantize for the no-append MLA path.
- Fix _rope_apply_interleave_f32 to handle 2D K natively via
  unsqueeze/squeeze inside _rot().

test_activation_hip.py:
- Remove debug if __name__ == "__main__" block.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In the C++ binding, check that v_in is contiguous along the last dim and
its scalar type matches q_rope_in before computing strides.  Mirror the
check at the Python entry point so mismatched v dtypes are caught early
with a clear error message.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…nstraint in rope bindings

Remove self-evident MLA/GQA path labels and stale line-reference comments
flagged in simplify/review pass. Add WHY comments explaining that pos_ids is
int32-only in paged-cache paths (shared int32 index arithmetic). Document the
observed bf16 rounding error (~3e-2) that drives the 5e-2 tolerance in
test_rope_cos_sin_cache.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…usivity

Address Copilot review comments on PR ROCm#223:
- Add TORCH_CHECK(offsets.scalar_type() == indptr.scalar_type()) in
  apply_rope and apply_llama31_rope to prevent UB when caller passes
  mismatched dtypes (e.g. indptr=int64, offsets=int32).
- Add TORCH_CHECK(!(has_gqa_caches && has_mla_caches)) in
  rope_quantize_append_paged_kv_cache so callers with both cache sets
  get a clear error instead of silently falling through to the GQA path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…undary

The kernel tiles no_rope_dim in rope_dim-sized chunks using ceiling division.
If no_rope_dim is not a multiple of rope_dim, the last chunk reads/writes past
the end of the K-nope buffer. Add TORCH_CHECK in both rope_quantize and
rope_quantize_append_paged_kv_cache to catch this at the API boundary rather
than silently causing OOB in the kernel.

All known real-world configs satisfy the constraint (DeepSeek MLA:
no_rope_dim=512, rope_dim=64; Llama: no_rope_dim=0).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@demandal25 demandal25 force-pushed the port-positional-encoding branch from 31c88b9 to 602aff6 Compare May 20, 2026 06:20
…omments

- Add K-tensor shape checks in rope_quantize (token/head/last dims, both 2D MLA and 3D GQA paths) to fail fast on caller shape mismatches instead of producing OOB inside the kernel.
- Add K and cache shape checks in the MLA branch of rope_quantize_append_paged_kv_cache (k_rope_in/k_nope_in token+last dims; ckv_cache/kpe_cache rank, page_size, and last-dim parity with rope_dim/no_rope_dim).
- Update the RopeQuantizeAppendPagedKVCacheKernel doc comment: it dispatches between GQA/MHA and MLA via the CacheT template and the IS_MLA constexpr, not GQA-only.
- Reword the pos_ids int32 rationale in rope_quantize: that path is the non-paged kernel; the real reason is that the kernel template is instantiated with IdType=int32_t.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 20, 2026 06:38
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated no new comments.

@demandal25 demandal25 merged commit c81ea05 into ROCm:amd-integration May 20, 2026
4 of 6 checks passed
@demandal25 demandal25 deleted the port-positional-encoding branch May 20, 2026 10:45
demandal25 added a commit that referenced this pull request May 21, 2026
## Summary

Refresh the FlashInfer+ROCm README aimed at library consumers, refresh
the Feature Support Matrix to match what has actually landed on
`amd-integration`, and align the ROCm MLA wrapper with the rest of the
ROCm backends so `backend="auto"` is accepted everywhere.

### What changed

#### `README.md`

- **Intro and structure.** Tighten the intro to call out HIP-in-repo
kernels vs AITER dispatch up front; link to the Feature Support Matrix
and AITER sections from the first paragraph. Cross-link CDNA3 / CDNA4 to
AMD's official architecture whitepapers on first mention.
- **Feature Support Matrix.** Replaced with a five-column table (Kernel
/ HIP / AITER / `backend="auto"` resolves to / Notes). New ✅ rows:
Cascade (#221), MLA via AITER (#232), RoPE (#223), paged KV-cache
append, RMSNorm via AITER (#232), sliding-window decode on the AITER
path (#234), activation, quantization, and opt-in `torch.compile`
(#210). Every ✅ is backed by a `tests/rocm_tests/test_*_hip.py`. FP8
status is folded into per-row notes rather than a dedicated column.
- **GPU / ROCm / PyTorch.** Consolidated into one section with arch
codenames inline (gfx942 → MI300X/MI325X = CDNA3, gfx950 → MI355X =
CDNA4). `pip install torch` uses `--index-url` instead of `-f` so pip
cannot silently fall back to a CPU-only PyPI wheel (matches CLAUDE.md).
- **Getting Started.** Collapsed the Docker image table to the latest
validated tag and pointed at Docker Hub for older releases. Dropped the
manual `micromamba activate base` step (the env is auto-activated). Used
the concrete image tag plus a `--name=flashinfer-rocm` in the `docker
run` snippet.
- **Trying the Examples.** Simplified to point at `examples/` plus one
run command — no wget-based downloads.
- **Install from Source.** Renamed from "Build from Source"; rewrote the
ambiguous "Environment name varies …" note (and later removed it once
the build / run blocks made the matching tag self-evident).
- **AITER Support.** Collapsed the section intro to avoid re-listing
conditions already in the matrix; cross-link Known Limitations. Rewrote
Known Limitations preamble to state the two-group split (hard errors vs
silently-ignored kwargs). Dropped the redundant Single Prefill Example
(Basic Usage already shows the call pattern).
- **Environment Variables.** New section documenting runtime env vars —
`FLASHINFER_USE_TORCH_CUSTOM_OPS`, `FLASHINFER_HIP_FUSED_CASCADE`,
`FLASHINFER_LOGGING_LEVEL`, `FLASHINFER_DISABLE_JIT`, `ROCM_PATH` /
`ROCM_HOME`. Build-time vars stay in `CLAUDE.md` and are linked from
here.
- **Runtime Helpers.** Short snippet showing `is_aiter_supported` and
`check_torch_rocm_compatibility`; calls out
`validate_flashinfer_rocm_arch` as a build-time validator, not a runtime
helper.
- **CPX-mode pytest notes.** Split the dense paragraph into labelled
bullets (Worker count / Reruns / `slow` marker / HIPBLAS retry).
- **Basic Usage.** Moved to the end of the README as a closing example.
- **License and Acknowledgements.** Added; the contributing reminder
lives on its own line.

#### `flashinfer/mla_rocm.py` + `tests/rocm_tests/test_mla_aiter_hip.py`

- Accept `backend="auto"` as an alias for `"aiter"` on the ROCm MLA
wrapper (default is now `"auto"` to match every other ROCm wrapper).
Previously the wrapper raised `ValueError` on anything other than
`"aiter"`, leaving MLA as the odd one out in the public API even though
there is exactly one implementation to pick from on ROCm.
- New tests: `test_mla_backend_accepts_auto_and_aiter` (parametrized
over both values) and `test_mla_backend_rejects_unsupported` (confirms
`backend="fa2"` still raises; runs without a GPU since the check fires
before the AITER probe).

## Test plan

- [x] `pre-commit run -a` passes.
- [x] `pre-commit run markdownlint --files README.md` passes after every
change.
- [x] Every TOC entry resolves to an `##` heading in the body.
- [x] Every ✅ in the Feature Support Matrix has a backing
`tests/rocm_tests/test_*_hip.py`.
- [x] `pytest tests/rocm_tests/test_mla_aiter_hip.py` — 11 passed.
- [x] Render the README on the PR page and visually confirm tables, code
blocks, and `<details>` sections look right.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants