Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flashattention support qkvpacked and varlen #63289

Merged
merged 18 commits into from
May 7, 2024

Conversation

kircle888
Copy link
Contributor

PR Category

Performance Optimization

PR Types

Performance

Description

FlashAttention支持QKVPacked输入和Padded Varlen输入,以避免几种情况下的预处理和后处理开销
nn.functional.flash_attention增加flash_attn_qkvpacked和flash_attn_varlen_qkvpacked
flash_attn_qkvpacked接受5维输入qkv,形状为[batchsize, seqlen , num_heads/num_heads_k + 2, num_heads_k, head_dim]
flash_attn_varlen_qkvpacked接受4维输入qkv,形状为[total_seq_len, num_heads/num_heads_k + 2, num_heads_k, head_dim],当参数varlen_padded为False时,输入输出为unpad形式(与flash_attn_unpadded相似),当varlen_padded为True时,输入输出为padded形式(即可以直接将flash_attn_qkvpacked的输入batchsize和seqlen维度flatten得到total_seq_len维度)

Copy link

paddle-bot bot commented Apr 7, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@CLAassistant
Copy link

CLAassistant commented Apr 7, 2024

CLA assistant check
All committers have signed the CLA.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 7, 2024
GuoxiaWang
GuoxiaWang previously approved these changes Apr 9, 2024
Copy link
Contributor

@GuoxiaWang GuoxiaWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel如果支持了stride,需要再paddle/fluid/eager/auto_code_generator/generator/eager_gen.py的strided_op_list中添加一下。

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu Outdated Show resolved Hide resolved
paddle/phi/kernels/gpu/flash_attn_kernel.cu Show resolved Hide resolved
paddle/phi/kernels/gpu/flash_attn_kernel.cu Show resolved Hide resolved
Copy link

paddle-ci-bot bot commented Apr 17, 2024

Sorry to inform you that 9a130e6's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

GuoxiaWang
GuoxiaWang previously approved these changes Apr 24, 2024
Copy link
Contributor

@GuoxiaWang GuoxiaWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

sneaxiy
sneaxiy previously approved these changes Apr 24, 2024
@GuoxiaWang
Copy link
Contributor

API doc PaddlePaddle/docs#6608

@jzhang533 jzhang533 self-assigned this Apr 24, 2024
jeff41404
jeff41404 previously approved these changes Apr 24, 2024
XiaoguangHu01
XiaoguangHu01 previously approved these changes Apr 25, 2024
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kircle888 kircle888 dismissed stale reviews from jeff41404, sneaxiy, and GuoxiaWang via 4fc88be April 25, 2024 08:46
jzhang533
jzhang533 previously approved these changes Apr 28, 2024
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

python/paddle/nn/functional/flash_attention.py Outdated Show resolved Hide resolved
python/paddle/nn/functional/flash_attention.py Outdated Show resolved Hide resolved
Comment on lines 517 to 519
out(Tensor): The attention tensor.
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
The dtype can be float16 or bfloat16.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 重复加这段的原因是?
  • 以及 flash_attn_unpadded 似乎没对外暴露,需要公开么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是一个操作错误,已删除
这个PR不会修改flash_attn_unpadded,也不会将它公开

kircle888 and others added 3 commits April 29, 2024 11:08
Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,这个 PR 先合吧,文档小问题之后再单独提 PR 就行


Returns:
- out(Tensor). The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.
- softmax(Tensor). The softmax tensor. None if return_softmax is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns: 后面的内容不要出现冒号..因为官网文档的解析会把冒号前面的部分自动解析成 Returns type(很扯的功能)

Comment on lines +415 to +419
softmax->set_dtype(qkv.dtype());
softmax_lse->set_dtype(qkv.dtype());
if (seed_offset) {
seed_offset->set_dtype(phi::DataType::INT64);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个output可以设置dim吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分是和原有的FlashAttnInferMeta一样的,这几个output的shape是在flash_attn_utils.h的FlashAttnFwdParamsV2中Resize的,没有在这里设置dim

@GuoxiaWang GuoxiaWang requested a review from zyfncg May 7, 2024 08:08
@sneaxiy sneaxiy merged commit d6e4163 into PaddlePaddle:develop May 7, 2024
30 checks passed
yinfan98 pushed a commit to yinfan98/Paddle that referenced this pull request May 7, 2024
* Flashattention support qkvpacked and varlen

* fix codestyle

* fix codestyle

* FlashAttention kvReduceGQA Performance Optimization

* Fix problem with windows

* code clean

* update third_party/flashattn

* update errormsg and docs

* update api

* update doc

* update doctest

* update doc, test=document_fix

* update doc, test=document_fix

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* update doc

---------

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
yinfan98 pushed a commit to yinfan98/Paddle that referenced this pull request May 7, 2024
* Flashattention support qkvpacked and varlen

* fix codestyle

* fix codestyle

* FlashAttention kvReduceGQA Performance Optimization

* Fix problem with windows

* code clean

* update third_party/flashattn

* update errormsg and docs

* update api

* update doc

* update doctest

* update doc, test=document_fix

* update doc, test=document_fix

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* update doc

---------

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
yinfan98 pushed a commit to yinfan98/Paddle that referenced this pull request May 7, 2024
add

int4_1

int4_2

FLAGS_logging_pir_py_code (PaddlePaddle#63981)

* FLAGS_logging_pir_py_code

* FLAGS_logging_pir_py_code_dir

---------

Co-authored-by: jiahy0825 <jiahongyu@baidu.com>

[Cleanup] Remove Flake8 config in `.editorconfig` (PaddlePaddle#64027)

【PIR Dist Op Reg No.19】 reg pull_box_sparse (PaddlePaddle#62982)

* fix

* fix

* fix

* fix

* fix

* fix

* add test

* add

* fix

* fix

* add out

* fix

* codestyle

* fix

* fix backward

* merge

[Dy2St][PIR] Hold backward program in GradNode (PaddlePaddle#63694)

Co-authored-by: xiongkun <xiongkun03@baidu.com>
Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>

split test.cmake: add new test_cases.cmake (PaddlePaddle#64007)

[PIR] Support sparse_slice and sparse_sum in pt (PaddlePaddle#64009)

* support sparse_slice and sparse_sum in pt

* support sparse_slice and sparse_sum in pt

* support sparse_slice and sparse_sum in pt

option for WITH_CPP_TEST (PaddlePaddle#63896)

* option for WITH_CPP_TEST

* fix

* Fix

* Fix

[PIR] Fix `attributes_num` of `SliceArrayOp` (PaddlePaddle#64013)

[Dy2St] Use `full_graph=True` outside dy2st uts (part1) (PaddlePaddle#64058)

[Dy2St] Use `full_graph=True` outside dy2st uts (part2) (PaddlePaddle#64059)

fix typo (PaddlePaddle#64060)

Co-authored-by: jiahy0825 <jiahongyu@baidu.com>

update (PaddlePaddle#64042)

Replace paddle/fluid/platform/device/gpu/gpu_dnn.h (PaddlePaddle#63819)

* Fix

* Fix

* Fix

Clean lookup_table_v2_op.h lookup_table_v2_op.cu (PaddlePaddle#64020)

* Fix

* ci

refine GetTensorListFromArgs (PaddlePaddle#64045)

Revert "【Hackathon 6th Fundable Projects 3 No.60】Remove fluid operator chunk_…" (PaddlePaddle#64050)

This reverts commit 88b1a6e.

[Prim][PIR] support floor_divide op forward in prim pir (PaddlePaddle#64023)

* floor-div-dev

* update test

[CINN] Reconstruct shape_analysis (PaddlePaddle#63790)

* reconstruct shape_analysis

* fix input value shape infer

* fix merge bugs

* fix concat and gather op InferSymbolicShape

* fix merge bug

* fix value_to_shape_or_data hash error and add some checks

* fix set shape for null value

* fix group op lazy infer

* add IsStaticShape check

* fix merge bug

* support static dim check and set for VectorType

* change auto to detail type

[XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (PaddlePaddle#64003)

* [XPU] fix segmentfault caused by setting fix_seed_offset on XPU

* cast attention_mask to float32 when necessary

fix merge bug (PaddlePaddle#64069)

【Fix PIR Unittest No.125、147、481】Fix some 0D uts in PIR mode (part1) (PaddlePaddle#64064)

[Prim][VJP]support autogen to remove unused composite in .yaml (PaddlePaddle#64054)

* support autogen to remove unused composite in .yaml

* fix bug

[PIR] Fix typo `set_pit_tests_properties` -> `set_pir_tests_properties` (PaddlePaddle#64063)

[Dy2St] Use `full_graph=True` outside dy2st uts (part3) (PaddlePaddle#64066)

[PIR save/load] Open more tests for paddle.save and paddle.load (PaddlePaddle#64044)

* open more tests for paddle.save and paddle.load

* fix

API Improvement for paddle.nn.functional.group_norm and paddle.nn.GroupNorm (PaddlePaddle#63881)

* update group_norm

* update trt plugin

* update trt plugin

* fix trt plugin

* fix trt plugin

* fix test

* fix test

* fix ci windows inference

* update kernel function names and add v2 test

* fix

* fix fp16 test

Revert "【Hackathon 6th Fundable Projects 3 No.81】Remove fluid operators ctc_a…" (PaddlePaddle#64049)

This reverts commit 2134ead.

Clean paddle/fluid/operators/fused/attention_layer_norm.h (PaddlePaddle#64051)

* Fix

* Fix

 Replace operators::math to phi::math in fluid/operators (PaddlePaddle#63854)

[CINN]Clean usless loop_reorder_aligment tactic (PaddlePaddle#63998)

* [CINN]Clean usless loop_reorder_aligment tactic

* fix source

【Hackathon 6th Fundable Projects 3 No.396】fluid operator yolo_box_head (PaddlePaddle#63783)

* Fix

* Fix

* Fix

* Fix

* Fix

【Hackathon 6th Fundable Projects 3 No.240】fluid operator moe (PaddlePaddle#63929)

【Hackathon 6th Fundable Projects 3 No.82】fluid operator cudnn_lstm (PaddlePaddle#63936)

* Fix

* Fix

* Fix

* Fix

[CINN] Remove useless log (PaddlePaddle#64052)

[pir_save_load] add pir for test_jit_save_load.py (PaddlePaddle#63958)

* add jit load.train

* modify backward program lost

* modify

* combine eval and train

* modify 8 case of jit.save.load

* modify jit_save_load case

* rename jit_save_load

* change name all

* modify timeout

* modify new case

* modify TestJitSaveLoadMultiMethods

* modify cpu tensor no holder bug

Flashattention support qkvpacked and varlen (PaddlePaddle#63289)

* Flashattention support qkvpacked and varlen

* fix codestyle

* fix codestyle

* FlashAttention kvReduceGQA Performance Optimization

* Fix problem with windows

* code clean

* update third_party/flashattn

* update errormsg and docs

* update api

* update doc

* update doctest

* update doc, test=document_fix

* update doc, test=document_fix

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* update doc

---------

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

【PIR Dist Op Reg No.20】 reg global_gather (PaddlePaddle#63867)

* reg global_gather

* reg global_gather

* reg_global_gather

* fix

* fix

* fix

* fix conflict

* fix conflict

* Update ops_api_gen.py

* Update ops_api_gen.py

Fix backward program kwargs error when process inplace value (PaddlePaddle#63939)

【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True fix (PaddlePaddle#63880)

* support kwargs for recompute when use_reentrant == True

* recover third party

merge main

lint

delete printf

change flash attn version
co63oc pushed a commit to co63oc/Paddle that referenced this pull request May 10, 2024
* Flashattention support qkvpacked and varlen

* fix codestyle

* fix codestyle

* FlashAttention kvReduceGQA Performance Optimization

* Fix problem with windows

* code clean

* update third_party/flashattn

* update errormsg and docs

* update api

* update doc

* update doctest

* update doc, test=document_fix

* update doc, test=document_fix

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* update doc

---------

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet