Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration flash attention #49869

Merged
merged 21 commits into from Mar 1, 2023
Merged

Conversation

kuizhiqing
Copy link
Member

@kuizhiqing kuizhiqing commented Jan 16, 2023

PR types

New features

PR changes

OPs

Describe

Integrating flash-attention to PaddlePaddle.

Usage

from paddle.nn.functional.flash_attention import flash_attention
flash_attention(q, k, v, dropout)

Validation with PaddelFleetX GPT 1.3B model.

Performance impact: 700ms/step -> 619ms/step, ~9% speed up.

Convergence results shows as follows,
flash_loss_1 (3)

@paddle-bot
Copy link

paddle-bot bot commented Jan 16, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot
Copy link

paddle-bot bot commented Jan 16, 2023

❌ The PR is not created using PR's template. You can refer to this Demo.
Please use PR's template, it helps save our maintainers' time so that more developers get helped.

sneaxiy
sneaxiy previously approved these changes Feb 23, 2023
sneaxiy
sneaxiy previously approved these changes Mar 1, 2023
JiabinYang
JiabinYang previously approved these changes Mar 1, 2023
Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for pkg size

zhangbo9674
zhangbo9674 previously approved these changes Mar 1, 2023
Copy link
Contributor

@zhangbo9674 zhangbo9674 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for setup

cmake/external/flashattn.cmake Outdated Show resolved Hide resolved
paddle/phi/backends/dynload/flashattn.cc Outdated Show resolved Hide resolved
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu Outdated Show resolved Hide resolved
python/paddle/nn/functional/flash_attention.py Outdated Show resolved Hide resolved
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu Outdated Show resolved Hide resolved
paddle/phi/kernels/gpu/flash_attn_kernel.cu Outdated Show resolved Hide resolved
Copy link
Contributor

@zyfncg zyfncg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeff41404 翔哥看下这个接口算是新增API吗?

paddle/phi/api/yaml/ops.yaml Show resolved Hide resolved
paddle/phi/infermeta/ternary.h Outdated Show resolved Hide resolved
paddle/phi/kernels/flash_attn_grad_kernel.h Outdated Show resolved Hide resolved
paddle/phi/kernels/gpu/arange_kernel.cu Outdated Show resolved Hide resolved
@kuizhiqing kuizhiqing dismissed stale reviews from zhangbo9674, JiabinYang, and sneaxiy via bb70b92 March 1, 2023 07:10
Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for pkg size

Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuizhiqing kuizhiqing requested a review from sneaxiy March 1, 2023 12:46
@kuizhiqing kuizhiqing changed the title [WIP] integration flash attention Integration flash attention Mar 1, 2023
@sneaxiy sneaxiy merged commit 6161178 into PaddlePaddle:develop Mar 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants