Skip to content

[Optimization][DeepSeekV3.2]Reducing slot_mapping compute frequency from twice per layer to a single pre-processing step.#7367

Merged
zhoutianzi666 merged 13 commits intoPaddlePaddle:developfrom
ShaneGZhu:runner_dev
Apr 16, 2026
Merged

[Optimization][DeepSeekV3.2]Reducing slot_mapping compute frequency from twice per layer to a single pre-processing step.#7367
zhoutianzi666 merged 13 commits intoPaddlePaddle:developfrom
ShaneGZhu:runner_dev

Conversation

@ShaneGZhu
Copy link
Copy Markdown
Contributor

@ShaneGZhu ShaneGZhu commented Apr 13, 2026

Motivation

使用5层DeepSeekV3.2-Exp-BF16进行Profile
--------------关闭CudaGraph之后的性能提升--------------------
V0版本:图中decode每个Step耗时17.2ms,每层layer耗时3.2ms,其中slotmapping在DSA_Index和DSA_attn中占据不少时间:

  • 一个layer中两次耗时加起来大约300us,占比0.3ms/3.282ms=9.1%
  • 在一个Step中,五层要耗时1.5ms,占比 1.5/17.2 = 8.7%
    image

注意:最终目标是将preprocess逻辑放到了gpu_model_runner中,后续计划作为每一个Step通用的逻辑,而不仅仅是某一个模型的Trick。

V1版本 :关闭CudaGraph +消除冗余计算
image
看到DSA_Indexer和DSA_attn已经没有了slot计算耗时。
结果:图中decode每个Step耗时降低到15.8ms,相比V0版本降低1.4ms(+8.13%)

--------------保持CudaGraph开启之后的性能提升--------------------
V2版本 开启CudaGraph +slotmapping冗余计算:
image
之后图中decode每个Step耗时约7.2ms
V3版本 开启CudaGraph +消除冗余计算:
image
图中decode每个Step耗时约6.2ms,相比V2版本降低1ms(+13.88%)

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

本次优化将 DeepSeekV3.2 模型中 slot_mapping 的计算从每层两次减少为单次预处理,显著减少重复计算,并将通用的slotmapping和postionids移到了gpu_model_runner层。
核心变更修改文件列表 (5个文件):

  • fastdeploy/model_executor/forward_meta.py
  • fastdeploy/model_executor/layers/attention/dsa_attention_backend.py
  • fastdeploy/model_executor/models/deepseek_v3.py
  • fastdeploy/worker/gpu_model_runner.py
  • fastdeploy/worker/input_batch.py
    次要修改:删除了get_position_ids_and_mask_encoder_batch算子不必要的参数,调整单测文件

Usage or Command

None

Accuracy Tests

评测集 指标 Before After
IFEval_gen Prompt-level-strict-accuracy 86.69 87.8
Inst-level-strict-accuracy 91.01 91.61
Prompt-level-loose-accuracy 89.46 90.94
Inst-level-loose-accuracy 93.05 93.65

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 13, 2026

Thanks for your contribution!

PaddlePaddle-bot

This comment was marked as outdated.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 13, 2026

Codecov Report

❌ Patch coverage is 96.00000% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@14d5566). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/models/deepseek_v3.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7367   +/-   ##
==========================================
  Coverage           ?   73.89%           
==========================================
  Files              ?      398           
  Lines              ?    54943           
  Branches           ?     8606           
==========================================
  Hits               ?    40598           
  Misses             ?    11632           
  Partials           ?     2713           
Flag Coverage Δ
GPU 73.89% <96.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

zhoutianzi666
zhoutianzi666 previously approved these changes Apr 15, 2026
Comment thread fastdeploy/worker/gpu_model_runner.py Outdated
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

@ShaneGZhu ShaneGZhu marked this pull request as ready for review April 15, 2026 12:32
zhoutianzi666
zhoutianzi666 previously approved these changes Apr 16, 2026
chang-wenbin
chang-wenbin previously approved these changes Apr 16, 2026
Copy link
Copy Markdown
Collaborator

@chang-wenbin chang-wenbin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测需要check修复下

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-16 15:40 CST

📋 Review 摘要

PR 概述:将 DeepSeekV3.2 模型中 slot_mapping 的计算从每层两次冗余计算优化为在 gpu_model_runner 层的单次预处理,同时清理 mask_encoder_batch 不再使用的参数。
变更范围gpu_model_runnerinput_batchdeepseek_v3 模型、dsa_attention_backend、CUDA 算子
影响面 TagOptimization Models OP Executor

问题

级别 文件 概述
🟡 建议 gpu_model_runner.py:1284 attn_backends[0] 缺少空列表防御检查
🟡 建议 gpu_model_runner.py:1296 assert 缺少描述性错误信息

总体评价

优化方案设计合理,将 slot_mappingposition_ids 的计算从模型层(deepseek_v3.py)和算子层(dsa_attention_backend.py)提升到 gpu_model_runner 统一预处理,消除了每层重复计算,Profile 数据也验证了性能收益。重复的 compute_slot_mapping 函数被正确移除,mask_encoder_batch 参数的清理也干净彻底。两条小建议均为代码健壮性改进,不阻塞合入。

Results are stored in self.forward_meta.
"""
# NOTE(zhushengguang): Only support MLAAttentionBackend and DSAAttentionBackend currently.
if not isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 此处直接访问 self.attn_backends[0] 未检查列表是否为空。

对比同文件 _process_reorder 方法中的写法 if self.attn_backends and getattr(...) 做了空列表保护。虽然 _initialize_attn_backend__init__ 中先于此方法被调用,保证了 attn_backends 非空,但建议保持一致的防御式编程风格。

建议修改:

if not self.attn_backends or not isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)):
    return

)
block_size = self.cache_config.block_size
block_idx = position_ids // block_size # [num_tokens]
assert self.forward_meta.batch_id_per_token.shape == block_idx.shape
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 此处 assert 语句在生产环境中可能被优化掉(python -O 模式),建议添加描述性的错误信息方便排查问题。

assert self.forward_meta.batch_id_per_token.shape == block_idx.shape, (
    f"Shape mismatch: batch_id_per_token {self.forward_meta.batch_id_per_token.shape} "
    f"vs block_idx {block_idx.shape}"
)

@zhoutianzi666 zhoutianzi666 merged commit 2d8338f into PaddlePaddle:develop Apr 16, 2026
36 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants