Skip to content

Commit

Permalink
implement logcumsumexp
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaishaonvjituizi committed Apr 26, 2022
1 parent fccb081 commit 5c3b6bb
Show file tree
Hide file tree
Showing 16 changed files with 814 additions and 221 deletions.
68 changes: 67 additions & 1 deletion paddle/fluid/operators/cumsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ class CumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
};

class CumGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "logsumexp");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -74,17 +87,70 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
}
};

class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of logcumsumexp operator");
AddOutput("Out", "Output of logcumsumexp operator");
AddAttr<int>("axis",
"The dimension to accumulate along. -1 means the last "
"dimension [default -1].")
.SetDefault(-1);
AddAttr<bool>("flatten",
"Whether to compute the logcumsumexp over the flattened array. "
"[default false].")
.SetDefault(false);
AddAttr<bool>("exclusive",
"Whether to perform exclusive logcumsumexp. [default false].")
.SetDefault(false);
AddAttr<bool>("reverse",
"If true, the logcumsumexp is performed in the reversed direction. "
"[default false].")
.SetDefault(false);
AddComment(R"DOC(
Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis.
By default, the first element of the result is the same of the first element of
the input. If exlusive is true, the first element of the result is the minimum value of dtype.
)DOC");
}
};

template <typename T>
class LogcumsumexpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("logcumsumexp_grad");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis")));
grad_op->SetAttr("flatten",
BOOST_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttr("reverse",
BOOST_GET_CONST(bool, this->GetAttr("reverse")));
grad_op->SetAttr("exclusive",
BOOST_GET_CONST(bool, this->GetAttr("exclusive")));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor,
PD_INFER_META(phi::CumsumInferMeta));
PD_INFER_META(phi::CumInferMeta));
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker,
ops::LogcumsumexpGradMaker<paddle::framework::OpDesc>,
ops::LogcumsumexpGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp_grad, ops::CumGradOp);

REGISTER_OP_VERSION(cumsum)
.AddCheckpoint(
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_layout(x.layout());
}

void CumsumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out) {
void CumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out) {
auto x_dims = x.dims();
if (flatten) {
out->set_dims(phi::make_ddim({phi::product(x_dims)}));
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);

void CumsumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out);
void CumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out);

void DiagInferMeta(const MetaTensor& x,
int offset,
Expand Down
Loading

0 comments on commit 5c3b6bb

Please sign in to comment.