Skip to content

[OP] support cfp8 in blackwell mla#7876

Merged
zhoutianzi666 merged 1 commit into
PaddlePaddle:developfrom
zhoutianzi666:add_cfp8
May 21, 2026
Merged

[OP] support cfp8 in blackwell mla#7876
zhoutianzi666 merged 1 commit into
PaddlePaddle:developfrom
zhoutianzi666:add_cfp8

Conversation

@zhoutianzi666
Copy link
Copy Markdown
Collaborator

@zhoutianzi666 zhoutianzi666 commented May 21, 2026

🤖 Paddle-CI-Agent | pr_review | 2026-05-21 13:50:38

📋 Review 摘要

PR 概述:为 MLA Attention 的 Prefill Cache 写入及 Blackwell GPU 的 Decode 路径增加 FP8 KV Cache 量化支持。
变更范围custom_ops/gpu_ops/append_attn/(CUDA kernel)、fastdeploy/model_executor/layers/attention/(Python 调度层)
影响面 Tag[Quantization] [OP]

问题

级别 文件 概述
🟡 建议 fastdeploy/model_executor/layers/attention/mla_attention_backend.py:1003 compiled_mla 全局缓存未区分 FP8/FP16 kernel 类型
❓ 疑问 custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh:233 FP8 量化无 scale 因子,精度影响未知
📝 PR 规范 标题/描述不合规

📝 PR 规范检查

PR 标题为 commit,完全不符合 [Tag] 描述 格式;PR 描述所有 section 均为空占位符,不合规。

标题建议(可直接复制):

  • [Quantization] Add FP8 KV cache support for MLA prefill write and Blackwell decode

PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):

## Motivation
为 MLA(Multi-head Latent Attention)的 Prefill Cache 写入路径及 Blackwell GPU Decode 路径新增 FP8(e4m3)KV Cache 量化支持,以降低 KV Cache 显存占用并提升 Blackwell 推理吞吐。

## Modifications
- `custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu/.cuh``PrefillMLAWriteCache` 新增 `cache_quant_type_str` 参数;当取值 `cache_fp8` 时,对 kv_nope/kv_pe 执行 clamp+cast 写入 FP8 KV cache;保留 `none` 的原始 BF16/FP16 路径
- `fastdeploy/model_executor/layers/attention/mla_attention_backend.py``forward_mixed` 中改为读取 `layer.cache_quant_type_str`(默认 `"none"`)传入 prefill cache op;`mla_blackwell` 新增 FP8 分支,当 `latent_cache.dtype == uint8` 时切换到 `mla_decode_fp8` kernel 执行

## Usage or Command
N/A

## Accuracy Tests
N/A

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

总体评价

本 PR 功能方向清晰,FP8 KV Cache 量化的 kernel 实现逻辑基本正确;compiled_mla 全局缓存需增加 kernel 类型区分,FP8 量化精度影响建议补充验证结果。PR 规范(标题/描述)需完善后合入。

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 21, 2026

Thanks for your contribution!

@zhoutianzi666 zhoutianzi666 changed the title commit support cfp8 in mla May 21, 2026
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-21 13:50:38

📋 Review 摘要

PR 概述:为 MLA Attention 的 Prefill Cache 写入及 Blackwell GPU 的 Decode 路径增加 FP8 KV Cache 量化支持。
变更范围custom_ops/gpu_ops/append_attn/(CUDA kernel)、fastdeploy/model_executor/layers/attention/(Python 调度层)
影响面 Tag[Quantization] [OP]

问题

级别 文件 概述
🟡 建议 fastdeploy/model_executor/layers/attention/mla_attention_backend.py:1003 compiled_mla 全局缓存未区分 FP8/FP16 kernel 类型
❓ 疑问 custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh:233 FP8 量化无 scale 因子,精度影响未知
📝 PR 规范 标题/描述不合规

📝 PR 规范检查

PR 标题为 commit,完全不符合 [Tag] 描述 格式;PR 描述所有 section 均为空占位符,不合规。

标题建议(可直接复制):

  • [Quantization] Add FP8 KV cache support for MLA prefill write and Blackwell decode

PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):

## Motivation
为 MLA(Multi-head Latent Attention)的 Prefill Cache 写入路径及 Blackwell GPU Decode 路径新增 FP8(e4m3)KV Cache 量化支持,以降低 KV Cache 显存占用并提升 Blackwell 推理吞吐。

## Modifications
- `custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu/.cuh``PrefillMLAWriteCache` 新增 `cache_quant_type_str` 参数;当取值 `cache_fp8` 时,对 kv_nope/kv_pe 执行 clamp+cast 写入 FP8 KV cache;保留 `none` 的原始 BF16/FP16 路径
- `fastdeploy/model_executor/layers/attention/mla_attention_backend.py``forward_mixed` 中改为读取 `layer.cache_quant_type_str`(默认 `"none"`)传入 prefill cache op;`mla_blackwell` 新增 FP8 分支,当 `latent_cache.dtype == uint8` 时切换到 `mla_decode_fp8` kernel 执行

## Usage or Command
N/A

## Accuracy Tests
N/A

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

总体评价

本 PR 功能方向清晰,FP8 KV Cache 量化的 kernel 实现逻辑基本正确;compiled_mla 全局缓存需增加 kernel 类型区分,FP8 量化精度影响建议补充验证结果。PR 规范(标题/描述)需完善后合入。

from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16

# from mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8
if use_fp8_cache_kv:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 compiled_mla 全局缓存不区分 FP8/FP16 kernel 类型

当前逻辑:compiled_mla 只在 None 时编译一次并持久复用。若在同一进程内曾以 use_fp8_cache_kv=False(FP16 kernel)完成初始化,后续以 use_fp8_cache_kv=True 调用时,compiled_mla 仍指向 FP16 版本,传入 FP8 tensor 会引发 dtype mismatch 运行时错误(反之亦然)。

建议修复:使用独立变量或 dict 分别缓存两个 kernel 的编译结果:

global compiled_mla_fp8, compiled_mla_fp16
if use_fp8_cache_kv:
    if compiled_mla_fp8 is None:
        compiled_mla_fp8 = cute.compile(mla, ...)
    compiled_mla_fp8(...)
else:
    if compiled_mla_fp16 is None:
        compiled_mla_fp16 = cute.compile(mla, ...)
    compiled_mla_fp16(...)

Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);

if constexpr (std::is_same_v<CT, __nv_fp8_e4m3>) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 FP8 量化未使用 scale 因子

当前实现等价于 scale=1.0 的静态量化(仅 clamp 至 FP8 e4m3 的 ±448 范围)。若 MLA KV cache 的激活值实际分布在较小量级(如 ±10),FP8 e4m3 在该范围内只有约 4 个指数级别,精度损失可能不可忽略。请确认:

  1. 是否有针对 DeepSeek-R1 等目标模型的量化精度对比数据?
  2. 是否刻意省略 scale(如激活已经过归一化处理)?建议在 PR 描述的 Accuracy Tests 段补充说明。

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 21, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-21 15:38:56

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

当前 Required 任务有 1 个失败、0 个运行中、0 个等待中,暂不建议合入;失败点为主测试任务的 diff coverage 门禁。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
41(0) 41 37 3 0 1 0

2 任务状态汇总

日志列说明:失败任务直接使用日志链接,运行中任务链接到对应 Job。

2.1 Required任务 : 9/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage 1h21m PR问题:Python新增分支 diff 覆盖率 0% tests/layers补充MLA FP8分支用例 Job -
其余 9 个必选任务通过 - - - - -

2.2 可选任务 — 28/31 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Check PR Template 13s Job -
Trigger Jenkins for PR 7m19s Job -
⏸️ CI_HPU - - -
其余 28 个可选任务通过 - - -

3 失败详情(仅 required)

Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage — 代码覆盖率不足(置信度: 高)

Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage

  • 状态: ❌ 失败
  • 错误类型: 代码覆盖率不足
  • 置信度: 高
  • 根因摘要: Python新增分支 diff 覆盖率 0%
  • 分析器: ci_analyze_unittest_fastdeploy

失败用例: 无。日志显示 All tests passed,失败发生在 Verify Code Coverage Threshold (80%) 步骤。

根因详情:
PR 新增/修改了 fastdeploy/model_executor/layers/attention/mla_attention_backend.py 中 MLA FP8 cache 相关逻辑,包括 latent_cache.dtype == paddle.uint8 时的 FP8 cast/view、mla_decode_fp8 kernel 选择以及 compiled_mla 调用参数。diff coverage 报告显示这些新增 Python 行没有被现有单测覆盖,导致总 diff coverage 为 0%,低于 80% 门禁,因此 CI 以 exit code 9 失败。

关键日志:

All tests passed
Coverage generation failed (exit code 9)
Failure. Coverage is below 80%.
fastdeploy/model_executor/layers/attention/mla_attention_backend.py (0.0%): Missing lines 1001-1005,1054-1055,1059,1063,1118
GPU Patch Coverage Details:
"total_num_lines": 10, "total_num_violations": 10, "total_percent_covered": 0

修复建议:

  1. tests/layers/ 下为 MLAAttentionBackend.mla_blackwell 增加 FP8 cache 分支单测,构造/Mock latent_cache.dtype == paddle.uint8,覆盖 mla_decode_fp8 import、kernel 选择和 compiled_mla(...) 调用路径(对应 mla_attention_backend.py L1001-L1005、L1054-L1063)。
  2. 如该 Blackwell/CUTLASS 路径在当前 CI 环境不可直接执行,可通过 mock cutlasscuda.bindings.drivermla_decode_fp8/mla_decode_fp16cute.compile 的方式做轻量单元测试,只验证分支选择和调用参数。
  3. 覆盖报告还标出 mla_attention_backend.py L1118,可补充 mha_baseline 的最小 shape 校验用例,或确认该行是否由 diff 计算误纳入后调整覆盖/豁免策略。

修复建议摘要: tests/layers补充MLA FP8分支用例

关联变更: fastdeploy/model_executor/layers/attention/mla_attention_backend.py L812、L950-L1068;custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu L54-L98;custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh L233-L267

链接: 查看日志

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 21, 2026

Codecov Report

❌ Patch coverage is 0% with 11 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@5348902). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...executor/layers/attention/mla_attention_backend.py 0.00% 11 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7876   +/-   ##
==========================================
  Coverage           ?   63.63%           
==========================================
  Files              ?      462           
  Lines              ?    64492           
  Branches           ?     9889           
==========================================
  Hits               ?    41037           
  Misses             ?    20673           
  Partials           ?     2782           
Flag Coverage Δ
GPU 72.76% <0.00%> (?)
XPU 7.11% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@zhoutianzi666 zhoutianzi666 changed the title support cfp8 in mla [OP] support cfp8 in blackwell mla May 21, 2026
Copy link
Copy Markdown
Collaborator

@EmmonsCurse EmmonsCurse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~ Skip coverage check as it mainly relies on tests with sm_version >= 100

@zhoutianzi666 zhoutianzi666 merged commit 4f14845 into PaddlePaddle:develop May 21, 2026
57 of 61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants