From a1331f987775de1ee24aa95cee4c2ffef46fddb8 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 6 Mar 2018 23:36:38 +0800 Subject: [PATCH] refine elementwise_mul_op --- paddle/fluid/operators/elementwise_mul_op.h | 83 ++----------------- .../fluid/operators/elementwise_op_function.h | 2 +- 2 files changed, 9 insertions(+), 76 deletions(-) diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise_mul_op.h index 46d69ed87d4c2..e2b59b3112096 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise_mul_op.h @@ -40,80 +40,14 @@ class ElementwiseMulKernel : public framework::OpKernel { }; template -struct ElementwiseMulGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e * y_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = x_e * dz_e; - } - } +struct IdentityGrad_DX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } }; template -struct ElementwiseMulBroadCastGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) - .broadcast(Eigen::DSizes(pre, 1)) - .reshape(Eigen::DSizes(x_e.size())); - - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e * y_e_bcast; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (x_e * dz_e) - .reshape(Eigen::DSizes(pre, n)) - .sum(Eigen::array{{0}}); - } - } +struct IdentityGrad_DY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } }; - -template -struct ElementwiseMulBroadCast2GradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) - .broadcast(Eigen::DSizes(pre, 1, post)) - .reshape(Eigen::DSizes(x_e.size())); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e * y_e_bcast; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (x_e * dz_e) - .reshape(Eigen::DSizes(pre, n, post)) - .sum(Eigen::array{{0, 2}}); - } - } -}; - template class ElementwiseMulGradKernel : public framework::OpKernel { public: @@ -127,12 +61,11 @@ class ElementwiseMulGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - ElementwiseGradCompute, - ElementwiseMulBroadCastGradFunctor, - ElementwiseMulBroadCast2GradFunctor>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute, + IdentityGrad_DY>(ctx, *x, *y, *out, *dout, axis, dx, + dy, IdentityGrad_DX(), + IdentityGrad_DY()); } }; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index ffda53a383ced..0b4238436ffcc 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -301,7 +301,7 @@ struct ElemwiseGradNoBroadcast { dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); } if (dy_ != nullptr) { - dy_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); + dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]); } }