perf(deepseek_v4): fused_compress kernel optimization + DualRMSNorm fusion#723
Merged
Conversation
…e buffer cleanup
Fused-compress kernel (atom/model_ops/v4_kernels/fused_compress.py)
- Split K-loop into two phases (state-only `[0, window_len)` then
input-only `[window_len, K)`): on AMD CDNA masked tl.load still
issues the LD instruction, predicate only suppresses the register
write. Issuing only the live side's loads cuts HBM traffic ~40%
in the bandwidth-bound regime.
Microbench (median per launch, BATCH=50, 200 iters):
HCA (ratio=128 K=128): 35.07us → 28.76us (-18%)
CSA (ratio=4 K=8 ): unchanged (launch-overhead floor at K=8)
- Padding invariant verified: window_len = K - min(j_in_seq+1, K)
bounds K-1-position, so padding (`s < 0`) lies entirely in the
state phase — input phase needs no padding mask.
- Eviction hints: ragged kv_in/score_in marked evict_first
(single-use per program), ape evict_last (small + reused).
- FP8 quant fusion path: `tl.clamp + plain .to(fp8)` (aiter style;
avoids the slow `fp_downcast_rounding="rtne"` path on AMD that
bypasses `v_cvt_pk_fp8_f32`), with UE8M0 scale + MFMA 16x16
preshuffle + .cs streaming stores.
- Bit-exact match vs `fused_compress_attn_reference` (verified at
BF16 precision; ≤1 BF16 ULP due to `tl.exp` HW vs libm).
DualRMSNorm fusion (atom/model_ops/layernorm.py, atom/models/deepseek_v4.py)
- q_norm2 + kv_norm (per-head Q + KV, both head_dim=128) routed
through existing `DualRMSNorm` + `_fused_qk_norm_single_kernel`.
- q_norm2 carries no learnable weight — added `_make_weightless_rmsnorm`
factory (`del weight; weight = None`) so the parameter is absent
from `state_dict` (no loader warning), and a `Q_HAS_WEIGHT` constexpr
in the fused kernel skips the load when the weight is None.
- `DualRMSNorm._eps` resolved once with explicit None-check fallback
(handles both `variance_epsilon` and `eps` attribute names).
Compressor refactor (atom/models/deepseek_v4.py)
- Side-effecting `forward()` returns None — the prior caller-visible
BF16 return was vestigial (paged_decode/paged_prefill read scattered
entries directly from `unified_kv` / FP8 indexer pool).
- `cache_scale` strided fp32 view binding for the Indexer-inner
Compressor (FP8 scale region of the same allocation).
- Auto-detects quant via `kv_cache.dtype != bfloat16`.
CompressPlan decode capacity (atom/model_ops/v4_kernels/compress_plan.py)
- New `decode_capacity_per_ratio` arg: when supplied, the returned
`compress_plan_gpu` slice has fixed length = decode-tight bound
(`max_decode_tokens // ratio + max_bs`) instead of prefill worst
case (~13× larger), with sentinel-fill of trailing rows so the
captured kernel grid is decode-sized but address-stable.
- Empty-fwd path now also produces CompressPlans pointing at the
pre-allocated buffers (sentinel-filled), so capture-time and
replay-time addresses match even on a zero-token fwd.
V4 attention builder (atom/model_ops/attentions/deepseek_v4_attn.py)
- `_decode_compress_cap[ratio]` plumbing for CG decode path.
- Indexer-inner Compressor `cache_scale` view bound from
`runner.v4_csa_idx_kv` per-layer.
- Removed `_build_indexer_compress_slot_mapping` and
`compress_slot_mapping_gpu` (Indexer-inner now uses
block_tables directly).
- Dropped `v4_indexer_decode_logits` and
`v4_indexer_decode_topk_indices` from the metadata pool — these
are write-once GPU scratch with no CPU mirror; allocated per-fwd
via `torch.empty` in `Indexer._score_topk_decode`. Under CG capture
the allocations land in the graph's private pool and replay
reuses the same address (saves ~2 MiB pinned host + ~2 MiB GPU
on the prior `CpuGpuBuffer` overhead).
mark_trace typing (atom/utils/decorators.py)
- `@overload` + ParamSpec: pyright/pylance no longer flags
DualRMSNorm-style decorated callables as "not callable".
Triton MoE block_m default (atom/model_ops/fused_moe_triton.py)
- `ATOM_TRITON_MOE_BLOCK_M` default 64 → 32 (better MI355X tile
occupancy at typical MoE shapes).
GSM8K nshot=5 (DeepSeek-V4-Pro, --level 0, ATOM_USE_TRITON_MOE=1):
flexible-extract 0.9522 ±0.0059 / strict-match 0.9530 ±0.0058
(baseline 0.953/0.954 — within 1σ, no regression)
Contributor
There was a problem hiding this comment.
Pull request overview
This PR optimizes the DeepSeek-V4(-Pro) inference path by reducing fused-compress kernel memory traffic, fusing per-head Q + KV RMSNorm into a single Triton kernel, and simplifying/reshaping V4 decode-time metadata buffers to better fit CUDAGraph constraints.
Changes:
- Reworked V4 fused-compress Triton kernel (two-phase load strategy) and added an FP8 quantized scatter path; Compressor becomes side-effect-only (no return).
- Introduced
DualRMSNormweightless-Q support and updated DeepSeek-V4 attention to use fused Q/K normalization. - Tightened decode-time CompressPlan capacity slicing for CUDAGraph and removed some long-lived preallocated decode scratch buffers in favor of per-forward GPU scratch allocations.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/utils/decorators.py | Adds overloads/ParamSpec typing to preserve @mark_trace callable signatures for type checkers. |
| atom/models/deepseek_v4.py | Switches Compressor to side-effect-only fused scatter, integrates DualRMSNorm, and updates indexer decode scratch allocation behavior. |
| atom/model_ops/v4_kernels/fused_compress.py | Implements the optimized fused-compress kernel, adds optional FP8 quant scatter, and removes legacy return tensor handling. |
| atom/model_ops/v4_kernels/compress_plan.py | Adds decode-capacity slicing support and adjusts sentinel fill logic for CUDAGraph vs eager paths. |
| atom/model_ops/layernorm.py | Adds optional Q-weight support to fused Q/K RMSNorm kernel and introduces DualRMSNorm enhancements. |
| atom/model_ops/fused_moe_triton.py | Adjusts default Triton MoE tiling (ATOM_TRITON_MOE_BLOCK_M). |
| atom/model_ops/attentions/deepseek_v4_attn.py | Updates V4 metadata builder to bind FP8 scale views, removes slot-mapping plumbing, and wires decode capacity slicing. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+826
to
+827
| self._eps = getattr(q_norm, "variance_epsilon", None) or getattr( | ||
| q_norm, "eps", None |
Comment on lines
905
to
912
| # Compress plans (per ratio) for batched fused_compress + update_states. | ||
| # Decode batch: extend_lens = max_seqlen_q for all seqs (uniform). | ||
| # `context_lens_np` is post-extend (from batch.context_lens, set by | ||
| # scheduler after appending this fwd's tokens) — this is what plan | ||
| # generation needs as `seq_lens`. Must run BEFORE | ||
| # `_attach_v4_indexer_meta` since the indexer consumes | ||
| # plan.compress_plan_cpu to derive its FP8 write-side slot_mapping. | ||
| extend_lens_np = np.full(scheduled_bs, max_seqlen_q, dtype=np.int32) |
Comment on lines
996
to
998
| attn_metadata.compress_plans = self._build_compress_plans( | ||
| extend_lens_np, context_lens_np, positions.device | ||
| extend_lens_np, context_lens_np, positions.device, for_decode_cg=False | ||
| ) |
Comment on lines
+826
to
+827
| self._eps = getattr(q_norm, "variance_epsilon", None) or getattr( | ||
| q_norm, "eps", None |
Comment on lines
+632
to
+648
| # FP8 quant path: bind a strided fp32 view of the per-block | ||
| # scale region. Layout per block: [k1*head_dim FP8 region] | ||
| # then [k1 fp32 scale region] then padding (cache_kernels.cu | ||
| # :1209-1239). Strides expressed in fp32 elements. | ||
| nb, k1, aligned_dim = idx_kv.shape | ||
| head_dim = self.index_head_dim | ||
| assert ( | ||
| k1 * aligned_dim | ||
| ) % 4 == 0, f"per-block bytes ({k1 * aligned_dim}) must be 4-aligned" | ||
| block_fp32_stride = (k1 * aligned_dim) // 4 | ||
| scale_fp32_offset = (k1 * head_dim) // 4 | ||
| module.cache_scale = ( | ||
| idx_kv.view(torch.float32) | ||
| .view(-1) | ||
| .as_strided( | ||
| size=(nb, k1), | ||
| stride=(block_fp32_stride, 1), |
Comment on lines
907
to
912
| # `context_lens_np` is post-extend (from batch.context_lens, set by | ||
| # scheduler after appending this fwd's tokens) — this is what plan | ||
| # generation needs as `seq_lens`. Must run BEFORE | ||
| # `_attach_v4_indexer_meta` since the indexer consumes | ||
| # plan.compress_plan_cpu to derive its FP8 write-side slot_mapping. | ||
| extend_lens_np = np.full(scheduled_bs, max_seqlen_q, dtype=np.int32) |
Comment on lines
+311
to
+323
| if PRESHUFFLE: | ||
| TILE: tl.constexpr = 16 | ||
| token_tile_id = slot_in_block // TILE | ||
| token_in_tile = slot_in_block % TILE | ||
| col_tile_id = d // TILE | ||
| col_in_tile = d % TILE | ||
| fp8_offset = ( | ||
| physical_block * kv_cache_block_stride | ||
| + token_tile_id * (TILE * head_dim) | ||
| + col_tile_id * (TILE * TILE) | ||
| + token_in_tile * TILE | ||
| + col_in_tile | ||
| ) |
Comment on lines
+1126
to
+1137
| # Per-fwd write-once GPU scratch — no CPU mirror, no cross-fwd state. | ||
| # Under CUDAGraph capture, torch allocates from the graph's private | ||
| # memory pool and the address is stable across replays at this | ||
| # captured `total_tokens`. No `fill_(-inf)` needed — | ||
| # `top_k_per_row_decode` bounds each row by `n_committed_per_seq[batch]` | ||
| # so unwritten cols are never picked. | ||
| logits = torch.empty( | ||
| total_tokens, | ||
| self._max_model_len_idx, | ||
| dtype=torch.float32, | ||
| device=q_fp8.device, | ||
| ) |
Comment on lines
+1149
to
+1154
| # Per-fwd write-once int32 scratch. Kernel writes exactly `index_topk` | ||
| # ints per row (valid seq-local indices then -1 sentinels). CG-safe | ||
| # for the same reason as `logits` above. | ||
| topk_local = torch.empty( | ||
| total_tokens, self.index_topk, dtype=torch.int32, device=q_fp8.device | ||
| ) |
Comment on lines
+1971
to
+1973
| # Under CUDAGraph capture they land in the graph's private pool | ||
| # and replay reuses the same address; eager keeps the standard | ||
| # caching-allocator fast path. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
A few performance/cleanup changes on the V4-Pro inference path. GSM8K nshot=5 matches baseline locally (strict-match 0.9530 vs baseline 0.9530). NOTE: CI is currently failing for this PR — see open question at the bottom.
Fused-compress kernel (main win)
[0, window_len)then input-only[window_len, K). On AMD CDNA a maskedtl.loadstill issues the LD instruction; the predicate only suppresses the register write-back. Issuing only the live side's loads cuts HBM traffic ~40% in the bandwidth-bound regime.window_len = K - min(j_in_seq+1, K)boundsK-1-position, so padding (s < 0) lies entirely within the state phase. The input phase needs no padding mask.kv_in / score_inmarkedevict_first(single-use per program);apemarkedevict_last(small + reused across programs).tl.clamp + plain .to(fp8)(aiter style; avoids the slow software-RTNE path on AMD that bypassesv_cvt_pk_fp8_f32), with UE8M0 scale + MFMA 16x16 preshuffle +.csstreaming stores.fused_compress_attn_reference(≤1 BF16 ULP, traceable totl.expHW vs libm propagating through softmax-pool; downstream consumer is BF16 anyway).DualRMSNorm fusion of q_norm2 + kv_norm
_fused_qk_norm_single_kernel.q_norm2carries no learnable weight: added_make_weightless_rmsnormfactory (del weight; weight = None) +Q_HAS_WEIGHT: tl.constexprskip path in the fused kernel. Loader no longer warns aboutq_norm2.weight unloaded.DualRMSNorm._epsresolved with explicitis not Nonefallback (handles bothvariance_epsilonandepsattribute names).Compressor refactor (side-effect only)
forward()returnsNone. The prior caller-visible BF16 return was vestigial —paged_decode/paged_prefillread the scattered entries directly fromunified_kv(Main path) or the FP8 indexer pool (Indexer-inner).cache_scale(strided fp32 view of the FP8 block's scale region — same allocation) for the Indexer-inner path.kv_cache.dtype != bfloat16.CompressPlan decode capacity
decode_capacity_per_ratioarg: when supplied, the returnedcompress_plan_gpuslice has fixed length = decode-tight bound (max_decode_tokens // ratio + max_bs), instead of the prefill worst case (~13× larger). Trailing rows are sentinel-filled so the captured kernel grid is decode-sized but addresses stay stable.V4 attention builder cleanup
_decode_compress_cap[ratio]plumbed throughprepare_decode/prepare_prefill/prepare_capture._build_indexer_compress_slot_mappingandcompress_slot_mapping_gpu(Indexer-inner now usesblock_tablesdirectly).v4_indexer_decode_logitsandv4_indexer_decode_topk_indicesfrom the metadata pool. These are write-once GPU scratch with no CPU mirror; they are now allocated per-fwd viatorch.emptyinIndexer._score_topk_decode. Saves ~2 MiB pinned host + ~2 MiB GPU on the priorCpuGpuBufferoverhead. NOTE: this is the prime suspect for the CI accuracy regression — see below.mark_tracetyping@overload+ParamSpec— pyright/pylance no longer flagsDualRMSNorm-style decorated callables as "not callable".Triton MoE block_m default
ATOM_TRITON_MOE_BLOCK_Mdefault 64 → 32 (better tile occupancy on MI355X for typical MoE shapes).Test plan
_fused_compress_attn_kernelfor CSA + HCA configs (numbers in commit body).fused_compress_attn_reference(CSA / HCA / padding mix).--level 0ATOM_USE_TRITON_MOE=1→ 0.9538 / 0.9538 (matches main 0.9500 / 0.9507 within 1σ).Known issue
CI accuracy job at
level=3 + use_cudagraph=Truecollapses to 0.0008 / 0.0000 vs main's 0.9500 / 0.9507 (same engine config). Local--level 0is unaffected (validated above). Local repro atlevel=3is blocked by an unrelated'KernelMetadata' object has no attribute 'cluster_dims'torch+triton bug.Strongest suspect: the
decode_topk_indices_gpu/decode_logits_gpuchange from pre-allocatedCpuGpuBuffer(stable address) to per-fwdtorch.empty. Under piecewise CG, downstream captured kernels bake in the capture-time pointer, buttorch.emptyreturns a different address each call → replay reads stale memory. Plan: revert that one change in a follow-up commit on this branch and re-run CI.