-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fused_scale_bias_relu_conv_bnstats OP #55026
Add fused_scale_bias_relu_conv_bnstats OP #55026
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
6ea7229
to
6df2d66
Compare
Sorry to inform you that 6df2d66's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@Xreki For your information, here is the overview of this PR: Review changes from last PR #54949:
New OP implementation: OP kernel
OP registry
Unittests
|
0675338
to
6f3ad2a
Compare
CI错误后续 @tianshuo78520a 会帮忙手动批准. @Xreki 请麻烦review. |
float epsilon, | ||
bool fuse_prologue, | ||
bool exhaustive_search, | ||
int64_t accumulation_count, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个参数是什么功能?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BatchNorm normalize的元素个数, 单GPU (非SyncBatchNorm)时为N*H*W
. 详见Table 42.
@@ -0,0 +1,635 @@ | |||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copyright年份2022 -> 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
template <typename T> | ||
using CudnnDataType = phi::backends::gpu::CudnnDataType<T>; | ||
|
||
template <typename T, typename Context> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数前加一些注释,解释下函数的功能吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
.build(); | ||
|
||
std::array<cudnn_frontend::Operation const*, 1> ops = {&finalize_stat_op}; | ||
auto op_graph = cudnn_frontend::OperationGraphBuilder() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L449 - L488看着,不同的组合里面,这部分实现是一样的,是否可以封装成一个函数?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
dev_ctx.GetComputeCapability())); | ||
// attr | ||
float exp_decay = 1. - momentum; | ||
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以用PADDLE_ENFORCE_LE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这只是个警告, 不会quit. 与batchnorm实现一致:
LOG(ERROR) << "Provided epsilon is smaller than " |
using CudnnDataType = phi::backends::gpu::CudnnDataType<T>; | ||
|
||
template <typename T, typename Context> | ||
void _FusedScaleBiasReluConvBnstatsImpl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数名前的_
可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
GPU, | ||
ALL_LAYOUT, | ||
phi::fusion::FusedScaleBiasReluConvBnstatsKernel, | ||
phi::dtype::float16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bfloat16
类型支持吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂不支持.
@@ -0,0 +1,243 @@ | |||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2022 -> 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
|
||
@skip_check_grad_ci(reason="no grap op") | ||
@unittest.skipIf(skip_unit_test(), skip_msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
继承了基类,子类可以不用加skip
装饰器
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -92,6 +92,7 @@ | |||
|
|||
NO_FP16_COMPARED_WITH_FP32_OP_LIST = [ | |||
'fake_quantize_moving_average_abs_max', | |||
'fused_scale_bias_relu_conv_bnstats', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为何要把算子加入这个白名单?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果不加, check_output_with_place
会将FP16的结果与FP32结果做对比, 然而该OP不支持FP32输入, 所以会出错.
@Xreki The changes have commit. Would you please take a look? |
The CI failures need to be manually approved and should not block the review process. @Xreki |
2023-08-08 11:03:40 **************** |
Sorry to inform you that b4cb408's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
b4cb408
to
bc6a1cd
Compare
Sorry to inform you that 5786863's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
std::vector<void*> data_ptrs; | ||
std::vector<int64_t> uids; | ||
int64_t uid = 100; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uid
是什么,为什么设置成100
呢?不同pattern
的uid
都要设置一样吗?建议uid
的管理后续可以考虑统一下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uid是为了区分不同变量的。不需要一样,只要一个cudnn operation graph内不存在冲突即可。
} | ||
|
||
template <typename T, typename Context> | ||
void FusedScaleBiasReluConvBnstatsKernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FusedScaleBiasReluConvBnstatsKernel
是FusedScaleBiasReluConvBnstatsImpl
+BNFinalizeImpl
,所以功能应该是FusedScaleBiasReluConvBn
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是,下个PR修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for new fused op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@Xreki Would you please merge it? |
* Add fused_scale_bias_relu_conv_bnstats op * Review changes * Fix no CUDNN Frontend build * Fix PADDLE_ENFORCE format * Fix PADDLE_ENFORCE CI error * Rename kernel filename * Refactor unittest to use paddle eager_op_test * Fix padding bugs * Review changes * test=cuda117 * test=cuda117
PR types
New features
PR changes
OPs
Description
Please merge this PR after #54949
This PR adds
fused_scale_bias_relu_conv_bnstats
which is needed for ResUnit fusion. It is implemented using CUDNN Frontend API.