Skip to content

support rta in pa#1

Merged
valarLip merged 1 commit into
mainfrom
pa_bf16_rta
Nov 20, 2024
Merged

support rta in pa#1
valarLip merged 1 commit into
mainfrom
pa_bf16_rta

Conversation

@carlushuang
Copy link
Copy Markdown
Collaborator

No description provided.

@valarLip valarLip merged commit 6499281 into main Nov 20, 2024
@carlushuang carlushuang deleted the pa_bf16_rta branch November 21, 2024 14:17
alfuyao-amd pushed a commit that referenced this pull request Jun 2, 2025
amd-youchen pushed a commit to amd-youchen/aiter that referenced this pull request Nov 25, 2025
Merge mrope & trtllm allreduce fusion
valarLip added a commit that referenced this pull request Mar 18, 2026
@gyohuangxin gyohuangxin mentioned this pull request Mar 18, 2026
11 tasks
chun-wan pushed a commit that referenced this pull request Apr 29, 2026
bb22fd3 added DPP-based warp_reduce_sum / half_warp_reduce_sum /
warp_shfl_xor_sync_vec helpers in mrope_utils::block_utils that referenced
opus::mov_dpp / opus::number / opus::bool_constant directly. But
rope_common.h gates `#include "opus/opus.hpp"` with #ifdef
__HIP_DEVICE_COMPILE__ — so during the HIP host pass, `opus` is not a
declared namespace. clang HIP parses template bodies in BOTH passes for
non-dependent-name lookup, so any TU that includes rope_common.h without
otherwise pulling in opus.hpp transitively (e.g. via quant_utils.cuh)
fails the host pass with "error: use of undeclared identifier 'opus'".

This caused CI build failure on
csrc/kernels/rope/general_2c_cached_positions_offsets_fwd_kernels.cu,
which is why the failure was masked locally — the fused_qk_norm_rope JIT
target pulls quant_utils.cuh (-> opus.hpp) and was building fine.

Two changes, both bit-for-bit equivalent at the GPU instruction level:

1. warp_reduce_sum / half_warp_reduce_sum bodies are now wrapped in
   #ifdef __HIP_DEVICE_COMPILE__ (matching the existing pattern used
   throughout rope_common.h, e.g. line 564). The host pass now skips the
   opus::mov_dpp calls entirely and the function returns `val` unchanged
   — fine because these helpers are __device__-only and never called
   from host code. Device-pass body is unchanged.

2. warp_shfl_xor_sync_vec's tag-dispatch parameter went from
   `opus::number<XorOffset> = {}` to `std::integral_constant<int,
   XorOffset> = {}`. The signature is parsed in both passes regardless
   of #ifdef wrapping the body, so this one CAN'T be hidden the way #1
   is. opus::number<I> is publicly derived from
   std::integral_constant<index_t, I> (csrc/include/opus/opus.hpp:57), so
   existing callers passing opus::number<X>{} continue to work
   unchanged via pass-by-value slicing of the empty derived type.

Verified:
- general_2c_cached_positions_offsets_fwd_kernels.cu now builds (full JIT
  re-run succeeds; previously failed in host pass on line 7260).
- fused_rope_rms_1way_quad and fused_rope_rms_1way (fallback) both
  numerically match pure-FP32 PyTorch reference at rtol=1e-2 atol=0.05
  for shapes that exercise both warp_reduce_sum and warp_shfl_xor_sync_vec
  (NEOX path) and the fallback inline RMSNorm.
- rocprofv3 fallback sweep over 32 (T, Hq, Hk, NEOX/INTR) cells: geomean
  speedup vs pre-388f737ba baseline = 1.22x (identical to the measurement
  before this fix — opus::mov_dpp lowers to the same v_*_dpp instruction
  as __builtin_amdgcn_mov_dpp; opus.hpp:1559-1562 is a thin wrapper).
chun-wan pushed a commit that referenced this pull request May 4, 2026
bb22fd3 added DPP-based warp_reduce_sum / half_warp_reduce_sum /
warp_shfl_xor_sync_vec helpers in mrope_utils::block_utils that referenced
opus::mov_dpp / opus::number / opus::bool_constant directly. But
rope_common.h gates `#include "opus/opus.hpp"` with #ifdef
__HIP_DEVICE_COMPILE__ — so during the HIP host pass, `opus` is not a
declared namespace. clang HIP parses template bodies in BOTH passes for
non-dependent-name lookup, so any TU that includes rope_common.h without
otherwise pulling in opus.hpp transitively (e.g. via quant_utils.cuh)
fails the host pass with "error: use of undeclared identifier 'opus'".

This caused CI build failure on
csrc/kernels/rope/general_2c_cached_positions_offsets_fwd_kernels.cu,
which is why the failure was masked locally — the fused_qk_norm_rope JIT
target pulls quant_utils.cuh (-> opus.hpp) and was building fine.

Two changes, both bit-for-bit equivalent at the GPU instruction level:

1. warp_reduce_sum / half_warp_reduce_sum bodies are now wrapped in
   #ifdef __HIP_DEVICE_COMPILE__ (matching the existing pattern used
   throughout rope_common.h, e.g. line 564). The host pass now skips the
   opus::mov_dpp calls entirely and the function returns `val` unchanged
   — fine because these helpers are __device__-only and never called
   from host code. Device-pass body is unchanged.

2. warp_shfl_xor_sync_vec's tag-dispatch parameter went from
   `opus::number<XorOffset> = {}` to `std::integral_constant<int,
   XorOffset> = {}`. The signature is parsed in both passes regardless
   of #ifdef wrapping the body, so this one CAN'T be hidden the way #1
   is. opus::number<I> is publicly derived from
   std::integral_constant<index_t, I> (csrc/include/opus/opus.hpp:57), so
   existing callers passing opus::number<X>{} continue to work
   unchanged via pass-by-value slicing of the empty derived type.

Verified:
- general_2c_cached_positions_offsets_fwd_kernels.cu now builds (full JIT
  re-run succeeds; previously failed in host pass on line 7260).
- fused_rope_rms_1way_quad and fused_rope_rms_1way (fallback) both
  numerically match pure-FP32 PyTorch reference at rtol=1e-2 atol=0.05
  for shapes that exercise both warp_reduce_sum and warp_shfl_xor_sync_vec
  (NEOX path) and the fallback inline RMSNorm.
- rocprofv3 fallback sweep over 32 (T, Hq, Hk, NEOX/INTR) cells: geomean
  speedup vs pre-388f737ba baseline = 1.22x (identical to the measurement
  before this fix — opus::mov_dpp lowers to the same v_*_dpp instruction
  as __builtin_amdgcn_mov_dpp; opus.hpp:1559-1562 is a thin wrapper).
LiuYinfeng01 pushed a commit that referenced this pull request May 21, 2026
bb22fd3 added DPP-based warp_reduce_sum / half_warp_reduce_sum /
warp_shfl_xor_sync_vec helpers in mrope_utils::block_utils that referenced
opus::mov_dpp / opus::number / opus::bool_constant directly. But
rope_common.h gates `#include "opus/opus.hpp"` with #ifdef
__HIP_DEVICE_COMPILE__ — so during the HIP host pass, `opus` is not a
declared namespace. clang HIP parses template bodies in BOTH passes for
non-dependent-name lookup, so any TU that includes rope_common.h without
otherwise pulling in opus.hpp transitively (e.g. via quant_utils.cuh)
fails the host pass with "error: use of undeclared identifier 'opus'".

This caused CI build failure on
csrc/kernels/rope/general_2c_cached_positions_offsets_fwd_kernels.cu,
which is why the failure was masked locally — the fused_qk_norm_rope JIT
target pulls quant_utils.cuh (-> opus.hpp) and was building fine.

Two changes, both bit-for-bit equivalent at the GPU instruction level:

1. warp_reduce_sum / half_warp_reduce_sum bodies are now wrapped in
   #ifdef __HIP_DEVICE_COMPILE__ (matching the existing pattern used
   throughout rope_common.h, e.g. line 564). The host pass now skips the
   opus::mov_dpp calls entirely and the function returns `val` unchanged
   — fine because these helpers are __device__-only and never called
   from host code. Device-pass body is unchanged.

2. warp_shfl_xor_sync_vec's tag-dispatch parameter went from
   `opus::number<XorOffset> = {}` to `std::integral_constant<int,
   XorOffset> = {}`. The signature is parsed in both passes regardless
   of #ifdef wrapping the body, so this one CAN'T be hidden the way #1
   is. opus::number<I> is publicly derived from
   std::integral_constant<index_t, I> (csrc/include/opus/opus.hpp:57), so
   existing callers passing opus::number<X>{} continue to work
   unchanged via pass-by-value slicing of the empty derived type.

Verified:
- general_2c_cached_positions_offsets_fwd_kernels.cu now builds (full JIT
  re-run succeeds; previously failed in host pass on line 7260).
- fused_rope_rms_1way_quad and fused_rope_rms_1way (fallback) both
  numerically match pure-FP32 PyTorch reference at rtol=1e-2 atol=0.05
  for shapes that exercise both warp_reduce_sum and warp_shfl_xor_sync_vec
  (NEOX path) and the fallback inline RMSNorm.
- rocprofv3 fallback sweep over 32 (T, Hq, Hk, NEOX/INTR) cells: geomean
  speedup vs pre-388f737ba baseline = 1.22x (identical to the measurement
  before this fix — opus::mov_dpp lowers to the same v_*_dpp instruction
  as __builtin_amdgcn_mov_dpp; opus.hpp:1559-1562 is a thin wrapper).
LiuYinfeng01 pushed a commit that referenced this pull request May 26, 2026
bb22fd3 added DPP-based warp_reduce_sum / half_warp_reduce_sum /
warp_shfl_xor_sync_vec helpers in mrope_utils::block_utils that referenced
opus::mov_dpp / opus::number / opus::bool_constant directly. But
rope_common.h gates `#include "opus/opus.hpp"` with #ifdef
__HIP_DEVICE_COMPILE__ — so during the HIP host pass, `opus` is not a
declared namespace. clang HIP parses template bodies in BOTH passes for
non-dependent-name lookup, so any TU that includes rope_common.h without
otherwise pulling in opus.hpp transitively (e.g. via quant_utils.cuh)
fails the host pass with "error: use of undeclared identifier 'opus'".

This caused CI build failure on
csrc/kernels/rope/general_2c_cached_positions_offsets_fwd_kernels.cu,
which is why the failure was masked locally — the fused_qk_norm_rope JIT
target pulls quant_utils.cuh (-> opus.hpp) and was building fine.

Two changes, both bit-for-bit equivalent at the GPU instruction level:

1. warp_reduce_sum / half_warp_reduce_sum bodies are now wrapped in
   #ifdef __HIP_DEVICE_COMPILE__ (matching the existing pattern used
   throughout rope_common.h, e.g. line 564). The host pass now skips the
   opus::mov_dpp calls entirely and the function returns `val` unchanged
   — fine because these helpers are __device__-only and never called
   from host code. Device-pass body is unchanged.

2. warp_shfl_xor_sync_vec's tag-dispatch parameter went from
   `opus::number<XorOffset> = {}` to `std::integral_constant<int,
   XorOffset> = {}`. The signature is parsed in both passes regardless
   of #ifdef wrapping the body, so this one CAN'T be hidden the way #1
   is. opus::number<I> is publicly derived from
   std::integral_constant<index_t, I> (csrc/include/opus/opus.hpp:57), so
   existing callers passing opus::number<X>{} continue to work
   unchanged via pass-by-value slicing of the empty derived type.

Verified:
- general_2c_cached_positions_offsets_fwd_kernels.cu now builds (full JIT
  re-run succeeds; previously failed in host pass on line 7260).
- fused_rope_rms_1way_quad and fused_rope_rms_1way (fallback) both
  numerically match pure-FP32 PyTorch reference at rtol=1e-2 atol=0.05
  for shapes that exercise both warp_reduce_sum and warp_shfl_xor_sync_vec
  (NEOX path) and the fallback inline RMSNorm.
- rocprofv3 fallback sweep over 32 (T, Hq, Hk, NEOX/INTR) cells: geomean
  speedup vs pre-388f737ba baseline = 1.22x (identical to the measurement
  before this fix — opus::mov_dpp lowers to the same v_*_dpp instruction
  as __builtin_amdgcn_mov_dpp; opus.hpp:1559-1562 is a thin wrapper).
valarLip pushed a commit that referenced this pull request May 26, 2026
* [fused_qk_norm_rope] add 1way kernel; replace ds_bpermute with DPP+ds_swizzle in rms-reduce/NEOX

New kernel `fused_qk_norm_rope_1way` mirrors the existing 2way kernel for
single-token-stream models: per-head RMSNorm followed by RoPE on q/k,
supporting NEOX (half-head split) and interleaved (adjacent-pair) styles,
head_size ∈ {64, 128, 256}, BF16/FP16. No partial rotary, no KV cache,
no quantization — minimal scope of 2way.

Shared rope_common.h shuffle helpers reworked to take advantage of DPP and
ds_swizzle on gfx9xx. The previous implementation went through __shfl_xor
with width=32 on a 64-lane wave, which lowers to ds_bpermute_b32 (~10 cyc)
even for compile-time-constant XOR offsets:

  block_utils::warp_reduce_sum<float>:
    before: 5x ds_bpermute_b32 (offsets 16,8,4,2,1) + 1x __shfl broadcast
    after:  3x ds_swizzle_b32 (offsets 16,8,4 via XOR mask) +
            2x v_mov_b32_dpp  (offsets 2,1   via quad_perm 0x4e/0xb1)
    The XOR butterfly is symmetric so the post-reduce broadcast is a no-op
    and is removed. Order kept as 16->1 to make the FP32 accumulation
    bitwise-identical to the previous bpermute-based path.

  warp_shfl_xor_sync_vec<T, vec_size, XorOffset>: new helper for vectorised
  constant XOR shuffles, lowers to ds_swizzle. Used at the two NEOX
  neighbour-swap call sites in fused_rope_rms_1way_kernel — replaces the
  runtime `lane + neighbor_offset` arithmetic that lowered to ds_bpermute.

Verification on production config B=2 T=4096 D=128 NEOX, Hq=24 Hk=25 (BF16):
  - Bitwise identical output vs the old ds_bpermute helper across NEOX +
    interleaved, BF16 + FP16, D ∈ {64, 128, 256}, T ∈ {32, 1024, 4096}.
  - Latency: ~8% reduction (652 GB/s achieved, up from 600 GB/s).
  - PMC: SQ_WAIT_INST_ANY -25%, SQ_BUSY_CYCLES -11%, SQ_INSTS_VALU -10%.
  - Disasm: 0 ds_bpermute_b32 (was 10), 7 ds_swizzle_b32, 2 v_mov_b32_dpp.

Also fixes a pre-existing OOB read on the interleaved RoPE path: cos/sin
are only VEC_SIZE/2 elements per lane but the vec_t::load issued a full
VEC_SIZE read, racing past the cos_sin buffer tail on the last token. The
1way kernel now uses scalar loads of exactly VEC_SIZE/2; the 2way kernel
still has the original load (separate fix recommended).

Tests: op_tests/test_fused_qk_norm_rope_cache_quant.py adds a
`test_qk_norm_rope_1way` + sweep across dtype/D/T/interleaved, all pass
checkAllclose(rtol=1e-2, atol=0.05) vs the torch reference.

* [fused_qk_norm_rope] add quad fast path to 1way kernel

Adds a 4-head-group ("quad") fast path to the 1way fused QK-Norm+RoPE kernel
for shapes where num_heads_q % 4 == 0 && num_heads_k % 4 == 0. Each physical
wave (64 lanes) maps to 4 heads x 16 lanes-per-head, packing 4 adjacent heads
into one wave instead of the 1-head-per-wave default path. Renames the internal
"pair-x2" variant to "quad" for clarity; the 2-head "pair" variant did not
pay off vs quad and is dropped.

* [fused_qk_norm_rope] quad kernel: kill NEOX divergent branch, use packed bf16 cvt

Two bit-exact-equivalent optimizations to fused_rope_rms_1way_quad_kernel
identified from rocprofv3 ATT trace + disasm inspection.

1. NEOX rope: replace the divergent if(is_lower_half){...}else{...} with a
   per-lane cndmask select. The divergent branch forced the compiler to
   emit two copies of the rope math AND the bf16 cvt sequence, with
   s_and_saveexec / s_xor / s_or EXEC mask switches between them. The new
   form computes both x*c - nx*nc and x*nc + nx*c in FP32 on every lane
   (same op order as the original divergent code) and selects via cndmask.

2. Add f32x2_to_bf16x2_rne + pack_f32_to_vec_t<T,N> helpers in
   rope_common.h, adapted from the gfx94 RNE branch of float_2_bf16_pair<0>
   in aiter/csrc/kernels/mla/hk/hk_mla_buffer_managers.cuh. The helper
   replaces the compiler default static_cast<bf16>(float) expansion (13
   instructions plus EXEC mask switches per output) with a 10-instruction
   VALU sequence with no EXEC mask manipulation. Used for both RMSNorm
   output and rope output in the NEOX and INTR paths. RNE rounding is
   bit-identical to the ctor for non-NaN inputs (NaN is replaced with
   canonical 0x7FFF, which is unreachable from finite RMSNorm / rope
   inputs).

Verified bit-exact equivalent to pre-change kernel: stashed both files,
built HEAD (53f3f65), dumped raw output bytes for 6 (seed, mode, shape)
configs (NEOX/INTR x T={127,1024,8192} x (Hq,Hk) in {(16,16),(24,24),
(32,32)}); restored optimizations, rebuilt, dumped again; md5sum compare
matched on all 12 binary files (~150M bf16 elements total).

Disasm impact for quad<bf16, 128, NEOX, 6, 6> on gfx942:
  total disasm lines: 1212 -> 591 (-51%)
  default static_cast<bf16> seqs: 48 -> 0
  s_and_saveexec EXEC mask switches: 47 -> 2 (-96%)

Wall-clock impact at B=1, T=8192, Hq=Hk=24, D=128 (Qwen-Image-2 shape):
  NEOX: 224.77 -> 141.05 us (1.59x speedup, -83.7 us)
  INTR: 152.07 -> 123.07 us (1.24x speedup, -29.0 us)

* [fused_qk_norm_rope] fallback 1way kernel: inline RMSNorm + packed bf16 cvt

Apply the same kill-divergent-branch + packed-bf16-cvt optimizations that
landed for the quad fast path (e66f419) to the single-head-per-warp
fused_rope_rms_1way_kernel fallback (used when num_heads_q or num_heads_k
is not divisible by 4 — e.g. Qwen-Image-2 prefill, Hq=24, Hk=25).

Three changes, all bit-exact-equivalent:

1. Inline RMSNorm instead of the shared mrope_utils::warp_rms_norm_<T,N>
   helper. The shared helper does

       acc = sum( static_cast<float>(input[i])^2 )      ; reduce
       input[i] = static_cast<T>(static_cast<float>(input[i]) * s_val
                                  * static_cast<float>(gamma[i]))

   which forces the compiler to re-cast bf16->f32 in the writeback
   (redundant v_lshlrev_b32 conversions) and emit the default 13-instr
   static_cast<bf16>(float) sequence per output. Inlining lets us cache
   the f32 reads in a stack array v[VEC_SIZE], do the writeback in f32,
   and pack via pack_f32_to_vec_t (10 instr per bf16x2 pair). RNE
   rounding bit-identical for finite inputs (NaN -> canonical 0x7fff).

   Per the user's directive, the shared helper is left untouched (it is
   used by 4+ unrelated kernels) — the inlining is local to the 1way
   fallback.

2. NEOX rope: replace the divergent if(is_lower_half){...}else{...} with
   a per-lane cndmask select, same as e66f419 for the quad kernel.
   Both expressions are evaluated in the SAME FP32 op order as the
   original divergent code (mul + mul + sub for lower, mul + mul + add
   for upper) — bit-exact equivalent.

3. Stage both NEOX and INTR rope outputs in float[VEC_SIZE] then pack
   via pack_f32_to_vec_t for the same bf16-cvt-instr-count win.

Verified numerically against the pure-FP32 PyTorch reference at rtol=1e-2
atol=0.05 (matches aiter checkAllclose) on shapes that exercise the
fallback path (Hq,Hk in {(24,25),(25,25),(18,18),(20,21)}, T in
{256,1024,4096}, NEOX+INTR): all 24 cells pass with fail=0/N.

Wall-clock impact (rocprofv3 kernel-trace, B=1, D=128, median of 30 iters,
gfx942 / MI300X). Geomean over 32 (T, Hq, Hk, layout) cells: 1.22x.

Hot shape (Qwen-Image-2 prefill, Hq=24, Hk=25):
  T=256   NEOX:  13.02 -> 10.15 us  (1.28x)
  T=256   INTR:  10.38 ->  9.33 us  (1.11x)
  T=4096  NEOX: 156.11 -> 116.17 us (1.34x)
  T=4096  INTR: 119.98 -> 107.40 us (1.12x)
  T=8192  NEOX: 305.22 -> 225.58 us (1.35x)
  T=8192  INTR: 231.51 -> 206.33 us (1.12x)

NEOX consistently wins more than INTR (~1.30x vs ~1.10x): NEOX gets all
three optimizations, INTR gets only inline-rmsnorm + packed-rope (no
divergent branch to refactor). VGPR (NEOX/INTR) goes 24/20 -> 32/24 — no
spill, occupancy unchanged.

* [fused_qk_norm_rope] fix host-pass build of rope_common.h includers

bb22fd3 added DPP-based warp_reduce_sum / half_warp_reduce_sum /
warp_shfl_xor_sync_vec helpers in mrope_utils::block_utils that referenced
opus::mov_dpp / opus::number / opus::bool_constant directly. But
rope_common.h gates `#include "opus/opus.hpp"` with #ifdef
__HIP_DEVICE_COMPILE__ — so during the HIP host pass, `opus` is not a
declared namespace. clang HIP parses template bodies in BOTH passes for
non-dependent-name lookup, so any TU that includes rope_common.h without
otherwise pulling in opus.hpp transitively (e.g. via quant_utils.cuh)
fails the host pass with "error: use of undeclared identifier 'opus'".

This caused CI build failure on
csrc/kernels/rope/general_2c_cached_positions_offsets_fwd_kernels.cu,
which is why the failure was masked locally — the fused_qk_norm_rope JIT
target pulls quant_utils.cuh (-> opus.hpp) and was building fine.

Two changes, both bit-for-bit equivalent at the GPU instruction level:

1. warp_reduce_sum / half_warp_reduce_sum bodies are now wrapped in
   #ifdef __HIP_DEVICE_COMPILE__ (matching the existing pattern used
   throughout rope_common.h, e.g. line 564). The host pass now skips the
   opus::mov_dpp calls entirely and the function returns `val` unchanged
   — fine because these helpers are __device__-only and never called
   from host code. Device-pass body is unchanged.

2. warp_shfl_xor_sync_vec's tag-dispatch parameter went from
   `opus::number<XorOffset> = {}` to `std::integral_constant<int,
   XorOffset> = {}`. The signature is parsed in both passes regardless
   of #ifdef wrapping the body, so this one CAN'T be hidden the way #1
   is. opus::number<I> is publicly derived from
   std::integral_constant<index_t, I> (csrc/include/opus/opus.hpp:57), so
   existing callers passing opus::number<X>{} continue to work
   unchanged via pass-by-value slicing of the empty derived type.

Verified:
- general_2c_cached_positions_offsets_fwd_kernels.cu now builds (full JIT
  re-run succeeds; previously failed in host pass on line 7260).
- fused_rope_rms_1way_quad and fused_rope_rms_1way (fallback) both
  numerically match pure-FP32 PyTorch reference at rtol=1e-2 atol=0.05
  for shapes that exercise both warp_reduce_sum and warp_shfl_xor_sync_vec
  (NEOX path) and the fallback inline RMSNorm.
- rocprofv3 fallback sweep over 32 (T, Hq, Hk, NEOX/INTR) cells: geomean
  speedup vs pre-388f737ba baseline = 1.22x (identical to the measurement
  before this fix — opus::mov_dpp lowers to the same v_*_dpp instruction
  as __builtin_amdgcn_mov_dpp; opus.hpp:1559-1562 is a thin wrapper).

* [fused_qk_norm_rope] quad kernel: scalarize Q/K split branch when QUAD_*_CT is even

When both QUAD_Q_CT and QUAD_K_CT are even, `is_q = global_warp_id <
T*QUAD_Q_CT` is uniform across the full 64-lane physical wave. Readfirstlane
the warp_id so the compiler emits s_cmp + s_cbranch instead of the divergent
v_cmp + s_and_saveexec + s_xor + s_cbranch_execz sequence, saving a few cycles
and EXEC-mask thrash per wave. Falls back to the original per-lane path for
odd QUAD_*_CT (e.g. H=12 -> QUAD=3) where the boundary can cut a wave.

* fix: align 1way RMSNorm cast order with diffusers

* fix: keep cos_sin in fp32 in 1way fused QK norm + RoPE

* style: black reformat assert in 1way op_tests

---------

Co-authored-by: LiuYinfeng01 <yinfeliu@amd.com>
LJ-underdog added a commit that referenced this pull request May 27, 2026
When M*topk*inter_dim is not divisible by group_size (128), the
x.view(-1, 128) reshape fails. This occurs in Silu MoE config
(topk=9, inter_dim=160) when M%4!=0.

Fix: zero-pad last dim to group_size boundary, quantize, trim output.
Zero-padding does not affect absmax scales (zeros don't raise max).
Scale count unchanged: ceil(160/128) == ceil(256/128) == 2.

Also fix scale allocation to use ceiling division instead of floor
division, which under-allocated scales for non-aligned last dims.

Bug #1 (CK kernel OOB at M=76/88) is unaffected — separate root cause.

Verified: 6 was-crash Silu M values PASS, 4 SwigluStep regression ≤2%.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants