From e94f42c001a0450a5fca2c3cc476e8d5ad0be048 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 3 May 2022 21:48:45 +0800 Subject: [PATCH] update --- paddle/fluid/operators/cum_op.cc | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 592adf8971c67..7043d47a26b1e 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -26,19 +26,6 @@ class CumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; }; -class CumGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp"); - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "logsumexp"); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } -}; - class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -115,6 +102,19 @@ the input. If exlusive is true, the first element of the result is the minimum v } }; +class LogcumsumexpGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logcumsumexp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logcumsumexp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "logcumsumexp"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + template class LogcumsumexpGradMaker : public framework::SingleGradOpMaker { public: @@ -142,6 +142,8 @@ namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor, PD_INFER_META(phi::CumInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor, + PD_INFER_META(phi::CumInferMeta)); REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, ops::CumsumGradMaker, ops::CumsumGradMaker, @@ -149,8 +151,8 @@ REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker, ops::LogcumsumexpGradMaker, ops::LogcumsumexpGradMaker, - CumsumInferShapeFunctor); -REGISTER_OPERATOR(logcumsumexp_grad, ops::CumGradOp); + LogcumsumexpInferShapeFunctor); +REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp); REGISTER_OP_VERSION(cumsum) .AddCheckpoint(