-
Notifications
You must be signed in to change notification settings - Fork 758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add deconv cpu impl #5224
Add deconv cpu impl #5224
Conversation
namespace { | ||
|
||
template<typename T> | ||
using Im2ColFunc = void (*)(const T* in_dptr, const ShapeView& in_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im2col 没用到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的。
const int32_t* padding_before, T* in_diff_ptr); | ||
|
||
template<typename T> | ||
using GemmFunc = void (*)(enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b, const int m, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个也没用到,我们是直接调用的 ofgemm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的。
template<typename T> | ||
struct ConvKernelUtil final { | ||
public: | ||
static void NCDHWIm2Col(const T* in_dptr, const ShapeView& in_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im2col 相关的感觉都可以删了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的。
struct ConvOpKernelState final : public user_op::OpKernelState { | ||
Im2ColFunc<T> im2col_func_; | ||
Col2ImFunc<T> col2im_func_; | ||
GemmFunc<T> forward_func_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward_func_ 这个也没用到
int32_t idx_offset = conv_state->idx_offset_; | ||
|
||
FOR_RANGE(int64_t, i, 0, in->shape().At(0)) { | ||
// channels first: col_buf' = weight(T) * out[i]' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释应该是 weight(T) * in[i] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,已修正。
conv_state->weight_5d_shape_.At(0), static_cast<T>(1), weight->dptr<T>(), | ||
GetImgDptr<T>(in, i), static_cast<T>(0), col_buf->mut_dptr<T>()); | ||
|
||
// in' = col2im(col_buf') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out = col2im(col_buf)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的。
添加反卷积的cpu实现。