Skip to content
Open
232 changes: 221 additions & 11 deletions tensorrt_llm/_torch/custom_ops/triton_fused_inv_rope_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def _fused_inv_rope_fp8_quant_per_head(
HALF_ROPE: tl.constexpr,
IS_NEOX: tl.constexpr,
):
# Original 1-token-per-block kernel. Used by the Python dispatcher when
# M < threshold (small-M / GEN-phase shapes), where wrapping the body in
# an outer loop measurably hurts perf even with BLOCK_TOKENS_M=1 (the if/
# else around the bulk creates predicated execution that compiles less
# tightly than a straight return path). See `_fused_inv_rope_fp8_quant_
# per_head_mblock` for the multi-token path used at M >= 1024.
# int64: stride multiply overflows int32 past num_tokens=32768 (IMA).
pid_token = tl.program_id(0).to(tl.int64)
pid_gh = tl.program_id(1).to(tl.int64)
Expand Down Expand Up @@ -158,10 +164,162 @@ def _fused_inv_rope_fp8_quant_per_head(
tl.store(scale_addrs, scales)


@triton.jit
def _fused_inv_rope_fp8_quant_per_head_mblock(
o_ptr,
positions_ptr,
cos_sin_cache_ptr,
fp8_ptr,
scale_ptr,
num_tokens,
heads_per_group: tl.constexpr,
o_stride_token,
o_stride_head,
cache_stride_pos,
fp8_stride_group,
fp8_stride_token,
scale_stride_group,
scale_stride_k,
scale_buf_m: tl.constexpr,
fp8_max: tl.constexpr,
eps: tl.constexpr,
QUANT_GROUP_SIZE: tl.constexpr,
CHUNKS_PER_HEAD: tl.constexpr,
ROPE_START: tl.constexpr,
HALF_ROPE: tl.constexpr,
IS_NEOX: tl.constexpr,
BLOCK_TOKENS_M: tl.constexpr,
):
# Multi-token-per-block variant for M >= 1024. Grid X = ceil(scale_buf_m / BTM).
# scale_buf_m = pad_up(num_tokens, 4) — the BMM consumer's expected
# scale-tensor M dim (see `cute_dsl_fp8_bmm_blackwell` `sf_m = pad_up(m, 4)`).
# The grid may overshoot past scale_buf_m when scale_buf_m isn't a multiple
# of BTM, so the inner loop guards with `pid_token < scale_buf_m` to avoid
# OOB scale writes.
pid_x = tl.program_id(0).to(tl.int64)
pid_gh = tl.program_id(1).to(tl.int64)

g = pid_gh // heads_per_group
head_in_group = pid_gh % heads_per_group
global_head = pid_gh
qb_start = head_in_group * CHUNKS_PER_HEAD

HEAD_DIM: tl.constexpr = CHUNKS_PER_HEAD * QUANT_GROUP_SIZE
rope_abs_start: tl.constexpr = (CHUNKS_PER_HEAD - 1) * QUANT_GROUP_SIZE + ROPE_START
offsets = tl.arange(0, HEAD_DIM)
is_rope = offsets >= rope_abs_start
rope_local = offsets - rope_abs_start
block_offsets = tl.arange(0, CHUNKS_PER_HEAD)
qb_indices = qb_start + block_offsets

# Inner loop over BLOCK_TOKENS_M tokens. tl.range (vs tl.static_range)
# generates a runtime loop instead of a fully-unrolled body. The
# num_stages arg here controls cross-iteration software pipelining
# (load_next overlap with compute_current_and_store) — distinct from
# the launch-site num_stages which only matters for warpgroup GEMM
# kernels. Depth=2 is enough to overlap one load with the previous
# iter's store.
for m_in_block in tl.range(0, BLOCK_TOKENS_M, num_stages=2):
pid_token = pid_x * BLOCK_TOKENS_M + m_in_block

if pid_token >= scale_buf_m:
# Beyond the (4-aligned) scale buffer: skip entirely. Happens
# only on the very last grid block when scale_buf_m % BTM != 0.
pass
elif pid_token >= num_tokens:
# Padding row in [num_tokens, pad_up(num_tokens, 4)): zero scale
# so the BMM dequant sees 0 instead of stale memory.
scale_addrs = (
scale_ptr + g * scale_stride_group + pid_token + qb_indices * scale_stride_k
)
tl.store(scale_addrs, tl.zeros((CHUNKS_PER_HEAD,), dtype=tl.float32))
else:
input_base = o_ptr + pid_token * o_stride_token + global_head * o_stride_head
x = tl.load(input_base + offsets).to(tl.float32)

pos = tl.load(positions_ptr + pid_token)
cache_base = cos_sin_cache_ptr + pos * cache_stride_pos

if IS_NEOX:
is_first_half = rope_local < HALF_ROPE
partner_local = tl.where(
is_first_half, rope_local + HALF_ROPE, rope_local - HALF_ROPE
)
partner_abs = rope_abs_start + partner_local
x_partner = tl.load(input_base + partner_abs, mask=is_rope, other=0.0).to(
tl.float32
)
cs_idx = tl.where(is_first_half, rope_local, rope_local - HALF_ROPE)
cos_v = tl.load(cache_base + cs_idx, mask=is_rope, other=1.0)
sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope, other=0.0)
sign = tl.where(is_first_half, 1.0, -1.0)
rotated = x * cos_v + sign * sin_v * x_partner
else:
x_partner = tl.load(input_base + (offsets ^ 1), mask=is_rope, other=0.0).to(
tl.float32
)
cs_idx = tl.maximum(rope_local >> 1, 0)
cos_v = tl.load(cache_base + cs_idx, mask=is_rope, other=1.0)
sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope, other=0.0)
x_add = x * cos_v + x_partner * sin_v
x_sub = x * cos_v - x_partner * sin_v
is_even = (rope_local & 1) == 0
rotated = tl.where(is_even, x_add, x_sub)

x = tl.where(is_rope, rotated, x)

x_2d = tl.reshape(tl.abs(x), (CHUNKS_PER_HEAD, QUANT_GROUP_SIZE))
block_absmax = tl.maximum(tl.max(x_2d, axis=1), eps)
scales = block_absmax * (1.0 / fp8_max)

scales_exp = tl.reshape(
tl.broadcast_to(
tl.reshape(scales, (CHUNKS_PER_HEAD, 1)),
(CHUNKS_PER_HEAD, QUANT_GROUP_SIZE),
),
(HEAD_DIM,),
)
x_quant = tl.clamp(x / scales_exp, -fp8_max, fp8_max).to(tl.float8e4nv)

fp8_base = (
fp8_ptr
+ g * fp8_stride_group
+ pid_token * fp8_stride_token
+ qb_start * QUANT_GROUP_SIZE
)
tl.store(fp8_base + offsets, x_quant)

scale_addrs = (
scale_ptr + g * scale_stride_group + pid_token + qb_indices * scale_stride_k
)
tl.store(scale_addrs, scales)


def _tma_aligned_size(x: int, tma_align_size_in_elems: int = 4) -> int:
return (x + tma_align_size_in_elems - 1) // tma_align_size_in_elems * tma_align_size_in_elems


def _choose_block_tokens_m(num_tokens: int) -> int:
"""Pick BLOCK_TOKENS_M based on M.

Microbench (DEP8 shape, n_groups=8, heads_per_group=16) on GB300:
- At M < 1024 the single-token-per-block path is the winner — BTM>1 adds
static-unroll overhead that's not amortized at small block count.
- At M >= 1024 multi-token blocks reduce grid size (fewer single-warp
launch costs) and improve SM occupancy. Measured wins:
M=1024 : BTM=8 → 63 µs vs 70 µs (BTM=1) ≈ 10% faster
M=2048+ : BTM=16 → 113 µs vs 137 µs (BTM=1) ≈ 18% faster
M=8192 : BTM=16 → 448 µs vs 533 µs (BTM=1) ≈ 16% faster
"""
if num_tokens >= 4096:
return 32
if num_tokens >= 2048:
return 16
if num_tokens >= 1024:
return 8
return 1


def _fused_inv_rope_fp8_quant_impl(
o: torch.Tensor,
positions: torch.Tensor,
Expand Down Expand Up @@ -199,6 +357,16 @@ def _fused_inv_rope_fp8_quant_impl(
fp8_dtype = torch.float8_e4m3fn
fp8_max = torch.finfo(fp8_dtype).max

block_tokens_m = _choose_block_tokens_m(num_tokens)
# The scale buffer is consumed by `cute_dsl_fp8_bmm_blackwell` which
# *hard-codes* its m-dim stride as `sf_m = pad_up(m, 4)` (see
# `tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py:3635`). Padding
# to a different alignment makes the BMM read scales from wrong physical
# offsets — silently producing wrong dequant values for ~7-13 gsm8k
# points under DEP8 when BTM > 4. We MUST keep the producer's pad equal
# to pad_up(num_tokens, 4) regardless of BTM. The mblock kernel's grid
# may overshoot scale_buf_m when scale_buf_m % BTM != 0; the inner-loop
# guard handles that (see `_fused_inv_rope_fp8_quant_per_head_mblock`).
tma_aligned_T = _tma_aligned_size(num_tokens, 4)

# FP8 output buffer: fully contiguous [n_groups, num_tokens, d] — must
Expand Down Expand Up @@ -239,14 +407,7 @@ def _fused_inv_rope_fp8_quant_impl(
if positions.dtype != torch.int32 and positions.dtype != torch.int64:
positions = positions.to(torch.int64)

grid = (tma_aligned_T, n_groups * heads_per_group)
_fused_inv_rope_fp8_quant_per_head[grid](
o,
positions,
cos_sin_view,
fp8_buf,
scale_buf,
num_tokens,
common_kwargs = dict(
heads_per_group=heads_per_group,
o_stride_token=o.stride(0),
o_stride_head=o.stride(1),
Expand All @@ -262,9 +423,57 @@ def _fused_inv_rope_fp8_quant_impl(
ROPE_START=nope_dim % quant_group_size,
HALF_ROPE=rope_dim // 2,
IS_NEOX=is_neox,
num_warps=1,
num_stages=1,
)

if block_tokens_m == 1:
# Small-M / GEN path: single-token-per-block kernel (no outer loop).
# Bit-for-bit equivalent to the original V1 kernel.
grid = (tma_aligned_T, n_groups * heads_per_group)
_fused_inv_rope_fp8_quant_per_head[grid](
o,
positions,
cos_sin_view,
fp8_buf,
scale_buf,
num_tokens,
**common_kwargs,
num_warps=1,
num_stages=1,
)
else:
# Large-M / CTX path: multi-token-per-block kernel. num_stages=2
# lets Triton interleave the next iter's TMA load with the current
# iter's compute+store across the unrolled BLOCK_TOKENS_M body.
# num_warps tuning: per-block work scales with BTM × HEAD_DIM.
# V2.3 microbench at M=8192 DEP8 shape:
# BTM=16 nw=1 → 448 µs (3.6 TB/s, 45% peak)
# BTM=16 nw=2 → 420 µs (3.8 TB/s, 48% peak) ← best
# BTM=16 nw=4 → 555 µs (2.9 TB/s, 36% peak) ← reg-pressure regression
# So nw=2 for BTM>=8 (more threads help amortize the larger per-block
# load fan-out), nw=1 for BTM=1 (V1 ruled out nw>1 there).
nw = 2 if block_tokens_m >= 8 else 1
# num_stages: software-pipelining depth across the inner loop.
# BTM=32 has more iterations to pipeline so stages=3 has more value.
ns = 3 if block_tokens_m >= 32 else 2
# ceil_div: tma_aligned_T may not be a multiple of BTM (e.g.
# M=1500, tma_aligned_T=1500, BTM=16 → grid_x=94, max pid_token=1519).
# The kernel's inner-loop `pid_token < scale_buf_m` guard handles
# the overshoot — no OOB scale writes.
grid_x = (tma_aligned_T + block_tokens_m - 1) // block_tokens_m
grid = (grid_x, n_groups * heads_per_group)
_fused_inv_rope_fp8_quant_per_head_mblock[grid](
o,
positions,
cos_sin_view,
fp8_buf,
scale_buf,
num_tokens,
scale_buf_m=tma_aligned_T,
**common_kwargs,
BLOCK_TOKENS_M=block_tokens_m,
num_warps=nw,
num_stages=ns,
)
return fp8_buf, scale_buf


Expand All @@ -282,7 +491,8 @@ def _fused_inv_rope_fp8_quant_fake(
num_tokens, num_heads, head_dim = o.shape
d = heads_per_group * head_dim
num_scale_blocks = d // quant_group_size
tma_aligned_T = _tma_aligned_size(num_tokens, 4)
block_tokens_m = _choose_block_tokens_m(num_tokens)
tma_aligned_T = _tma_aligned_size(num_tokens, max(block_tokens_m, 4))
fp8_buf = torch.empty(
(n_groups, num_tokens, d),
dtype=torch.float8_e4m3fn,
Expand Down
Loading
Loading