Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Cutlass Fused Multihead Attention in PHI #49910

Closed
wants to merge 37 commits into from
Closed

Integrate Cutlass Fused Multihead Attention in PHI #49910

wants to merge 37 commits into from

Conversation

MARD1NO
Copy link
Contributor

@MARD1NO MARD1NO commented Jan 18, 2023

PR types

New features

PR changes

OPs

Describe

Integrate Cutlass fused multihead attention

You can Add custom attention_mask

cutlass2.11.0兼容问题,参考 #50073 (comment) PR修改即可

文档:
image

Benchmark

dev: cuda11.6 A100 40G

The case is borrowed from xformers

FP16:

Without mask:

Cutlass Time(ms) Naive Time(ms)
32, 128, 16, 64 0.05 0.25
64, 128, 16, 16 0.07 0.27
64, 128, 16, 32 0.08 0.32
64, 512, 16, 16 0.91 2.75
64, 512, 16, 32 0.92 2.97
64, 512, 16, 64 1.04 3.36
64, 1024, 16, 128 6.62 18.31
64, 1024, 16, 256 15.36 22.52

With mask:

Cutlass Time(ms) Naive Time(ms)
32, 128, 16, 64 0.09 0.29
64, 128, 16, 16 0.12 0.35
64, 128, 16, 32 0.13 0.4
64, 512, 16, 16 1.58 3.91
64, 512, 16, 32 1.58 4.12
64, 512, 16, 64 1.68 4.53
64, 1024, 16, 128 9.44 22.92
64, 1024, 16, 256 17.97 27.12

InferCase FP16

(b, q_seq, kv_seq, num_head, head_size) Cutlass Time(us) Naive Time(us)
1, 900, 6000, 32 (PETR FMCA) 376.51 625.1
1, 4096, 4096, 8, 32 (Diffusion FMHA) 554.87 1494.8
1, 4096, 77, 8, 40 (Diffusion FMCA) 34.81 119.14
1, 197, 197, 12, 64 (VIT FMHA) 21.70 73.08

Compare script

def naive_attention_impl(query, key, value, mask, scale): 
    query = paddle.transpose(query, [0, 2, 1, 3])
    key = paddle.transpose(key, [0, 2, 1, 3])
    value = paddle.transpose(value, [0, 2, 1, 3])

    qk_res = paddle.matmul(query, key, transpose_y=True)
    attention = qk_res * scale
    attention = attention + mask 
    softmax_result = paddle.nn.functional.softmax(attention, -1)
    result = paddle.matmul(softmax_result, value)
    result = paddle.transpose(result, [0, 2, 1, 3])
    return result

TODO List

generate.sh “借鉴”自xformers,通过shell脚本生成对应模板特化kernel,实现并行编译,加快编译速度

后续可以考虑采用python脚本来实现Kernel生成。

@paddle-bot
Copy link

paddle-bot bot commented Jan 18, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@mnicely
Copy link

mnicely commented Jan 19, 2023

@MARD1NO I wanted to make you aware of a change coming to our fMHA (773)

@hwu36 for vis

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Jan 19, 2023

@MARD1NO I wanted to make you aware of a change coming to our fMHA (773)

@hwu36 for vis

Thanks, I will keep following this update

@MARD1NO MARD1NO marked this pull request as ready for review February 2, 2023 02:41
@@ -7,7 +7,8 @@ exclude: |
python/paddle/utils/gast/.+|
.+_pb2\.py|
python/paddle/fluid/tests/unittests/npu/.+|
python/paddle/fluid/tests/unittests/mlu/.+
python/paddle/fluid/tests/unittests/mlu/.+|
paddle/phi/kernels/fusion/cutlass/fused_multi_head_attention/.+
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里引入的是外部xformers代码,暂时不做format

from op_test import OpTest

# Ensure we use float type to accumulate
os.environ["FLAGS_gemm_use_half_precision_compute_type"] = "0"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保证对比的naive实现gemm使用float累加

# https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh

#!/bin/bash
set -ex
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里参考使用xformers的算子模板生成脚本,以实现并行编译,加快速度

(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
(my_warp_id / MM0::Mma::WarpCount::kM)};

if (kAddMask) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果要在QKmatmul后加mask,则需要将scale提前在寄存器计算好,而不是放到最后的tiledsoftmax里一起做

cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
}

int32_t mask_iter_m = kMaskBroadcastRow ? 1 : problem_size_0_m;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对mask的行broadcast做了一个特化

@vivienfanghuagood
Copy link
Contributor

LGTM

@MARD1NO MARD1NO marked this pull request as draft February 7, 2023 05:36
@MARD1NO MARD1NO marked this pull request as ready for review February 22, 2023 08:27
@zhoutianzi666
Copy link
Contributor

zhoutianzi666 commented Feb 24, 2023

image

https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/incubate/nn/FusedMultiHeadAttention_cn.html#fusedmultiheadattention

  • 这个API是不是应该去掉cutlass_fused_multi_head attention中的fused?

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Feb 24, 2023

image

https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/incubate/nn/FusedMultiHeadAttention_cn.html#fusedmultiheadattention

  • 这个API是不是应该去掉cutlass_fused_multi_head attention中的fused?

这里的fused是指 q matmul k, softmax, attention matmul v三个操作。原来的FusedMultiHeadAttention实质上也是拆成多个算子来完成操作,这个cutlass算子只启动了一个kernel完成

heavengate
heavengate previously approved these changes Mar 8, 2023
@MARD1NO MARD1NO closed this Mar 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants