Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cumprod #7278

Merged
merged 24 commits into from
Jan 24, 2022
Merged

Cumprod #7278

Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,6 @@ oneflow
read_onerec,
from_numpy,
cumsum,
cumprod,

.. autofunction:: oneflow.relu
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ limitations under the License.
namespace oneflow {
namespace one {

struct CumsumCaptureState : public AutoGradCaptureState {
struct CumCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
int64_t dim = 0;
};

class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> {
template<typename StateT>
class CumGrad : public OpExprGradFunction<StateT> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand All @@ -33,8 +34,14 @@ class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> {
return Maybe<void>::Ok();
}

Maybe<void> Capture(CumsumCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
protected:
AttrMap base_attrs_;
};

class CumsumGrad : public CumGrad<CumCaptureState> {
public:
Maybe<void> Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -43,8 +50,7 @@ class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> {
ctx->dim = JUST(composed_attrs.GetAttr<int64_t>("dim"));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const CumsumCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const CumCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -53,12 +59,38 @@ class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> {
}
return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("cumsum", CumsumGrad);

class CumProdGrad : public CumGrad<CumCaptureState> {
public:
Maybe<void> Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int64_t>("dim"));
ctx->SaveTensorForBackward(outputs.at(0));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const CumCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::CumprodGrad(out_grads.at(0), ctx->SavedTensors().at(0),
ctx->SavedTensors().at(1), ctx->dim));
}
return Maybe<void>::Ok();
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("cumprod", CumProdGrad);

} // namespace one
} // namespace oneflow
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1914,3 +1914,11 @@
- name: "cumsum_grad"
signature: "Tensor (Tensor input, Int64 dim) => CumsumGrad"
bind_python: False

- name: "cumprod"
signature: "Tensor (Tensor input, Int64 dim) => Cumprod"
bind_python: True

- name: "cumprod_grad"
signature: "Tensor (Tensor input, Tensor y, Tensor x, Int64 dim) => CumprodGrad"
bind_python: False
47 changes: 42 additions & 5 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1782,9 +1782,11 @@ class ErfinvInplaceFunctor {
std::shared_ptr<OpExpr> op_;
};

class CumsumFunctor {
class CumBaseFunctor {
public:
CumsumFunctor() { op_ = CHECK_JUST(one::OpBuilder("cumsum").Input("x").Output("y").Build()); }
explicit CumBaseFunctor(std::string op_name) {
op_ = CHECK_JUST(one::OpBuilder(op_name).Input("x").Output("y").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t dim) const {
auto ndim = input->ndim();
if (dim < 0) { dim += ndim; }
Expand All @@ -1804,7 +1806,22 @@ class CumsumFunctor {
std::shared_ptr<OpExpr> op_;
};

class CumsumGradFunctor {
class CumsumFunctor : public CumBaseFunctor {
public:
CumsumFunctor() : CumBaseFunctor("cumsum") {}
};

class CumProdFunctor : public CumBaseFunctor {
public:
CumProdFunctor() : CumBaseFunctor("cumprod") {}
};

class CumGradBaseFunctor {
protected:
std::shared_ptr<OpExpr> op_;
};

class CumsumGradFunctor : public CumGradBaseFunctor {
public:
CumsumGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("cumsum_grad").Input("dy").Output("dx").Build());
Expand All @@ -1815,10 +1832,28 @@ class CumsumGradFunctor {
JUST(attrs.SetAttr<int64_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);
}
};

private:
std::shared_ptr<OpExpr> op_;
class CumProdGradFunctor : public CumGradBaseFunctor {
public:
CumProdGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("cumprod_grad")
.Input("dy")
.Input("output")
.Input("input")
.Output("dx")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& y,
const std::shared_ptr<one::Tensor>& x, int64_t dim) const {
// No need to check dim validation here, while CumProdFunctor handled already
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, y, x}, attrs);
}
};

} // namespace impl

using namespace impl;
Expand Down Expand Up @@ -1890,6 +1925,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<ErfinvInplaceFunctor>("ErfinvInplace");
m.add_functor<CumsumFunctor>("Cumsum");
m.add_functor<CumsumGradFunctor>("CumsumGrad");
m.add_functor<CumProdFunctor>("Cumprod");
m.add_functor<CumProdGradFunctor>("CumprodGrad");
};

} // namespace functional
Expand Down
34 changes: 34 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4007,6 +4007,40 @@ def OneFlow_CumsumGradOp : OneFlow_BaseOp<"cumsum_grad", [NoSideEffect, DeclareO
let has_data_type_infer_fn = 1;
}

def OneFlow_CumProdOp : OneFlow_BaseOp<"cumprod", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
);
let output = (outs
OneFlow_Tensor:$y
);
let attrs = (ins
SI64Attr:$dim
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_CumProdGradOp : OneFlow_BaseOp<"cumprod_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
OneFlow_Tensor:$output,
OneFlow_Tensor:$input
);
let output = (outs
OneFlow_Tensor:$dx
);
let attrs = (ins
SI64Attr:$dim
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_ErfInvOp : OneFlow_BaseOp<"erfinv", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
Expand Down
Loading