-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add element-wise multiplication operator. #3787
Conversation
paddle/operators/CMakeLists.txt
Outdated
@@ -57,6 +57,7 @@ op_library(add_op SRCS add_op.cc add_op.cu) | |||
op_library(mean_op SRCS mean_op.cc mean_op.cu) | |||
|
|||
op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) | |||
op_library(element_wise_mul_op SRCS element_wise_mul_op.cc element_wise_mul_op.cu DEPS math_function) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove DEPS math_function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
Done.
auto Y_e = framework::EigenVector<T>::Flatten(*Y); | ||
auto Z_e = framework::EigenVector<T>::Flatten(*Z); | ||
|
||
Z_e.device(context.GetEigenDevice<Place>()) = X_e * Y_e; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The names of variables are all lowercas. https://google.github.io/styleguide/cppguide.html#Variable_Names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to support DotMulProjection
, where the second input is a row vector,
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/gserver/layers/DotMulProjection.cpp#L57
The broadcast
of Eigen can be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both are needed.
paddle/operators/CMakeLists.txt
Outdated
@@ -57,6 +57,7 @@ op_library(add_op SRCS add_op.cc add_op.cu) | |||
op_library(mean_op SRCS mean_op.cc mean_op.cu) | |||
|
|||
op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) | |||
op_library(element_wise_mul_op SRCS element_wise_mul_op.cc element_wise_mul_op.cu DEPS math_function) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Elementwise is one word.
elementwise_mul_op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
Done.
auto Y_e = framework::EigenVector<T>::Flatten(*Y); | ||
auto Z_e = framework::EigenVector<T>::Flatten(*Z); | ||
|
||
Z_e.device(context.GetEigenDevice<Place>()) = X_e * Y_e; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both are needed.
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP(elemwisemul, ops::ElemWiseMulOp, ops::ElemWiseMulOpMaker, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elemwisemul => elementwise_mul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
Done.
auto dOut_e = framework::EigenVector<T>::Flatten(*dOut); | ||
|
||
dX_e.device(ctx.GetEigenDevice<Place>()) = dOut_e * Y_e; | ||
dY_e.device(ctx.GetEigenDevice<Place>()) = X_e * dOut_e; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to support the case where the gradient dX_e or dY_e is not needed.
@qingqing01 Most of the gradient operator need to handle the case where is one of more of the gradient of the input is not needed. Please change the unittest framework to explicitly test whether the op correctly handles this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@emailweixu Ok, I'll change the unit test framework to handle this.
paddle/operators/CMakeLists.txt
Outdated
recurrent_op | ||
scale_op) | ||
op_library(identity_op DEPS scale_op) | ||
op_library(minus_op DEPS scale_op) | ||
op_library(mul_op DEPS math_function) | ||
op_library(elementwise_mul_op SRCS elementwise_mul_op.cc elementwise_mul_op.cu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove line 57 and line 63. There is no DEPS for this op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto x_dim = ctx.Input<Tensor>("X")->dims(); | ||
auto y_dim = ctx.Input<Tensor>("Y")->dims(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PADDLE_ENFORCE_NOT_NULL
is needed for input X and Y, similar check in ElementWiseMulOpGrad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "The first input of elementwise mul op"); | ||
AddInput("Y", "The second input of elementwise mul op"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the supported shape for input X and Y.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
AddInput("X", "The first input of elementwise mul op"); | ||
AddInput("Y", "The second input of elementwise mul op"); | ||
AddAttr<int>("axis", "Optional input parameter of elementwise mul op"); | ||
AddAttr<int>("broadcast", "Optional input parameter of elementwise mul op"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to describe the meaning and usage for axis and broadcast attr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从下面单测来看 axis像是y在x中的start_axis,需要更详细的说明。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
using Tensor = framework::Tensor; | ||
template <typename T, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From googlge code style:
https://google.github.io/styleguide/cppguide.html#Namespaces
Do not use Namespace aliases at namespace scope in header files except in explicitly marked internal-only namespaces, because anything imported into a namespace in a header file becomes part of the public API exported by that file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
// TODO(gongweibao): if axis is optional? | ||
bool broadcast = ctx.template Attr<int>("broadcast"); | ||
PADDLE_ENFORCE(broadcast, "Do you forget broadcast parameter?"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从这里实现来看 broadcast 要始终为True, 为什么还要加这个参数? 另外这个参数作用是啥?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast和axis都是optional的参数。
paddle/operators/elementwise_op.h
Outdated
|
||
inline void get_slice(const framework::DDim& x_dims, | ||
const framework::DDim& y_dims, const int axis, int& pre, | ||
int& n, int& post) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从下面实现来看,不是给x_dims, y_dims做slice得, 叫这个名字不是很理解。另外,单独放到ementwise_op.h
里是否合适?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果暂时只有这一个op用到,觉得先放到 elementwise_mul_op.h
中,后续增加新的op,再移出来都行。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), | ||
"Axis should be in range [0, x_dims)"); | ||
|
||
int pre, n, post; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是特别理解pre, n, post 的意思, 能够加上comments吗? 或者更明确的name吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out = X ⊙ Y
1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
pre=2, n=3*4,post=5
2.shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
pre=2*3, n=4*5,post=1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经加入到comment。
""" | ||
self.compare_grad(self.op, self.inputs) | ||
self.check_grad( | ||
self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_relative_error能调小一些吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
if (x_dims == y_dims || product(y_dims) == 1) { | ||
dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e; | ||
dy_e.device(ctx.GetEigenDevice<Place>()) = x_e * dout_e; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (dx) dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e;
if (dy) dy_e.device(ctx.GetEigenDevice<Place>()) = x_e * dout_e;
需要处理某个输入可能不用计算grad的情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
from paddle.v2.framework.op import Operator | ||
|
||
|
||
class TestElementwiseMulOp_Matrix(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要先merge一下develop分支,因为目前的operator的测试框架有一些调整,也需要对应修改一下单元测试的写法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
"broadcast", | ||
R"DOC(When broadcast is set, Y will be broadcast to match shape of X. | ||
)DOC") | ||
.SetDefault(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast
can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经删除掉了。
* Out = X ⊙ Y | ||
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 | ||
* pre=2, n=3*4, post=5 | ||
* 2.shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return; | ||
} | ||
|
||
int axis = ctx.template Attr<int>("axis"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctx.Attr<int>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
AddOutput("Out", "The output of elementwise mul op"); | ||
AddComment(R"DOC( | ||
Limited elementwise multiple operator.The equation is: Out = X ⊙ Y. | ||
1. The shape of Y should be same of X or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape of Y should be same with X
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} | ||
if (dy) { | ||
dy->mutable_data<T>(ctx.GetPlace()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
short lines for line 119 - 124:
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://google.github.io/styleguide/cppguide.html#Conditionals
这个应该不是强制的。
主要是gdb调试的时候比较郁闷,不知道if后边的语句到底是不是执行了。
|
||
if (y_grad) { | ||
y_grad->Resize(y_dims); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for shorter lines:
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
return; | ||
} | ||
|
||
int axis = ctx.template Attr<int>("axis"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctx.Attr<int>("axis");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
* Update bbox_utils.py
Fix #3713