Skip to content

perf(deepseek_v4): fused_compress kernel optimization + DualRMSNorm fusion#723

Merged
valarLip merged 3 commits into
mainfrom
feat/v4-fused-compress-opt
May 8, 2026
Merged

perf(deepseek_v4): fused_compress kernel optimization + DualRMSNorm fusion#723
valarLip merged 3 commits into
mainfrom
feat/v4-fused-compress-opt

Conversation

@valarLip
Copy link
Copy Markdown
Collaborator

@valarLip valarLip commented May 8, 2026

Summary

A few performance/cleanup changes on the V4-Pro inference path. GSM8K nshot=5 matches baseline locally (strict-match 0.9530 vs baseline 0.9530). NOTE: CI is currently failing for this PR — see open question at the bottom.

Fused-compress kernel (main win)

  • Split the K-loop into two phases: state-only [0, window_len) then input-only [window_len, K). On AMD CDNA a masked tl.load still issues the LD instruction; the predicate only suppresses the register write-back. Issuing only the live side's loads cuts HBM traffic ~40% in the bandwidth-bound regime.
  • HCA (ratio=128, K=128): 35.07µs → 28.76µs (-18%); CSA (ratio=4, K=8) unchanged (kernel body is too small — launch-overhead floor at K=8).
  • Padding invariant verified: window_len = K - min(j_in_seq+1, K) bounds K-1-position, so padding (s < 0) lies entirely within the state phase. The input phase needs no padding mask.
  • Eviction hints: kv_in / score_in marked evict_first (single-use per program); ape marked evict_last (small + reused across programs).
  • FP8 quant fusion path now uses tl.clamp + plain .to(fp8) (aiter style; avoids the slow software-RTNE path on AMD that bypasses v_cvt_pk_fp8_f32), with UE8M0 scale + MFMA 16x16 preshuffle + .cs streaming stores.
  • Bit-exact match vs fused_compress_attn_reference (≤1 BF16 ULP, traceable to tl.exp HW vs libm propagating through softmax-pool; downstream consumer is BF16 anyway).

DualRMSNorm fusion of q_norm2 + kv_norm

  • Both are head_dim=128 — routed through the existing _fused_qk_norm_single_kernel.
  • q_norm2 carries no learnable weight: added _make_weightless_rmsnorm factory (del weight; weight = None) + Q_HAS_WEIGHT: tl.constexpr skip path in the fused kernel. Loader no longer warns about q_norm2.weight unloaded.
  • DualRMSNorm._eps resolved with explicit is not None fallback (handles both variance_epsilon and eps attribute names).

Compressor refactor (side-effect only)

  • forward() returns None. The prior caller-visible BF16 return was vestigial — paged_decode / paged_prefill read the scattered entries directly from unified_kv (Main path) or the FP8 indexer pool (Indexer-inner).
  • Bound cache_scale (strided fp32 view of the FP8 block's scale region — same allocation) for the Indexer-inner path.
  • Auto-detects quant via kv_cache.dtype != bfloat16.

CompressPlan decode capacity

  • New decode_capacity_per_ratio arg: when supplied, the returned compress_plan_gpu slice has fixed length = decode-tight bound (max_decode_tokens // ratio + max_bs), instead of the prefill worst case (~13× larger). Trailing rows are sentinel-filled so the captured kernel grid is decode-sized but addresses stay stable.
  • Empty-fwd path now also returns CompressPlans pointing at the pre-allocated buffers (sentinel-filled), so capture-time and replay-time addresses match even on a zero-token fwd.

V4 attention builder cleanup

  • _decode_compress_cap[ratio] plumbed through prepare_decode / prepare_prefill / prepare_capture.
  • Removed _build_indexer_compress_slot_mapping and compress_slot_mapping_gpu (Indexer-inner now uses block_tables directly).
  • Dropped v4_indexer_decode_logits and v4_indexer_decode_topk_indices from the metadata pool. These are write-once GPU scratch with no CPU mirror; they are now allocated per-fwd via torch.empty in Indexer._score_topk_decode. Saves ~2 MiB pinned host + ~2 MiB GPU on the prior CpuGpuBuffer overhead. NOTE: this is the prime suspect for the CI accuracy regression — see below.

mark_trace typing

  • @overload + ParamSpec — pyright/pylance no longer flags DualRMSNorm-style decorated callables as "not callable".

Triton MoE block_m default

  • ATOM_TRITON_MOE_BLOCK_M default 64 → 32 (better tile occupancy on MI355X for typical MoE shapes).

Test plan

  • Microbench _fused_compress_attn_kernel for CSA + HCA configs (numbers in commit body).
  • Bit-exact validation vs fused_compress_attn_reference (CSA / HCA / padding mix).
  • Local GSM8K nshot=3 — DeepSeek-V4-Pro --level 0 ATOM_USE_TRITON_MOE=1 → 0.9538 / 0.9538 (matches main 0.9500 / 0.9507 within 1σ).
  • CI accuracy regression at level=3 + CG — see below.

Known issue

CI accuracy job at level=3 + use_cudagraph=True collapses to 0.0008 / 0.0000 vs main's 0.9500 / 0.9507 (same engine config). Local --level 0 is unaffected (validated above). Local repro at level=3 is blocked by an unrelated 'KernelMetadata' object has no attribute 'cluster_dims' torch+triton bug.

Strongest suspect: the decode_topk_indices_gpu / decode_logits_gpu change from pre-allocated CpuGpuBuffer (stable address) to per-fwd torch.empty. Under piecewise CG, downstream captured kernels bake in the capture-time pointer, but torch.empty returns a different address each call → replay reads stale memory. Plan: revert that one change in a follow-up commit on this branch and re-run CI.

…e buffer cleanup

Fused-compress kernel (atom/model_ops/v4_kernels/fused_compress.py)
  - Split K-loop into two phases (state-only `[0, window_len)` then
    input-only `[window_len, K)`): on AMD CDNA masked tl.load still
    issues the LD instruction, predicate only suppresses the register
    write. Issuing only the live side's loads cuts HBM traffic ~40%
    in the bandwidth-bound regime.
    Microbench (median per launch, BATCH=50, 200 iters):
      HCA (ratio=128 K=128): 35.07us → 28.76us (-18%)
      CSA (ratio=4   K=8 ): unchanged (launch-overhead floor at K=8)
  - Padding invariant verified: window_len = K - min(j_in_seq+1, K)
    bounds K-1-position, so padding (`s < 0`) lies entirely in the
    state phase — input phase needs no padding mask.
  - Eviction hints: ragged kv_in/score_in marked evict_first
    (single-use per program), ape evict_last (small + reused).
  - FP8 quant fusion path: `tl.clamp + plain .to(fp8)` (aiter style;
    avoids the slow `fp_downcast_rounding="rtne"` path on AMD that
    bypasses `v_cvt_pk_fp8_f32`), with UE8M0 scale + MFMA 16x16
    preshuffle + .cs streaming stores.
  - Bit-exact match vs `fused_compress_attn_reference` (verified at
    BF16 precision; ≤1 BF16 ULP due to `tl.exp` HW vs libm).

DualRMSNorm fusion (atom/model_ops/layernorm.py, atom/models/deepseek_v4.py)
  - q_norm2 + kv_norm (per-head Q + KV, both head_dim=128) routed
    through existing `DualRMSNorm` + `_fused_qk_norm_single_kernel`.
  - q_norm2 carries no learnable weight — added `_make_weightless_rmsnorm`
    factory (`del weight; weight = None`) so the parameter is absent
    from `state_dict` (no loader warning), and a `Q_HAS_WEIGHT` constexpr
    in the fused kernel skips the load when the weight is None.
  - `DualRMSNorm._eps` resolved once with explicit None-check fallback
    (handles both `variance_epsilon` and `eps` attribute names).

Compressor refactor (atom/models/deepseek_v4.py)
  - Side-effecting `forward()` returns None — the prior caller-visible
    BF16 return was vestigial (paged_decode/paged_prefill read scattered
    entries directly from `unified_kv` / FP8 indexer pool).
  - `cache_scale` strided fp32 view binding for the Indexer-inner
    Compressor (FP8 scale region of the same allocation).
  - Auto-detects quant via `kv_cache.dtype != bfloat16`.

CompressPlan decode capacity (atom/model_ops/v4_kernels/compress_plan.py)
  - New `decode_capacity_per_ratio` arg: when supplied, the returned
    `compress_plan_gpu` slice has fixed length = decode-tight bound
    (`max_decode_tokens // ratio + max_bs`) instead of prefill worst
    case (~13× larger), with sentinel-fill of trailing rows so the
    captured kernel grid is decode-sized but address-stable.
  - Empty-fwd path now also produces CompressPlans pointing at the
    pre-allocated buffers (sentinel-filled), so capture-time and
    replay-time addresses match even on a zero-token fwd.

V4 attention builder (atom/model_ops/attentions/deepseek_v4_attn.py)
  - `_decode_compress_cap[ratio]` plumbing for CG decode path.
  - Indexer-inner Compressor `cache_scale` view bound from
    `runner.v4_csa_idx_kv` per-layer.
  - Removed `_build_indexer_compress_slot_mapping` and
    `compress_slot_mapping_gpu` (Indexer-inner now uses
    block_tables directly).
  - Dropped `v4_indexer_decode_logits` and
    `v4_indexer_decode_topk_indices` from the metadata pool — these
    are write-once GPU scratch with no CPU mirror; allocated per-fwd
    via `torch.empty` in `Indexer._score_topk_decode`. Under CG capture
    the allocations land in the graph's private pool and replay
    reuses the same address (saves ~2 MiB pinned host + ~2 MiB GPU
    on the prior `CpuGpuBuffer` overhead).

mark_trace typing (atom/utils/decorators.py)
  - `@overload` + ParamSpec: pyright/pylance no longer flags
    DualRMSNorm-style decorated callables as "not callable".

Triton MoE block_m default (atom/model_ops/fused_moe_triton.py)
  - `ATOM_TRITON_MOE_BLOCK_M` default 64 → 32 (better MI355X tile
    occupancy at typical MoE shapes).

GSM8K nshot=5 (DeepSeek-V4-Pro, --level 0, ATOM_USE_TRITON_MOE=1):
  flexible-extract 0.9522 ±0.0059 / strict-match 0.9530 ±0.0058
  (baseline 0.953/0.954 — within 1σ, no regression)
Copilot AI review requested due to automatic review settings May 8, 2026 14:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the DeepSeek-V4(-Pro) inference path by reducing fused-compress kernel memory traffic, fusing per-head Q + KV RMSNorm into a single Triton kernel, and simplifying/reshaping V4 decode-time metadata buffers to better fit CUDAGraph constraints.

Changes:

  • Reworked V4 fused-compress Triton kernel (two-phase load strategy) and added an FP8 quantized scatter path; Compressor becomes side-effect-only (no return).
  • Introduced DualRMSNorm weightless-Q support and updated DeepSeek-V4 attention to use fused Q/K normalization.
  • Tightened decode-time CompressPlan capacity slicing for CUDAGraph and removed some long-lived preallocated decode scratch buffers in favor of per-forward GPU scratch allocations.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
atom/utils/decorators.py Adds overloads/ParamSpec typing to preserve @mark_trace callable signatures for type checkers.
atom/models/deepseek_v4.py Switches Compressor to side-effect-only fused scatter, integrates DualRMSNorm, and updates indexer decode scratch allocation behavior.
atom/model_ops/v4_kernels/fused_compress.py Implements the optimized fused-compress kernel, adds optional FP8 quant scatter, and removes legacy return tensor handling.
atom/model_ops/v4_kernels/compress_plan.py Adds decode-capacity slicing support and adjusts sentinel fill logic for CUDAGraph vs eager paths.
atom/model_ops/layernorm.py Adds optional Q-weight support to fused Q/K RMSNorm kernel and introduces DualRMSNorm enhancements.
atom/model_ops/fused_moe_triton.py Adjusts default Triton MoE tiling (ATOM_TRITON_MOE_BLOCK_M).
atom/model_ops/attentions/deepseek_v4_attn.py Updates V4 metadata builder to bind FP8 scale views, removes slot-mapping plumbing, and wires decode capacity slicing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +826 to +827
self._eps = getattr(q_norm, "variance_epsilon", None) or getattr(
q_norm, "eps", None
Comment on lines 905 to 912
# Compress plans (per ratio) for batched fused_compress + update_states.
# Decode batch: extend_lens = max_seqlen_q for all seqs (uniform).
# `context_lens_np` is post-extend (from batch.context_lens, set by
# scheduler after appending this fwd's tokens) — this is what plan
# generation needs as `seq_lens`. Must run BEFORE
# `_attach_v4_indexer_meta` since the indexer consumes
# plan.compress_plan_cpu to derive its FP8 write-side slot_mapping.
extend_lens_np = np.full(scheduled_bs, max_seqlen_q, dtype=np.int32)
Comment on lines 996 to 998
attn_metadata.compress_plans = self._build_compress_plans(
extend_lens_np, context_lens_np, positions.device
extend_lens_np, context_lens_np, positions.device, for_decode_cg=False
)
Copilot AI review requested due to automatic review settings May 8, 2026 16:18
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.

Comment on lines +826 to +827
self._eps = getattr(q_norm, "variance_epsilon", None) or getattr(
q_norm, "eps", None
Comment on lines +632 to +648
# FP8 quant path: bind a strided fp32 view of the per-block
# scale region. Layout per block: [k1*head_dim FP8 region]
# then [k1 fp32 scale region] then padding (cache_kernels.cu
# :1209-1239). Strides expressed in fp32 elements.
nb, k1, aligned_dim = idx_kv.shape
head_dim = self.index_head_dim
assert (
k1 * aligned_dim
) % 4 == 0, f"per-block bytes ({k1 * aligned_dim}) must be 4-aligned"
block_fp32_stride = (k1 * aligned_dim) // 4
scale_fp32_offset = (k1 * head_dim) // 4
module.cache_scale = (
idx_kv.view(torch.float32)
.view(-1)
.as_strided(
size=(nb, k1),
stride=(block_fp32_stride, 1),
Comment on lines 907 to 912
# `context_lens_np` is post-extend (from batch.context_lens, set by
# scheduler after appending this fwd's tokens) — this is what plan
# generation needs as `seq_lens`. Must run BEFORE
# `_attach_v4_indexer_meta` since the indexer consumes
# plan.compress_plan_cpu to derive its FP8 write-side slot_mapping.
extend_lens_np = np.full(scheduled_bs, max_seqlen_q, dtype=np.int32)
Comment on lines +311 to +323
if PRESHUFFLE:
TILE: tl.constexpr = 16
token_tile_id = slot_in_block // TILE
token_in_tile = slot_in_block % TILE
col_tile_id = d // TILE
col_in_tile = d % TILE
fp8_offset = (
physical_block * kv_cache_block_stride
+ token_tile_id * (TILE * head_dim)
+ col_tile_id * (TILE * TILE)
+ token_in_tile * TILE
+ col_in_tile
)
Comment on lines +1126 to +1137
# Per-fwd write-once GPU scratch — no CPU mirror, no cross-fwd state.
# Under CUDAGraph capture, torch allocates from the graph's private
# memory pool and the address is stable across replays at this
# captured `total_tokens`. No `fill_(-inf)` needed —
# `top_k_per_row_decode` bounds each row by `n_committed_per_seq[batch]`
# so unwritten cols are never picked.
logits = torch.empty(
total_tokens,
self._max_model_len_idx,
dtype=torch.float32,
device=q_fp8.device,
)
Comment on lines +1149 to +1154
# Per-fwd write-once int32 scratch. Kernel writes exactly `index_topk`
# ints per row (valid seq-local indices then -1 sentinels). CG-safe
# for the same reason as `logits` above.
topk_local = torch.empty(
total_tokens, self.index_topk, dtype=torch.int32, device=q_fp8.device
)
Comment on lines +1971 to +1973
# Under CUDAGraph capture they land in the graph's private pool
# and replay reuses the same address; eager keeps the standard
# caching-allocator fast path.
@valarLip valarLip merged commit e87b2f3 into main May 8, 2026
26 of 33 checks passed
@valarLip valarLip deleted the feat/v4-fused-compress-opt branch May 8, 2026 16:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants