Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] Add bias input of mmha and simplify mmha. #56411

Merged
merged 8 commits into from Aug 25, 2023

Conversation

xiaoxiaohehe001
Copy link
Contributor

@xiaoxiaohehe001 xiaoxiaohehe001 commented Aug 17, 2023

PR types

Others

PR changes

Others

Description

Add bias input of mmha and simplify mmha.
关联pr #55344
Pcard-71502

@paddle-bot
Copy link

paddle-bot bot commented Aug 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -3987,6 +3987,7 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

量化的完整需求是
输入为int32/float16/32
输出为int8/float/16/32

我看这里考虑了输出是int8的情况 但是没考虑输入是int32的情况是么

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h"

namespace phi {
namespace fusion {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要把头文件干掉呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要添加多余的头文件,防止被其他调用

@@ -43,6 +45,7 @@ def masked_multihead_attention(
Args:
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
bias (Tensor, optional): The bias tensor. Its shape is [3, num_head, head_dim].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也需要加一下compte_dtype的参数说明~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done~

@@ -77,7 +80,7 @@ def setUp(self):
self.seq_len = 1
self.rotary_emb_dims = 0
self.use_neox_rotary_style = False

self.compute_dtype = "default"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测要不加一下输入为Int32的情形

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测有int32的情况

@@ -53,6 +56,7 @@ def masked_multihead_attention(
seq_len (int, optional): The seq_len, used to get input length. Default 1.
rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1.
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False.
compute_dtype (string): A compute dtype, used to represent the input data type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_dtype的为啥不能根据输入tensor的类型判断呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptq 情况下,输入 x 有可能是int32,如果根据 cache_kv dtype 判断,后续 cache_kv 量化支持还需要修改。

Copy link
Contributor

@vivienfanghuagood vivienfanghuagood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for API change

Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for new args

@heavengate heavengate merged commit 636dc2f into PaddlePaddle:develop Aug 25, 2023
25 of 26 checks passed
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants