[BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn#7210
[BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn#7210xiaoxiaohehe001 wants to merge 2 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 针对 SM90 的 flash_mask_attention 自定义算子,修复 batch_size 推导来源不正确导致的 kernel launch 维度错误,并放宽(移除)部分运行时 shape 断言以兼容按 max_batch 预分配但有效 batch 更小的输入形态。
Changes:
- 将
batch_size的推导从cu_seqlens_q改为cu_seqlens_k(避免 max_batch 预分配导致的偏大 batch_size)。 - 注释/移除对
seq_len_encoder与batch_size严格相等的运行时 shape 校验(避免误报)。
|
|
||
| PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); | ||
| PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); |
There was a problem hiding this comment.
这里把 batch_size 改为由 cu_seq_k 推导是合理的,但同时移除了对 seq_len_encoder / cu_seq_q 与 batch_size 一致性的校验后,若出现 cu_seq_k 真实 batch_size 大于 seq_len_encoder.dims()[0] 或 cu_seq_q.dims()[0]-1 的情况,kernel 仍会按 grid_dims.z=batch_size 访问 seq_len_encoder[bidb] / cu_seq_q[bidb],会产生越界读并导致未定义行为。建议把原来的“==”校验放宽为下界校验(例如 seq_len_encoder.dims()[0] >= batch_size 且 cu_seq_q.dims()[0] >= batch_size+1),至少保证不会 OOB;如果允许更小,也需要相应收缩 params.batch_size / launch grid。
| PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); | |
| PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); | |
| PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape"); | |
| PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape"); |
| const int batch_size = cu_seq_k.dims()[0] - 1; | ||
|
|
||
| PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); | ||
| PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); |
There was a problem hiding this comment.
本次修改是为兼容 cu_seqlens_q / seq_lens_encoder 可能按 max_batch 预分配但有效 batch 更小的场景。建议补充一个单测覆盖该 case(例如构造 cu_seq_q/seq_len_encoder 的 first-dim > 实际 batch_size,且 cu_seq_k 仍为真实 batch_size+1),以防后续有人恢复“==”断言或再次把 batch_size 推导改回 cu_seq_q 导致回归。
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #7210 +/- ##
==========================================
Coverage ? 73.25%
==========================================
Files ? 376
Lines ? 52949
Branches ? 8264
==========================================
Hits ? 38789
Misses ? 11443
Partials ? 2717
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
Background
在 SM90 flash mask attention 算子中,cu_seqlens_q 和 seq_lens_encoder 的输入 shape 可能按 max_batch 维度预分配,其实际有效长度可能小于 tensor 的第一维大小。此时若以 cu_seq_q.dims()[0] - 1 推导 batch_size,会得到一个偏大的值(等于 max_batch 而非真实 batch size),导致后续 kernel launch 的 batch 维度不正确。
cu_seq_k 始终按真实 batch size 填充,因此改为从 cu_seq_k 推导 batch_size 可获得正确值。
同时,由于 cu_seqlens_q / seq_lens_encoder 的 shape 可能为 max_batch(大于实际 batch size),原有的 PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0]) 等断言将误报失败,因此暂时注释掉相关校验。
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.