Skip to content

Commit

Permalink
Cumprod (#7278)
Browse files Browse the repository at this point in the history
* add comprod

* fix name

* rename

* fix when specified dim is 1

* add docstr

* add docstr

* refine

* add WITH_CUDA

* refine

* refine

* Update python/oneflow/framework/docstr/math_ops.py

fix docstr

Co-authored-by: Yao Chi <later@usopp.net>

* Update python/oneflow/framework/docstr/math_ops.py

fix docstr

Co-authored-by: Yao Chi <later@usopp.net>

* fix docstr

* refine

* refine

* fix include

* refine

* refine

* refine

Co-authored-by: Yao Chi <later@usopp.net>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 24, 2022
1 parent 04a83e2 commit 3d6d467
Show file tree
Hide file tree
Showing 15 changed files with 931 additions and 355 deletions.
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,6 @@ oneflow
decode_onerec,
from_numpy,
cumsum,
cumprod,

.. autofunction:: oneflow.relu
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ limitations under the License.
namespace oneflow {
namespace one {

struct CumsumCaptureState : public AutoGradCaptureState {
struct CumCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
int64_t dim = 0;
};

class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> {
template<typename StateT>
class CumGrad : public OpExprGradFunction<StateT> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand All @@ -33,8 +34,14 @@ class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> {
return Maybe<void>::Ok();
}

Maybe<void> Capture(CumsumCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
protected:
AttrMap base_attrs_;
};

class CumsumGrad : public CumGrad<CumCaptureState> {
public:
Maybe<void> Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -43,8 +50,7 @@ class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> {
ctx->dim = JUST(composed_attrs.GetAttr<int64_t>("dim"));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const CumsumCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const CumCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -53,12 +59,38 @@ class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> {
}
return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("cumsum", CumsumGrad);

class CumProdGrad : public CumGrad<CumCaptureState> {
public:
Maybe<void> Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int64_t>("dim"));
ctx->SaveTensorForBackward(outputs.at(0));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const CumCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::CumprodGrad(out_grads.at(0), ctx->SavedTensors().at(0),
ctx->SavedTensors().at(1), ctx->dim));
}
return Maybe<void>::Ok();
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("cumprod", CumProdGrad);

} // namespace one
} // namespace oneflow
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1919,3 +1919,11 @@
- name: "cumsum_grad"
signature: "Tensor (Tensor input, Int64 dim) => CumsumGrad"
bind_python: False

- name: "cumprod"
signature: "Tensor (Tensor input, Int64 dim) => Cumprod"
bind_python: True

- name: "cumprod_grad"
signature: "Tensor (Tensor input, Tensor y, Tensor x, Int64 dim) => CumprodGrad"
bind_python: False
47 changes: 42 additions & 5 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1815,9 +1815,11 @@ class ErfinvInplaceFunctor {
std::shared_ptr<OpExpr> op_;
};

class CumsumFunctor {
class CumBaseFunctor {
public:
CumsumFunctor() { op_ = CHECK_JUST(one::OpBuilder("cumsum").Input("x").Output("y").Build()); }
explicit CumBaseFunctor(std::string op_name) {
op_ = CHECK_JUST(one::OpBuilder(op_name).Input("x").Output("y").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t dim) const {
auto ndim = input->ndim();
if (dim < 0) { dim += ndim; }
Expand All @@ -1837,7 +1839,22 @@ class CumsumFunctor {
std::shared_ptr<OpExpr> op_;
};

class CumsumGradFunctor {
class CumsumFunctor : public CumBaseFunctor {
public:
CumsumFunctor() : CumBaseFunctor("cumsum") {}
};

class CumProdFunctor : public CumBaseFunctor {
public:
CumProdFunctor() : CumBaseFunctor("cumprod") {}
};

class CumGradBaseFunctor {
protected:
std::shared_ptr<OpExpr> op_;
};

class CumsumGradFunctor : public CumGradBaseFunctor {
public:
CumsumGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("cumsum_grad").Input("dy").Output("dx").Build());
Expand All @@ -1848,10 +1865,28 @@ class CumsumGradFunctor {
JUST(attrs.SetAttr<int64_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);
}
};

private:
std::shared_ptr<OpExpr> op_;
class CumProdGradFunctor : public CumGradBaseFunctor {
public:
CumProdGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("cumprod_grad")
.Input("dy")
.Input("output")
.Input("input")
.Output("dx")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& y,
const std::shared_ptr<one::Tensor>& x, int64_t dim) const {
// No need to check dim validation here, while CumProdFunctor handled already
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, y, x}, attrs);
}
};

} // namespace impl

using namespace impl;
Expand Down Expand Up @@ -1923,6 +1958,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<ErfinvInplaceFunctor>("ErfinvInplace");
m.add_functor<CumsumFunctor>("Cumsum");
m.add_functor<CumsumGradFunctor>("CumsumGrad");
m.add_functor<CumProdFunctor>("Cumprod");
m.add_functor<CumProdGradFunctor>("CumprodGrad");
};

} // namespace functional
Expand Down
34 changes: 34 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4054,6 +4054,40 @@ def OneFlow_CumsumGradOp : OneFlow_BaseOp<"cumsum_grad", [NoSideEffect, DeclareO
let has_data_type_infer_fn = 1;
}

def OneFlow_CumProdOp : OneFlow_BaseOp<"cumprod", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
);
let output = (outs
OneFlow_Tensor:$y
);
let attrs = (ins
SI64Attr:$dim
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_CumProdGradOp : OneFlow_BaseOp<"cumprod_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
OneFlow_Tensor:$output,
OneFlow_Tensor:$input
);
let output = (outs
OneFlow_Tensor:$dx
);
let attrs = (ins
SI64Attr:$dim
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_ErfInvOp : OneFlow_BaseOp<"erfinv", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
Expand Down
Loading

0 comments on commit 3d6d467

Please sign in to comment.