New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add oss flash fmha and fmhca support #49438
add oss flash fmha and fmhca support #49438
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
4cecf1d
to
9c1e126
Compare
|
||
void TrtCrossMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { | ||
FusePassBase::Init(name_scope_, graph); | ||
#ifdef PADDLE_WITH_TENSORRT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个if 宏是不是应该包到562行后
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我理解这里进行early stop的话是要在build fusion之前进行, 因此放在ApplyImpl的最开头.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经过讨论后, 目前采用运行期判定的方式进行early stop, 麻烦辛苦再看一下
paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc
Outdated
Show resolved
Hide resolved
FusePassBase::Init(name_scope_, graph); | ||
auto* scope = param_scope(); | ||
|
||
#ifdef PADDLE_WITH_TENSORRT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经过讨论后, 目前采用运行期判定的方式进行early stop, 麻烦辛苦再看一下
paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc
Outdated
Show resolved
Hide resolved
std::get<2>(trt_version) * 10 < | ||
8520) { | ||
VLOG(3) << "Flash attention oss plugin only available for trt version >= " | ||
"8.5.2.2. Stop this pass"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里只是输出了日志?应该return吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看着pass有限定trt 8.5.2.2才注册,这里应该不用判断了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
考虑后还是在这里进行runtime时的early stop, 因此现在在这里加上了return
paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc
Outdated
Show resolved
Hide resolved
paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc
Outdated
Show resolved
Hide resolved
refine compile fix compile
9691d16
to
6a679e8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
OPs
Describe
Add nvidia tensorrt oss plugin flash attention and cross attention support to accelerate the inference speed of stable diffusion and other models.
Using flash attention and cross attention plugin, the stable diffusion latency can be speed up from 1.52s to 1.02s.
Tensorrt 8.5.2 is required to using those plugins.
Using nsys, we can see that the plugin are successful involved by unit test under trt8.5.2.2 environment