Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaishaonvjituizi committed May 3, 2022
1 parent 518c75a commit e94f42c
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions paddle/fluid/operators/cum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,6 @@ class CumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
};

class CumGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "logsumexp");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -115,6 +102,19 @@ the input. If exlusive is true, the first element of the result is the minimum v
}
};

class LogcumsumexpGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logcumsumexp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logcumsumexp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "logcumsumexp");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

template <typename T>
class LogcumsumexpGradMaker : public framework::SingleGradOpMaker<T> {
public:
Expand Down Expand Up @@ -142,15 +142,17 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor,
PD_INFER_META(phi::CumInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor,
PD_INFER_META(phi::CumInferMeta));
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker,
ops::LogcumsumexpGradMaker<paddle::framework::OpDesc>,
ops::LogcumsumexpGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp_grad, ops::CumGradOp);
LogcumsumexpInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp);

REGISTER_OP_VERSION(cumsum)
.AddCheckpoint(
Expand Down

0 comments on commit e94f42c

Please sign in to comment.