Skip to content

[Models] add fleet model fallback#7732

Open
xiaoguoguo626807 wants to merge 5 commits intoPaddlePaddle:developfrom
xiaoguoguo626807:fleet
Open

[Models] add fleet model fallback#7732
xiaoguoguo626807 wants to merge 5 commits intoPaddlePaddle:developfrom
xiaoguoguo626807:fleet

Conversation

@xiaoguoguo626807
Copy link
Copy Markdown

@xiaoguoguo626807 xiaoguoguo626807 commented May 7, 2026

Motivation

新增 PaddleFleet 作为模型推理后端(--model-impl paddlefleet),通过将 PaddleFleet TransformerLayer 中的 core_attention 替换为 FastDeploy Attention 内核,实现在 PaddleFleet 模型结构上复用 FastDeploy 的 KV Cache 和高性能 Attention 计算。

Modifications

  • config.py: 新增 paddlefleetModelImpl 类型定义
  • engine/args_utils.py: 支持 --model-impl paddlefleet CLI 参数,并补充校验逻辑
  • model_executor/models/paddleformers/base_fleet.py: 新增 PaddleFleetModelBase 基类、FastDeployAttention 层及 patch_paddlefleet_core_attention 替换函数
  • model_executor/models/paddleformers/__init__.py: 注册 PaddleFleetForCausalLM 模型类

Usage or Command

python -m fastdeploy.entrypoints.openai.api_server \
    --model /path/to/model \
    --model-impl paddlefleet

Accuracy Tests

N/A(本 PR 新增 PaddleFleet 推理后端,尚未提供与参考实现的 logits 对齐数据)

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 7, 2026

Thanks for your contribution!

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 8, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-08 18:32:13

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

存在 1 个 Required 失败任务,需优先处理后方可合并。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
19(0) 19 12 2 4 1 0

2 任务状态汇总

2.1 Required任务 : 1/2 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 7s PR问题:PR新增多处logger.info调用,未获指定RD审批 联系xyxinyang或zyyzghb对本PR进行审批 Job -
其余 1 个必选任务通过 - - - - -

2.2 可选任务 — 11/17 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Check PR Template 16s Job -
xpu_build_test / xpu-build-test - Job -
FD-Build-Linux / fd-build - Job -
Run iluvatar Tests / run_iluvatar_cases - Job -
Trigger Jenkins for PR - Job -
⏸️ CI_HPU - - -
其余 11 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 代码规范(置信度: 高)

Approval

  • 状态: ❌ 失败
  • 错误类型: 代码规范
  • 置信度: 高
  • 根因摘要: PR新增多处logger.info调用,未获指定RD审批
  • 分析器: 通用分析(fallback)

根因详情:
check_approval.sh 脚本检测到 PR diff 中包含 logger.info 等日志修改,触发了 FastDeploy 的日志行为审批规则。该规则要求修改 .info/.debug/.error/log_request 相关日志行为时,必须获得指定 FastDeploy RD(xyxinyang(zhouchong) 或 zyyzghb(zhangyongyue))的审批。PR 当前尚未获得该审批,导致检查失败(exit code 6)。

关键日志:

Detected log modification in diff:
+            logger.info("Initializing PaddleFormers backend.")
+            logger.info(f"Patched {patched_count} attention layers with FastDeploy")
+                logger.info(
+            logger.info(
+            logger.info(f"Replaced core_attention with FastDeployAttention for layer {fd_layer_id}")
+        logger.info(f"Successfully replaced {patched_count} core_attention layers with FastDeployAttention")
0. You must have one FastDeploy RD (xyxinyang(zhouchong), zyyzghb(zhangyongyue)) approval for modifying logging behavior.
There are 1 approved errors.
##[error]Process completed with exit code 6.

修复建议:

  1. 请联系 FastDeploy RD:xyxinyang(zhouchong) 或 zyyzghb(zhangyongyue) 对本 PR 进行审批
  2. 若认为这些日志添加不属于需要审批的范围,可与相关 RD 确认规则是否适用

修复建议摘要: 联系xyxinyang或zyyzghb对本PR进行审批

链接: 查看日志

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-08 18:46:22

📋 Review 摘要

PR 概述:新增 PaddleFleet 作为模型推理后端(--model-impl paddlefleet),通过替换 TransformerLayer 中的 core_attention 为 FastDeploy Attention 内核,实现在 PaddleFleet 模型结构上复用 FastDeploy KV Cache 和高性能 Attention 计算
变更范围model_executor/models/paddleformers/model_executor/models/model_base.pyconfig.pyengine/args_utils.pyworker/worker_process.py
影响面 Tag[Models] [FDConfig] [Engine]

📝 PR 规范检查

标题 [Models] add fleet model fallback 格式合规,Tag 合法。描述结构完整,但 Checklist 勾选状态有 3 处不一致:

  1. [x] Provide accuracy results. — Accuracy Tests 节为 N/A,应改为 [ ]
  2. [ ] Add unit tests. — 未勾选但 PR 描述中未说明不添加单测的原因(checklist 要求"Please write the reason in this PR if no unit tests")
  3. [x] If the current PR is submitting to the release branch... — 此 PR 目标分支为 develop,非 release 分支,应改为 [ ]

PR 描述建议(可直接复制):

## Motivation
新增 PaddleFleet 作为模型推理后端(`--model-impl paddlefleet`),通过将 PaddleFleet TransformerLayer 中的 `core_attention` 替换为 FastDeploy Attention 内核,实现在 PaddleFleet 模型结构上复用 FastDeploy 的 KV Cache 和高性能 Attention 计算。

## Modifications
- `config.py`: 新增 `paddlefleet``ModelImpl` 类型定义
- `engine/args_utils.py`: 支持 `--model-impl paddlefleet` CLI 参数,并补充校验逻辑
- `model_executor/models/paddleformers/base_fleet.py`: 新增 `PaddleFleetModelBase` 基类、`FastDeployAttention` 层及 `patch_paddlefleet_core_attention` 替换函数
- `model_executor/models/paddleformers/__init__.py`: 注册 `PaddleFleetForCausalLM` 模型类

## Usage or Command
```bash
python -m fastdeploy.entrypoints.openai.api_server \
    --model /path/to/model \
    --model-impl paddlefleet
```

## Accuracy Tests
N/A(本 PR 新增 PaddleFleet 推理后端,尚未提供与参考实现的 logits 对齐数据)

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [x] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
  - 原因:本 PR 为新增 PaddleFleet 推理后端,依赖 paddlefleet 环境,暂未添加单测,后续补充。
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

问题

级别 文件 概述
🟡 建议 fastdeploy/model_executor/models/model_base.py:206 安装命令中含有多余单引号 dev20260507',导致用户看到错误命令
🟡 建议 fastdeploy/model_executor/models/paddleformers/base_fleet.py:432 load_weights 为空 pass,绕过 FastDeploy 权重加载路径,quant_config 等特性可能无法生效
❓ 疑问 fastdeploy/model_executor/models/paddleformers/base_fleet.py:283 get_tensor_model_parallel_group is not None 永远为 True,建议移除冗余判断
📝 PR 规范 Checklist 3 处勾选状态不一致(详见上方 §PR 规范检查)

总体评价

整体设计思路清晰,通过 FastDeployAttention 替换 PaddleFleet core_attention 的 patch 机制合理。建议修复安装命令中的 typo,并补充 load_weights 绕过说明或重构加载逻辑,以避免 quant_config 等 FastDeploy 特性静默失效。

raise ImportError(
"paddlefleet backend requires paddlefleet to be installed.\n"
"Please install with [change cuda version if needed ]:\n"
"python -m pip install paddlefleet==0.3.0.dev20260507' "
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 安装命令字符串含有多余的单引号。

"python -m pip install paddlefleet==0.3.0.dev20260507' "dev20260507 后跟了一个 ',会导致用户看到错误的安装命令。

建议修复:

"python -m pip install paddlefleet==0.3.0.dev20260507 "

model_input = layer(model_input, decoder_input=inputs_embeds)
else:
model_input = layer(model_input)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 load_weights 方法为空(pass),完全绕过了 FastDeploy 标准权重加载路径。

FastDeploy 的 load_weights 是权重加载的核心接口(量化、dtype 转换、sharded 加载均通过此路径)。在 __init__ 中用 from_pretrained 直接加载权重并让 load_weights 留空,可能导致:

  1. quant_config 等 FastDeploy 特性无法生效
  2. 与上层 load_model() 调用约定冲突(调用方预期 load_weights 能正常执行)

建议至少在方法体中添加注释说明此设计决策,或在 load_weights 内调用 from_pretrained 并移除 __init__ 中的加载逻辑。

initialize_fleet(strategy)
logger.info(
f"Initialized PaddleFleet parallel_state via initialize_fleet "
f"(dp={parallel_config.data_parallel_size}, "
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 get_tensor_model_parallel_group is not None 判断永远为 True。

get_tensor_model_parallel_group 是从 paddlefleet.parallel_state 直接 import 的函数对象,代码执行至此处时 import 已成功,该对象不可能为 None。此判断是冗余的,实际生效的只有后半个条件。

建议简化为:

if get_tensor_model_parallel_group(False) is None:

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 3.46154% with 251 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@78b5462). Learn more about missing BASE report.

Files with missing lines Patch % Lines
.../model_executor/models/paddleformers/base_fleet.py 1.65% 237 Missing and 1 partial ⚠️
...oy/model_executor/models/paddleformers/__init__.py 11.11% 7 Missing and 1 partial ⚠️
fastdeploy/model_executor/models/model_base.py 0.00% 4 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7732   +/-   ##
==========================================
  Coverage           ?   71.29%           
==========================================
  Files              ?      397           
  Lines              ?    55835           
  Branches           ?     8741           
==========================================
  Hits               ?    39808           
  Misses             ?    13285           
  Partials           ?     2742           
Flag Coverage Δ
GPU 71.29% <3.46%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants