Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] Support GQA Decoder #58472

Merged
merged 7 commits into from Oct 31, 2023

Conversation

zhoutianzi666
Copy link
Contributor

@zhoutianzi666 zhoutianzi666 commented Oct 30, 2023

PR types

Others

PR changes

Others

Description

Pcard-71500

  • 解码阶段支持GQA,不影响API层面

@zhoutianzi666 zhoutianzi666 changed the title [Paddle Inference] Support GQA [Paddle Inference] Support GQA Decoder Oct 30, 2023
@@ -4279,8 +4279,12 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
MetaTensor* beam_cache_offset_out) {
int bsz = static_cast<int>(x.dims()[0]);
auto cache_kv_dims = cache_kv.dims();
int num_head = static_cast<int>(cache_kv.dims()[2]);
int k_num_head = static_cast<int>(cache_kv.dims()[2]);
int v_num_head = k_num_head;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉没必要拉出来一个v_num_head. 直接一个kv_num_head吧

还需要check num_head % kv_num_head == 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉没必要拉出来一个v_num_head. 直接一个kv_num_head吧

还需要check num_head % kv_num_head == 0

还需要check num_head % kv_num_head == 0 ok,感谢review。

@@ -92,6 +92,9 @@ struct Masked_multihead_attention_params {
int beam_width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把newKV写到KVCache的时候应该有个判断对应blockIdx是group idx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把newKV写到KVCache的时候应该有个判断对应blockIdx是group idx
感谢reveiw。
这里 hi = blockIdx.x表示的还是query的head 的索引,至于key的head索引,用 hi / num_head_per_group 来获得的。

Copy link
Contributor

@MARD1NO MARD1NO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@paddle-bot paddle-bot bot added the contributor External developers label Oct 30, 2023
@zhoutianzi666 zhoutianzi666 merged commit 0651dde into PaddlePaddle:develop Oct 31, 2023
28 checks passed
Copy link

paddle-bot bot commented Nov 2, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link

paddle-bot bot commented Nov 2, 2023

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@paddle-bot paddle-bot bot removed the contributor External developers label Nov 3, 2023
zeroRains pushed a commit to zeroRains/Paddle that referenced this pull request Nov 8, 2023
Support GQA Decoder in masked_multihead_attention.cu
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
Support GQA Decoder in masked_multihead_attention.cu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants