diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index e2689cefd43a7..806a81fc20760 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -275,5 +275,32 @@ struct MulGradXYFunctor, Complex> { } }; +// Ternary compare +template +struct MaxGradXFunctor { + inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const { + return dout * static_cast(x > y); + } +}; +template +struct MaxGradYFunctor { + inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const { + return dout * static_cast(x <= y); + } +}; + +template +struct MaxGradXYFunctor { + inline HOSTDEVICE paddle::framework::Array operator()( + const InT& x, const InT& y, const InT& dout) { + paddle::framework::Array outs; + // dx = dout * (x > y) + outs[0] = static_cast(dout * static_cast(x > y)); + // dy = dout * (x <= y) + outs[1] = static_cast(dout * static_cast(x <= y)); + return outs; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 760429200889b..eaf7774428565 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -24,15 +24,41 @@ class ElementwiseMaxKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - const auto& cuda_ctx = + const auto& dev_ctx = ctx.template device_context(); int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, MaxFunctor()); + dev_ctx, ins, &outs, axis, MaxFunctor()); } }; +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMaxGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + const auto& dev_ctx = + ctx.template device_context(); + const auto place = ctx.GetPlace(); + if (dx != nullptr && dy != nullptr) { + std::vector ins = {x, y, dout}; + GetGradXAndYOut( + dev_ctx, place, axis, ins, dout, dx, dy, MaxGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + std::vector ins = {x, y, dout}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dx, MaxGradXFunctor()); + } else if (dx == nullptr && dy != nullptr) { + std::vector ins = {x, y, dout}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, MaxGradYFunctor()); + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index a7a49fed87151..cff30be50a3d1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -64,6 +64,28 @@ struct MaxGradDy { } }; +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMaxGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, MaxGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx(), MaxGradDy()); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMaxGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy); +#endif + template class ElementwiseMaxGradKernel : public ElemwiseGradKernel { public: @@ -74,12 +96,11 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* out = dout; // out is not necessary auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - auto* out = dout; // Fake out, not used - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, MaxGradDy>( - ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx(), MaxGradDy()); + + ElementwiseMaxGrad(ctx, x, y, out, dout, dx, dy); } };