Skip to content

[FlyDSL] fused RoPE kernel with layout APIs #300

Merged
coderfeli merged 7 commits intoROCm:mainfrom
amd-weisun:rope-layout-api
Apr 1, 2026
Merged

[FlyDSL] fused RoPE kernel with layout APIs #300
coderfeli merged 7 commits intoROCm:mainfrom
amd-weisun:rope-layout-api

Conversation

@amd-weisun
Copy link
Copy Markdown
Contributor

@amd-weisun amd-weisun commented Mar 27, 2026

As discussed in #272. Replace manual byte-offset arithmetic in fused_rope_cache_kernel with FlyDSL layout API (make_layout + crd2idx) for all structured address computations (Q/K/V, cos/sin, KV cache).

  • Declare tensor layouts as (shape, stride) tuples, use crd2idx for address computation instead of manual mul+add chains

92/92 tests passed, Cross-validated with AITER results.

Test Plan

Usage:

# Fast CI — correctness only (GPT-OSS 120B TP=8, 10 tests):
PYTHONPATH=./ pytest tests/kernels/test_fused_rope_cache.py -v -s

# All models × TPs (multi-model sweep):
FLYDSL_ALL_MODELS=1 PYTHONPATH=./ pytest tests/kernels/test_fused_rope_cache.py -v -s

# With benchmarking + optional AITER comparison:
FLYDSL_BENCH=1 AITER_REPO=../aiter PYTHONPATH=./ pytest tests/kernels/test_fused_rope_cache.py -v -s

# CLI — all models:
PYTHONPATH=./ python tests/kernels/test_fused_rope_cache.py --all-models

# CLI — with benchmark + AITER comparison:
FLYDSL_BENCH=1 AITER_REPO=../aiter PYTHONPATH=./ python tests/kernels/test_fused_rope_cache.py --all-models

Test Result

  • Tested on MI350: 0 numerical error vs PyTorch reference. Performance: 1.4-1.6x faster than Triton (AITER) across all configs ( GPT-OSS-120B, Qwen3, Llama-3.1), both layouts verified. Cross-validated against AITER output.

Submission Checklist

@amd-weisun amd-weisun changed the title [FlyDSL] Migrate RoPE kernel to layout API (make_layout + crd2idx) [FlyDSL] fused RoPE kernel with layout APIs Mar 27, 2026
@amd-weisun amd-weisun mentioned this pull request Mar 27, 2026
1 task
@amd-weisun amd-weisun marked this pull request as ready for review March 27, 2026 11:40
Copilot AI review requested due to automatic review settings March 27, 2026 11:40
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a fused RoPE + KV-cache kernel implementation that uses FlyDSL’s layout APIs for structured address calculations, along with a dedicated correctness/benchmarking test harness.

Changes:

  • Introduce kernels/fused_rope_cache_kernel.py implementing a 2-launch fused RoPE + KV-cache write using make_layout + crd2idx.
  • Add tests/kernels/test_fused_rope_cache.py covering flash/non-flash cache layouts, bf16/f16 dtypes, negative-slot behavior, and optional benchmarking/AITER cross-check.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
kernels/fused_rope_cache_kernel.py New fused kernel builder using layout-based indexing for Q/K/V and KV-cache address computations.
tests/kernels/test_fused_rope_cache.py New correctness + optional perf/AITER validation coverage for the fused kernel across layouts/dtypes.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/kernels/test_fused_rope_cache.py Outdated
Comment thread kernels/fused_rope_cache_kernel.py Outdated
amd-weisun and others added 2 commits March 27, 2026 13:25
Replace manual byte-offset arithmetic in fused_rope_cache_kernel with
FlyDSL layout API (make_layout + crd2idx) for all structured address
computations (Q/K/V, cos/sin, KV cache).

- Add _crd2idx_i32 helper to unwrap int_tuple -> i32 scalar for buffer
  offset math (same pattern as mfma_preshuffle_pipeline.py)
- Declare tensor layouts as (shape, stride) tuples, use crd2idx for
  address computation instead of manual mul+add chains
- Paired-half RoPE offset stays as arith (piecewise ±half_dim is not
  expressible as an affine layout stride)

92/92 tests passed, 0 errors, no performance regression

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@amd-weisun amd-weisun marked this pull request as draft March 27, 2026 14:19
@amd-weisun amd-weisun marked this pull request as draft March 27, 2026 14:19
@amd-weisun
Copy link
Copy Markdown
Contributor Author

As discussed with Felix, I am working on convert Paired-half RoPE offset to layout API as well.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/fused_rope_cache_kernel.py Outdated
Comment thread kernels/fused_rope_cache_kernel.py
Comment thread kernels/fused_rope_cache_kernel.py
Comment thread kernels/fused_rope_cache_kernel.py
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Collaborator

@coderfeli coderfeli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall clean implementation — layout API usage is correct, test coverage is solid (flash/non-flash, bf16/f16, negative slots, multi-model). A few suggestions below.

Comment thread kernels/fused_rope_cache_kernel.py
Comment thread kernels/fused_rope_cache_kernel.py Outdated
Comment thread kernels/fused_rope_cache_kernel.py
Comment thread kernels/fused_rope_cache_kernel.py Outdated
Comment thread kernels/fused_rope_cache_kernel.py
@amd-weisun amd-weisun marked this pull request as ready for review March 31, 2026 09:16
Comment thread tests/kernels/test_fused_rope_cache.py Outdated
Comment thread tests/kernels/test_fused_rope_cache.py
- Remove unreachable `vec_dwords != VEC_WIDTH` branches in bitcast
  calls (vec_dwords is always 4, VEC_WIDTH is always 8)
- Remove stale comment on launch function signature
- Remove sys.path manipulation in test (use PYTHONPATH instead)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderfeli coderfeli merged commit 5a1385c into ROCm:main Apr 1, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants