Skip to content

[BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn#7210

Open
xiaoxiaohehe001 wants to merge 2 commits intoPaddlePaddle:developfrom
xiaoxiaohehe001:fix_flash_mask_attn_sm90
Open

[BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn#7210
xiaoxiaohehe001 wants to merge 2 commits intoPaddlePaddle:developfrom
xiaoxiaohehe001:fix_flash_mask_attn_sm90

Conversation

@xiaoxiaohehe001
Copy link
Copy Markdown
Collaborator

@xiaoxiaohehe001 xiaoxiaohehe001 commented Apr 7, 2026

Motivation

  • 修复 DispatchFlashAttentionMask 中 batch_size 的推导来源:由 cu_seq_q 改为 cu_seq_k
  • 注释掉运行时的 PADDLE_ENFORCE shape 校验

Background

在 SM90 flash mask attention 算子中,cu_seqlens_q 和 seq_lens_encoder 的输入 shape 可能按 max_batch 维度预分配,其实际有效长度可能小于 tensor 的第一维大小。此时若以 cu_seq_q.dims()[0] - 1 推导 batch_size,会得到一个偏大的值(等于 max_batch 而非真实 batch size),导致后续 kernel launch 的 batch 维度不正确。

cu_seq_k 始终按真实 batch size 填充,因此改为从 cu_seq_k 推导 batch_size 可获得正确值。

同时,由于 cu_seqlens_q / seq_lens_encoder 的 shape 可能为 max_batch(大于实际 batch size),原有的 PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0]) 等断言将误报失败,因此暂时注释掉相关校验。

infoflow 2026-04-07 15-45-41

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 7, 2026 07:44
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 7, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 针对 SM90 的 flash_mask_attention 自定义算子,修复 batch_size 推导来源不正确导致的 kernel launch 维度错误,并放宽(移除)部分运行时 shape 断言以兼容按 max_batch 预分配但有效 batch 更小的输入形态。

Changes:

  • batch_size 的推导从 cu_seqlens_q 改为 cu_seqlens_k(避免 max_batch 预分配导致的偏大 batch_size)。
  • 注释/移除对 seq_len_encoderbatch_size 严格相等的运行时 shape 校验(避免误报)。


PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(head_dim == 128, "Unmatched shape");
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 batch_size 改为由 cu_seq_k 推导是合理的,但同时移除了对 seq_len_encoder / cu_seq_q 与 batch_size 一致性的校验后,若出现 cu_seq_k 真实 batch_size 大于 seq_len_encoder.dims()[0] 或 cu_seq_q.dims()[0]-1 的情况,kernel 仍会按 grid_dims.z=batch_size 访问 seq_len_encoder[bidb] / cu_seq_q[bidb],会产生越界读并导致未定义行为。建议把原来的“==”校验放宽为下界校验(例如 seq_len_encoder.dims()[0] >= batch_size 且 cu_seq_q.dims()[0] >= batch_size+1),至少保证不会 OOB;如果允许更小,也需要相应收缩 params.batch_size / launch grid。

Suggested change
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape");
PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape");

Copilot uses AI. Check for mistakes.
Comment on lines +52 to 56
const int batch_size = cu_seq_k.dims()[0] - 1;

PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(head_dim == 128, "Unmatched shape");
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本次修改是为兼容 cu_seqlens_q / seq_lens_encoder 可能按 max_batch 预分配但有效 batch 更小的场景。建议补充一个单测覆盖该 case(例如构造 cu_seq_q/seq_len_encoder 的 first-dim > 实际 batch_size,且 cu_seq_k 仍为真实 batch_size+1),以防后续有人恢复“==”断言或再次把 batch_size 推导改回 cu_seq_q 导致回归。

Copilot generated this review using guidance from repository custom instructions.
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@18f0124). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7210   +/-   ##
==========================================
  Coverage           ?   73.25%           
==========================================
  Files              ?      376           
  Lines              ?    52949           
  Branches           ?     8264           
==========================================
  Hits               ?    38789           
  Misses             ?    11443           
  Partials           ?     2717           
Flag Coverage Δ
GPU 73.25% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants