Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused attention op backward and python layer. #36498

Merged
merged 62 commits into from Oct 26, 2021

Conversation

limin2021
Copy link
Contributor

@limin2021 limin2021 commented Oct 18, 2021

PR types

New features

PR changes

OPs

Describe

  1. 功能:本PR的目标是提高attention模块的计算性能。
    为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
    为了减少防存开销,本PR采取了两种优化方法:
    (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
    (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;

  2. fused_attention_op 实现的计算逻辑:
    image

  3. fused_attention_op与paddle已有的MultiHeadAttention layer的不同:
    (1)计算逻辑范围扩大了,详见上面的伪代码。
    (2)q, k, v的weight存储格式不一样。
    原有的:保存在三个weight张量中,WQ, WK, WV
    本PR:保存在一个weight张量中,qkv_weight
    由WQ, WK, WV得到qkv_weight的方法:
    image

  4. 实现:
    本PR是fused_attention_op 的反向实现,具体细节:

(1)fused_attention_op.cc and fused_attention_op.cu
The C++ impl of backward for fused_attention_op.
Related preceding RRs:
#34883, #35308, #35350 #35621 , #35903, #35905

(2)functional/fused_attention/fused_mult_head_attention():
Add static graph construction method.

(3)test_fused_attention_op.py
Add code to test the correctness of backward of fused_attention_op.

(4)fused_transformer.py/FusedMultiHeadAttention layer:
Add FusedMultiHeadAttention layer.

(5)test_fused_attention_op_api.py
Test the correctness of fused_attention_op python API, both dynamic and static graph.

Unittest results
b6c4dbda073ab3442b30cd770a84ce24

limin2021 and others added 30 commits September 22, 2021 05:01
zkh2016
zkh2016 previously approved these changes Oct 25, 2021
zkh2016
zkh2016 previously approved these changes Oct 26, 2021
xingfeng01
xingfeng01 previously approved these changes Oct 26, 2021
lanxianghit
lanxianghit previously approved these changes Oct 26, 2021
@lanxianghit lanxianghit merged commit 5119428 into PaddlePaddle:develop Oct 26, 2021
limin2021 added a commit to limin2021/Paddle that referenced this pull request Oct 26, 2021
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
lanxianghit pushed a commit that referenced this pull request Oct 27, 2021
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
ghost pushed a commit to piotrekobi/Paddle that referenced this pull request Nov 3, 2021
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants