-
Notifications
You must be signed in to change notification settings - Fork 756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev Align torch avgpool #5610
Dev Align torch avgpool #5610
Conversation
CHECK_EQ_OR_RETURN(stride.size(), dim); | ||
for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0); } | ||
for (int32_t i = 0; i < padding.size(); i++) { | ||
CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个限制为什么要存在呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
@@ -534,7 +313,7 @@ def __init__( | |||
): | |||
super().__init__() | |||
self.kernel_size = _triple(kernel_size) | |||
self.stride = _triple(stride) if stride is not None else _triple(kernel_size) | |||
self.stride = _triple(stride) if (stride is not None) else _triple(kernel_size) | |||
data_format = "NCDHW" | |||
self.channel_pos = ( | |||
"channels_last" if data_format == "NDHWC" else "channels_first" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NDHW没听过,是不是写错了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
D -> depth。也有人写成NTCHW,T->time
): | ||
super().__init__() | ||
self.kernel_size = _pair(kernel_size) | ||
data_format = "NCHW" # only support "NCHW" for now ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是支持NCHW和NHWC那么可以删掉这个注释,否则加assert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里暂时保留data_format参数,原因是AMP的dataformat会用到NHWC,如果有需要可以基于这个补充对应的NHWC版本kernel
Speed stats:
|
对齐Pytorch的avgpool系列
和maxpool部分参数不同,因此这个PR暂时还未考虑将这两个pool相关代码共用
doctest