Support Triton MLA Attention Backend#7852
Conversation
|
chang-wenbin seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
|
Thanks for your contribution! |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #7852 +/- ##
==========================================
Coverage ? 62.85%
==========================================
Files ? 466
Lines ? 64720
Branches ? 9898
==========================================
Hits ? 40680
Misses ? 21271
Partials ? 2769
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-19 18:08:17
📋 Review 摘要
PR 概述:新增基于 Triton 的 MLA Attention Backend(TRITON_MLA_ATTN),作为现有 MLA_ATTN(custom CUDA op)的纯 Python Triton 替代方案,同时支持 CUDAGraph 兼容。
变更范围:layers/attention/(新增 triton backend + triton kernels)、models/deepseek_v3.py、config.py、platforms/、worker/gpu_model_runner.py
影响面 Tag:[OP] [Models] [FDConfig]
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🟡 建议 | triton_ops/decode_attention.py |
_decode_grouped_att_m_fwd 中 BLOCK_DPE 仅针对 DeepSeek-V3 维度硬编码,else 分支令 BLOCK_DPE=0,其他 MLA 模型 RoPE 注意力贡献被跳过,导致精度错误 |
| 🟡 建议 | triton_mla_attention_backend.py |
init_attention_metadata 热路径中多次调用 .item()(decode_mask.sum().item() 和 paddle.sum(full_seq_lens).item()),触发 CPU-GPU Sync,阻塞推理流水线 |
| ❓ 疑问 | models/deepseek_v3.py |
新增 need_do_attention 守卫影响所有 MLA 后端(不仅 Triton),需确认对 MLA_ATTN 路径无精度退化 |
| 📝 PR 规范 | — | 标题缺失 Tag,描述所有段落为空 |
🟡 详细分析
1. _decode_grouped_att_m_fwd 中 BLOCK_DPE 硬编码(decode_attention.py)
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0 # ← RoPE 部分被完全跳过当前只有 DeepSeek-V3 full(512+64)和 small(256+32)两种维度走了正确的 RoPE 分块路径。else 分支 BLOCK_DPE=0 会使 stage1 内 if BLOCK_DPE > 0 判断分支失效,其他 MLA 模型(不同 qk_rope_head_dim)的 PE 部分注意力贡献为零,结果精度错误。
建议:将 BLOCK_DPE 改为运行时参数(从 qk_rope_head_dim 传入),或在 else 分支中用 triton.next_power_of_2 动态计算 BLOCK_DMODEL 和 BLOCK_DPE。若本 PR 明确仅支持 DeepSeek-V3,请在函数 docstring 中注明限制。
2. init_attention_metadata 热路径 CPU-GPU Sync(triton_mla_attention_backend.py)
每次 forward 调用 init_attention_metadata 时:
decode_bs = int(decode_mask.sum().item()) # GPU → CPU sync #1
total_kv_len = int(paddle.sum(full_seq_lens).item()) # GPU → CPU sync #2两次 .item() 均为同步操作,会序列化 GPU 与 CPU 执行,在高并发或长 context 场景下显著降低吞吐(checklist §C 表层信号)。
建议:优先从 forward_meta 已有的 Python-level 整数元数据获取 decode_bs;若无对应字段,请在 PR 注释中说明此为临时方案并计划后续优化。
3. deepseek_v3.py 的 need_do_attention 守卫影响所有 MLA 后端
need_do_attention = forward_meta.max_len_tensor_cpu[1] > 0 or forward_meta.max_len_tensor_cpu[2] > 0
if hidden_states.shape[0] > 0 and need_do_attention:此修改并非 Triton-only,MLA_ATTN(MLAAttentionBackend)也会走此路径。当 hidden_states.shape[0] > 0 但两个 max_len 均为 0 时,attention 层被整体跳过,residual 不更新,downstream 层拿到未经 LayerNorm 处理的原始 residual。若此场景仅在 CUDAGraph padding 阶段触发,请在代码注释中明确说明,并提供针对 MLA_ATTN 路径的回归测试结果。
📝 PR 规范检查
PR 标题缺少官方 Tag,且 Motivation、Modifications、Usage or Command、Accuracy Tests 四个段落均为空(仅有模板占位注释)。
标题建议(可直接复制):
[Feature] Support Triton MLA Attention Backend
PR 描述建议(可直接复制):
## Motivation
当前 MLA 注意力计算依赖 custom CUDA op(`MLA_ATTN` 后端),编译和调试成本较高。本 PR 新增基于 Triton 的 MLA 注意力后端(`TRITON_MLA_ATTN`),以纯 Python Triton kernel 实现 KV cache 写入和 decode split-KV 注意力,提升可移植性与可调试性,并通过预分配缓冲区支持 CUDAGraph。
## Modifications
- 新增 `fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py`:实现 `TritonMLAAttentionBackend` 类,prefill 阶段使用 `flash_attn_unpadded`/`flash_attention_v3_varlen`,decode 阶段使用 Triton split-KV kernel;预分配中间 buffer 以兼容 CUDAGraph
- 新增 `triton_ops/decode_attention.py`:split-KV decode attention Triton kernel(stage1 grouped decode + stage2 cross-split reduction)
- 新增 `triton_ops/mla_cache_kernel.py`:将 `[compressed_kv || k_pe]` 写入 paged latent cache 的 Triton kernel
- `fastdeploy/config.py`:`use_mla_cache` 判断加入 `TRITON_MLA_ATTN`
- `fastdeploy/platforms/base.py` / `cuda.py`:注册 `TRITON_MLA_ATTN` backend 枚举与工厂
- `fastdeploy/worker/gpu_model_runner.py`:`mla_cache` 标志及 `isinstance` 判断加入 `TritonMLAAttentionBackend`
- `fastdeploy/model_executor/models/deepseek_v3.py`:添加 `need_do_attention` 守卫,防止 CUDAGraph padding batch 触发空注意力计算
- `scripts/.coveragerc`:将新增 triton 文件排除在覆盖率统计之外
- `tests/deterministic/`:新增 decode attention 和 MLA cache kernel 正确性测试
## Usage or Command
```bash
FD_ATTENTION_BACKEND=TRITON_MLA_ATTN python -m fastdeploy.entrypoints.openai.api_server \
--model deepseek-ai/DeepSeek-V3 ...
```
## Accuracy Tests
N/A(本 PR 未提供与参考实现的精度对比数据;已通过 `tests/deterministic/` 下的正确性测试验证 kernel 输出与 numpy 参考实现的最大绝对误差在 FP16/BF16 容忍范围内)
## Checklist
- [ ] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.总体评价
新增的 Triton MLA 后端整体设计思路清晰,CUDAGraph 兼容的预分配缓冲区方案合理,测试覆盖了正确性和确定性验证。主要关注点:BLOCK_DPE 硬编码限制了对非 DeepSeek-V3 维度 MLA 模型的支持,以及热路径 CPU-GPU 同步需进一步优化。PR 规范(标题 Tag + 描述填写)需完善后合入。
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览❌ 有 2 个 Required 任务失败,需优先处理后方可合并。
2 任务状态汇总2.1 Required任务 : 8/10 通过
2.2 可选任务 — 27/30 通过
3 失败详情(仅 required)Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage — 覆盖率不足(置信度: 高)Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage
覆盖率详情:
根因详情: 关键日志: 修复建议:
修复建议摘要: 为deepseek_v3.py:1108和cuda.py:77新增单元测试或申请豁免 关联变更: PR 新增 Triton MLA Attention 后端,涉及 Approval — PR审批检查(置信度: 高)Approval
根因详情: 关键日志: 修复建议:
修复建议摘要: 请@xyxinyang或@zyyzghb对PR进行Approve审批 关联变更: 新增 |
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.