Skip to content

[OP][RL]update attn_mask_q 2#7371

Merged
ckl117 merged 1 commit intoPaddlePaddle:developfrom
ckl117:dev_attn_mask_q_2
Apr 13, 2026
Merged

[OP][RL]update attn_mask_q 2#7371
ckl117 merged 1 commit intoPaddlePaddle:developfrom
ckl117:dev_attn_mask_q_2

Conversation

@ckl117
Copy link
Copy Markdown
Collaborator

@ckl117 ckl117 commented Apr 13, 2026

Motivation

attn_mask_q最后一个维度仅需2个vec即可表示双向mask

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 13, 2026

Thanks for your contribution!

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-13 21:33 CST

📋 Review 摘要

PR 概述:将 get_attn_mask_q 算子输出张量的最后一个维度从 4 优化为 2,减少内存占用。

变更范围custom_ops/gpu_ops/get_attn_mask_q.cu

影响面 Tag[OP]

📝 PR 规范检查

检查项 状态
标题 Tag ✅ 包含 [OP][RL]
Motivation ✅ 已填写
Modifications ❌ 未填写(建议补充说明变更逻辑)
Unit Tests ❌ 未添加( Checklist 未勾选)

标题建议(当前已符合,但建议补充):

  • 当前:[OP][RL]update attn_mask_q 2
  • 建议:[OP][Optimization] reduce attn_mask_q output dimension from 4 to 2

问题

级别 文件 概述
🟡 建议 custom_ops/gpu_ops/get_attn_mask_q.cu:52 需验证 Paddle flashmask_attention 是否支持 2 维输入
🟡 建议 custom_ops/gpu_ops/get_attn_mask_q.cu:70 缺少验证输出形状变更的单元测试

总体评价

这是一个内存优化 PR,减少了 attn_mask_q 输出张量的维度。但存在下游兼容性风险和缺少单元测试的问题,建议验证后再合入。

startend_row_vec[1] = cu_seqlens_q[max_batch_size];
startend_row_vec[2] = 0;
startend_row_vec[3] = this_batch_q_end;
// startend_row_vec[1] = cu_seqlens_q[max_batch_size];
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 startend_row_vec 从 4 维改为 2 维后,需要验证下游 Paddle 的 flashmask_attention 函数是否支持 2 维输入。

当前 flash_attn_backend.py:153,168attn_mask_q 通过 startend_row_indices 参数传给 Paddle 的 flashmask_attention/flashmask_attention_v4。如果这些函数严格要求 4 维输入,会导致运行时错误。

建议:

  1. tests/layers/test_flash_attn_func.py 中使用 2 维输入运行 test_fa3_with_masktest_fa4 等测试验证兼容性
  2. 或确认 Paddle 的 flashmask_attention 函数支持 2 维 startend_row_indices

reinterpret_cast<int4*>(startend_row_indices_ptr +
cu_seqlens_k_idx * 4)[0] =
reinterpret_cast<int4*>(startend_row_vec)[0];
reinterpret_cast<int2*>(startend_row_indices_ptr +
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 变更了输出张量形状(从 [..., 4] 改为 [..., 2]),但未添加相应的单元测试验证。

建议:

  1. tests/operators/test_flash_mask_attn.py 中添加测试,验证输出形状确实为 2
  2. 验证输出值的正确性([q_end, min_q_idx]
  3. 更新 PR Checklist 勾选 "Add unit tests"

Copy link
Copy Markdown
Collaborator

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@e83d458). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7371   +/-   ##
==========================================
  Coverage           ?   73.51%           
==========================================
  Files              ?      383           
  Lines              ?    53612           
  Branches           ?     8411           
==========================================
  Hits               ?    39414           
  Misses             ?    11517           
  Partials           ?     2681           
Flag Coverage Δ
GPU 73.51% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ckl117 ckl117 merged commit 26c47c2 into PaddlePaddle:develop Apr 13, 2026
34 of 37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants