[OP][RL]update attn_mask_q 2#7371
Conversation
|
Thanks for your contribution! |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review |
2026-04-13 21:33 CST
📋 Review 摘要
PR 概述:将 get_attn_mask_q 算子输出张量的最后一个维度从 4 优化为 2,减少内存占用。
变更范围:custom_ops/gpu_ops/get_attn_mask_q.cu
影响面 Tag:[OP]
📝 PR 规范检查
| 检查项 | 状态 |
|---|---|
| 标题 Tag | ✅ 包含 [OP] 和 [RL] |
| Motivation | ✅ 已填写 |
| Modifications | ❌ 未填写(建议补充说明变更逻辑) |
| Unit Tests | ❌ 未添加( Checklist 未勾选) |
标题建议(当前已符合,但建议补充):
- 当前:
[OP][RL]update attn_mask_q 2 - 建议:
[OP][Optimization] reduce attn_mask_q output dimension from 4 to 2
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🟡 建议 | custom_ops/gpu_ops/get_attn_mask_q.cu:52 |
需验证 Paddle flashmask_attention 是否支持 2 维输入 |
| 🟡 建议 | custom_ops/gpu_ops/get_attn_mask_q.cu:70 |
缺少验证输出形状变更的单元测试 |
总体评价
这是一个内存优化 PR,减少了 attn_mask_q 输出张量的维度。但存在下游兼容性风险和缺少单元测试的问题,建议验证后再合入。
| startend_row_vec[1] = cu_seqlens_q[max_batch_size]; | ||
| startend_row_vec[2] = 0; | ||
| startend_row_vec[3] = this_batch_q_end; | ||
| // startend_row_vec[1] = cu_seqlens_q[max_batch_size]; |
There was a problem hiding this comment.
🟡 建议 startend_row_vec 从 4 维改为 2 维后,需要验证下游 Paddle 的 flashmask_attention 函数是否支持 2 维输入。
当前 flash_attn_backend.py:153,168 将 attn_mask_q 通过 startend_row_indices 参数传给 Paddle 的 flashmask_attention/flashmask_attention_v4。如果这些函数严格要求 4 维输入,会导致运行时错误。
建议:
- 在
tests/layers/test_flash_attn_func.py中使用 2 维输入运行test_fa3_with_mask、test_fa4等测试验证兼容性 - 或确认 Paddle 的
flashmask_attention函数支持 2 维startend_row_indices
| reinterpret_cast<int4*>(startend_row_indices_ptr + | ||
| cu_seqlens_k_idx * 4)[0] = | ||
| reinterpret_cast<int4*>(startend_row_vec)[0]; | ||
| reinterpret_cast<int2*>(startend_row_indices_ptr + |
There was a problem hiding this comment.
🟡 建议 变更了输出张量形状(从 [..., 4] 改为 [..., 2]),但未添加相应的单元测试验证。
建议:
- 在
tests/operators/test_flash_mask_attn.py中添加测试,验证输出形状确实为 2 - 验证输出值的正确性(
[q_end, min_q_idx]) - 更新 PR Checklist 勾选 "Add unit tests"
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #7371 +/- ##
==========================================
Coverage ? 73.51%
==========================================
Files ? 383
Lines ? 53612
Branches ? 8411
==========================================
Hits ? 39414
Misses ? 11517
Partials ? 2681
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
attn_mask_q最后一个维度仅需2个vec即可表示双向mask
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.