Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add element-wise multiplication operator. #3787

Merged
merged 20 commits into from
Sep 13, 2017

Conversation

gongweibao
Copy link
Contributor

Fix #3713

@gongweibao gongweibao changed the title Add element-wise multiply operator. Add element-wise multiple operator. Aug 31, 2017
@@ -57,6 +57,7 @@ op_library(add_op SRCS add_op.cc add_op.cu)
op_library(mean_op SRCS mean_op.cc mean_op.cu)

op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(element_wise_mul_op SRCS element_wise_mul_op.cc element_wise_mul_op.cu DEPS math_function)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove DEPS math_function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
Done.

auto Y_e = framework::EigenVector<T>::Flatten(*Y);
auto Z_e = framework::EigenVector<T>::Flatten(*Z);

Z_e.device(context.GetEigenDevice<Place>()) = X_e * Y_e;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names of variables are all lowercas. https://google.github.io/styleguide/cppguide.html#Variable_Names

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to support DotMulProjection , where the second input is a row vector,
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/gserver/layers/DotMulProjection.cpp#L57

The broadcast of Eigen can be used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are needed.

@qingqing01 qingqing01 requested a review from pkuyym August 31, 2017 13:21
@@ -57,6 +57,7 @@ op_library(add_op SRCS add_op.cc add_op.cu)
op_library(mean_op SRCS mean_op.cc mean_op.cu)

op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(element_wise_mul_op SRCS element_wise_mul_op.cc element_wise_mul_op.cu DEPS math_function)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elementwise is one word.
elementwise_mul_op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
Done.

auto Y_e = framework::EigenVector<T>::Flatten(*Y);
auto Z_e = framework::EigenVector<T>::Flatten(*Z);

Z_e.device(context.GetEigenDevice<Place>()) = X_e * Y_e;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are needed.

} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(elemwisemul, ops::ElemWiseMulOp, ops::ElemWiseMulOpMaker,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elemwisemul => elementwise_mul

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
Done.

auto dOut_e = framework::EigenVector<T>::Flatten(*dOut);

dX_e.device(ctx.GetEigenDevice<Place>()) = dOut_e * Y_e;
dY_e.device(ctx.GetEigenDevice<Place>()) = X_e * dOut_e;
Copy link
Collaborator

@emailweixu emailweixu Aug 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to support the case where the gradient dX_e or dY_e is not needed.
@qingqing01 Most of the gradient operator need to handle the case where is one of more of the gradient of the input is not needed. Please change the unittest framework to explicitly test whether the op correctly handles this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@emailweixu Ok, I'll change the unit test framework to handle this.

@emailweixu emailweixu changed the title Add element-wise multiple operator. Add element-wise multiplication operator. Aug 31, 2017
@gongweibao gongweibao changed the title Add element-wise multiplication operator. [WIP]Add element-wise multiplication operator. Sep 10, 2017
@gongweibao gongweibao changed the title [WIP]Add element-wise multiplication operator. Add element-wise multiplication operator. Sep 10, 2017
recurrent_op
scale_op)
op_library(identity_op DEPS scale_op)
op_library(minus_op DEPS scale_op)
op_library(mul_op DEPS math_function)
op_library(elementwise_mul_op SRCS elementwise_mul_op.cc elementwise_mul_op.cu)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line 57 and line 63. There is no DEPS for this op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PADDLE_ENFORCE_NOT_NULL is needed for input X and Y, similar check in ElementWiseMulOpGrad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of elementwise mul op");
AddInput("Y", "The second input of elementwise mul op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the supported shape for input X and Y.

#3885 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

AddInput("X", "The first input of elementwise mul op");
AddInput("Y", "The second input of elementwise mul op");
AddAttr<int>("axis", "Optional input parameter of elementwise mul op");
AddAttr<int>("broadcast", "Optional input parameter of elementwise mul op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to describe the meaning and usage for axis and broadcast attr.

#3885 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从下面单测来看 axis像是y在x中的start_axis,需要更详细的说明。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From googlge code style:

https://google.github.io/styleguide/cppguide.html#Namespaces

Do not use Namespace aliases at namespace scope in header files except in explicitly marked internal-only namespaces, because anything imported into a namespace in a header file becomes part of the public API exported by that file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// TODO(gongweibao): if axis is optional?
bool broadcast = ctx.template Attr<int>("broadcast");
PADDLE_ENFORCE(broadcast, "Do you forget broadcast parameter?");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从这里实现来看 broadcast 要始终为True, 为什么还要加这个参数? 另外这个参数作用是啥?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast和axis都是optional的参数。


inline void get_slice(const framework::DDim& x_dims,
const framework::DDim& y_dims, const int axis, int& pre,
int& n, int& post) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从下面实现来看,不是给x_dims, y_dims做slice得, 叫这个名字不是很理解。另外,单独放到ementwise_op.h里是否合适?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果暂时只有这一个op用到,觉得先放到 elementwise_mul_op.h 中,后续增加新的op,再移出来都行。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");

int pre, n, post;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是特别理解pre, n, post 的意思, 能够加上comments吗? 或者更明确的name吗?

Copy link
Contributor Author

@gongweibao gongweibao Sep 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out = X ⊙ Y
1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
    pre=2, n=3*4,post=5
2.shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
    pre=2*3, n=4*5,post=1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经加入到comment。

"""
self.compare_grad(self.op, self.inputs)
self.check_grad(
self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_relative_error能调小一些吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


if (x_dims == y_dims || product(y_dims) == 1) {
dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e;
dy_e.device(ctx.GetEigenDevice<Place>()) = x_e * dout_e;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (dx) dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e;
if (dy) dy_e.device(ctx.GetEigenDevice<Place>()) = x_e * dout_e;

需要处理某个输入可能不用计算grad的情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

from paddle.v2.framework.op import Operator


class TestElementwiseMulOp_Matrix(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要先merge一下develop分支,因为目前的operator的测试框架有一些调整,也需要对应修改一下单元测试的写法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"broadcast",
R"DOC(When broadcast is set, Y will be broadcast to match shape of X.
)DOC")
.SetDefault(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除掉了。

* Out = X ⊙ Y
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* 2.shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
Copy link
Contributor

@qingqing01 qingqing01 Sep 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return;
}

int axis = ctx.template Attr<int>("axis");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.Attr<int>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

AddOutput("Out", "The output of elementwise mul op");
AddComment(R"DOC(
Limited elementwise multiple operator.The equation is: Out = X ⊙ Y.
1. The shape of Y should be same of X or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape of Y should be same with X

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

short lines for line 119 - 124:

if (dx)  dx->mutable_data<T>(ctx.GetPlace());
if (dy)  dy->mutable_data<T>(ctx.GetPlace());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://google.github.io/styleguide/cppguide.html#Conditionals
这个应该不是强制的。

主要是gdb调试的时候比较郁闷,不知道if后边的语句到底是不是执行了。


if (y_grad) {
y_grad->Resize(y_dims);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for shorter lines:

if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

return;
}

int axis = ctx.template Attr<int>("axis");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.Attr<int>("axis");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@gongweibao gongweibao merged commit 8778957 into PaddlePaddle:develop Sep 13, 2017
@gongweibao gongweibao deleted the addop branch September 13, 2017 06:18
@kuke kuke mentioned this pull request Sep 14, 2017
@gongweibao gongweibao restored the addop branch September 15, 2017 03:38
@gongweibao gongweibao deleted the addop branch January 17, 2021 07:41
heavengate pushed a commit to heavengate/Paddle that referenced this pull request Aug 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants