Skip to content

Commit

Permalink
gen_bw_fn return maybe (#5454)
Browse files Browse the repository at this point in the history
* modified SetInputArgModifyFn

* Delete the CHECK changes in the assign_op.cpp file

* Format

* Modified the OutputArgModifyFn interface

* add return

* maybe error stack from CheckAndConstructOp to OutputArgModifier callback function

* maybe error stack from CheckAndConstructOp to OutputArgModifier callback function

* OutputArgModifier return maybe part_1

* maybe error stack from CheckAndConstructOp to OutputArgModifier callback function

* input_arg_modifier return maybe

* gen_bw_fn return maybe

* add MakeGenBackwardOpConf because ofstatement-expression not allowed outside function resulting in JUST failed in lambda

* add maybe after merge master

* fix bug: JUST in lambda

Co-authored-by: aishangjj <702572275@qq.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 17, 2021
1 parent bffde6d commit de29655
Show file tree
Hide file tree
Showing 64 changed files with 318 additions and 159 deletions.
2 changes: 1 addition & 1 deletion oneflow/core/framework/user_op_grad_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace oneflow {
namespace user_op {

using AddOpFn = std::function<void(const UserOpConfWrapper&)>;
using GenBackwardOpConfFn = std::function<void(const UserOpWrapper&, AddOpFn)>;
using GenBackwardOpConfFn = std::function<Maybe<void>(const UserOpWrapper&, AddOpFn)>;
using BackwardOpConfGenFn = std::function<void(BackwardOpConfContext*)>;

struct OpGradRegistryResult {
Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/job_rewriter/user_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Maybe<void> GenerateBackwardOpConf(
const Operator& fw_op, std::vector<OperatorConf>* bw_op_confs,
const std::function<LogicalBlobId*(const std::string&)>& DiffLbi4BnInOp,
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4BnInOp) {
CHECK(fw_op.op_conf().has_user_conf());
CHECK_OR_RETURN(fw_op.op_conf().has_user_conf());
const UserOpConf& user_conf = fw_op.op_conf().user_conf();
const user_op::OpGradRegistryResult* val =
user_op::UserOpRegistryMgr::Get().GetOpGradRegistryResult(user_conf.op_type_name());
Expand All @@ -43,13 +43,13 @@ Maybe<void> GenerateBackwardOpConf(
auto AddOp = [&](const user_op::UserOpConfWrapper& wrapper) {
bw_op_confs->push_back(wrapper.op_conf());
};
val->gen_bw_fn(fw_user_op, AddOp);
JUST(val->gen_bw_fn(fw_user_op, AddOp));
}

for (const std::string& ibn : fw_op.input_bns()) {
LogicalBlobId* lbi = DiffLbi4BnInOp(ibn);
if (lbi != nullptr) {
CHECK(lbi->has_op_name() && lbi->has_blob_name())
CHECK_OR_RETURN(lbi->has_op_name() && lbi->has_blob_name())
<< " user_op: " << fw_op.op_name() << " op_type_name: " << user_conf.op_type_name()
<< " 's input blob " << ibn << " has not generate input diff blob !";
}
Expand Down
1 change: 0 additions & 1 deletion oneflow/user/kernels/pool_gpu_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class GPUPoolOpKernelState final : public user_op::OpKernelState {
static std::shared_ptr<GPUPoolOpKernelState> FromKernelComputeContext(
const int32_t& dim, const std::string& pooling_type, user_op::KernelComputeContext* ctx) {
if (pooling_type != "MAX" && pooling_type != "AVG") { UNIMPLEMENTED(); }
const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const ShapeView& x_shape = ctx->Tensor4ArgNameAndIndex("x", 0)->shape();
const std::string& data_format = ctx->Attr<std::string>("data_format");
const std::string& padding = ctx->Attr<std::string>("padding");
Expand Down
3 changes: 2 additions & 1 deletion oneflow/user/ops/add_n_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ REGISTER_USER_OP("add_n")
});

REGISTER_USER_OP_GRAD("add_n").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
user_op::AddOpFn AddOp) -> Maybe<void> {
int32_t in_size = op.input_size("in");
for (int i = 0; i < in_size; ++i) {
if (op.NeedGenGradTensor4OpInput("in", i)) {
op.BindGradTensorWithOpInput(op.GetGradTensorWithOpOutput("out", 0), "in", i);
}
}
return Maybe<void>::Ok();
});

} // namespace oneflow
4 changes: 3 additions & 1 deletion oneflow/user/ops/amp_white_identity_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ REGISTER_USER_OP("amp_white_identity")
});

REGISTER_USER_OP_GRAD("amp_white_identity")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -56,6 +57,7 @@ REGISTER_USER_OP_GRAD("amp_white_identity")
op.BindGradTensorWithOpInput(grad_op.output("out", 0), "in", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace
Expand Down
4 changes: 3 additions & 1 deletion oneflow/user/ops/batch_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ REGISTER_USER_OP("batch_gather")
});

REGISTER_USER_OP_GRAD("batch_gather")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
bool need_grad_in = op.NeedGenGradTensor4OpInput("in", 0);
if (need_grad_in) {
const Shape in_shape = op.TensorDesc4ArgNameAndIndex("in", 0).shape();
Expand All @@ -102,6 +103,7 @@ REGISTER_USER_OP_GRAD("batch_gather")
op.BindGradTensorWithOpInput(in_grad_op.output("out", 0), "in", 0);
AddOp(in_grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
4 changes: 3 additions & 1 deletion oneflow/user/ops/bias_add_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ REGISTER_USER_OP("bias_add")
});

REGISTER_USER_OP_GRAD("bias_add")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("a", 0)) {
op.BindGradTensorWithOpInput(op.GetGradTensorWithOpOutput("out", 0), "a", 0);
}
Expand All @@ -79,6 +80,7 @@ REGISTER_USER_OP_GRAD("bias_add")
AddOp(grad_op);
op.BindGradTensorWithOpInput(grad_op.output("output_tensor", 0), "b", 0);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
16 changes: 12 additions & 4 deletions oneflow/user/ops/broadcast_ops_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ std::string CreateReduceSumLikeBlob(const std::string& in_lbn, const Shape& in_s
} // namespace

REGISTER_USER_OP_GRAD("broadcast_add")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
const Shape& z_shape = op.TensorDesc4ArgNameAndIndex("z", 0).shape();
const std::string& dz_lbn = op.GetGradTensorWithOpOutput("z", 0);
if (op.NeedGenGradTensor4OpInput("x", 0)) {
Expand All @@ -71,10 +72,12 @@ REGISTER_USER_OP_GRAD("broadcast_add")
CreateReduceSumLikeBlob(dz_lbn, z_shape, y_lbn, y_shape, op.op_name() + "_y", AddOp);
op.BindGradTensorWithOpInput(out_lbn, "y", 0);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("broadcast_sub")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
const Shape& z_shape = op.TensorDesc4ArgNameAndIndex("z", 0).shape();
const std::string& dz_lbn = op.GetGradTensorWithOpOutput("z", 0);
if (op.NeedGenGradTensor4OpInput("x", 0)) {
Expand Down Expand Up @@ -102,10 +105,12 @@ REGISTER_USER_OP_GRAD("broadcast_sub")
scalar_mul_op.output("out", 0), z_shape, y_lbn, y_shape, op.op_name() + "_y", AddOp);
op.BindGradTensorWithOpInput(out_lbn, "y", 0);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("broadcast_mul")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
const Shape& z_shape = op.TensorDesc4ArgNameAndIndex("z", 0).shape();
const std::string& dz_lbn = op.GetGradTensorWithOpOutput("z", 0);
if (op.NeedGenGradTensor4OpInput("x", 0)) {
Expand Down Expand Up @@ -136,10 +141,12 @@ REGISTER_USER_OP_GRAD("broadcast_mul")
broadcast_mul_op.output("z", 0), z_shape, y_lbn, y_shape, op.op_name() + "_y", AddOp);
op.BindGradTensorWithOpInput(out_lbn, "y", 0);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("broadcast_div")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
const std::string& dz_lbn = op.GetGradTensorWithOpOutput("z", 0);
if (op.NeedGenGradTensor4OpInput("x", 0)) {
const Shape& z_shape = op.TensorDesc4ArgNameAndIndex("z", 0).shape();
Expand Down Expand Up @@ -167,6 +174,7 @@ REGISTER_USER_OP_GRAD("broadcast_div")
op.BindGradTensorWithOpInput(grad_op.output("dy", 0), "y", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
3 changes: 2 additions & 1 deletion oneflow/user/ops/cast_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ REGISTER_USER_OP("cast")
.SetDataTypeInferFn(InferDataType);

REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
const DataType& dtype = op.TensorDesc4ArgNameAndIndex("in", 0).data_type();
Expand All @@ -64,6 +64,7 @@ REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWra
op.BindGradTensorWithOpInput(cast_grad_op.output("out", 0), "in", 0);
AddOp(cast_grad_op);
}
return Maybe<void>::Ok();
});

} // namespace
Expand Down
4 changes: 3 additions & 1 deletion oneflow/user/ops/cast_to_static_shape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ REGISTER_USER_OP("cast_to_static_shape")
});

REGISTER_USER_OP_GRAD("cast_to_static_shape")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("input", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper identity_op =
Expand All @@ -59,6 +60,7 @@ REGISTER_USER_OP_GRAD("cast_to_static_shape")
op.BindGradTensorWithOpInput(identity_op.output("out", 0), "input", 0);
AddOp(identity_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
12 changes: 9 additions & 3 deletions oneflow/user/ops/clip_by_value_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ REGISTER_USER_OP("clip_by_scalar_max_grad")
.SetDataTypeInferFn(InferClipGradDataType);

REGISTER_USER_OP_GRAD("clip_by_scalar")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -144,10 +145,12 @@ REGISTER_USER_OP_GRAD("clip_by_scalar")
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("clip_by_scalar_min")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -161,10 +164,12 @@ REGISTER_USER_OP_GRAD("clip_by_scalar_min")
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("clip_by_scalar_max")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -178,6 +183,7 @@ REGISTER_USER_OP_GRAD("clip_by_scalar_max")
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
4 changes: 3 additions & 1 deletion oneflow/user/ops/combined_margin_loss_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ REGISTER_USER_OP("combined_margin_loss_grad")
});

REGISTER_USER_OP_GRAD("combined_margin_loss")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op = builder.Op("combined_margin_loss_grad")
Expand All @@ -122,6 +123,7 @@ REGISTER_USER_OP_GRAD("combined_margin_loss")
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
3 changes: 2 additions & 1 deletion oneflow/user/ops/concat_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Maybe<void> GetSbpSignature(user_op::SbpContext* ctx) {
return Maybe<void>::Ok();
}

void GenGrapOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
Maybe<void> GenGrapOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
bool need_grad = false;
const int32_t in_size = op.input_size("in");
FOR_RANGE(int32_t, i, 0, in_size) {
Expand All @@ -90,6 +90,7 @@ void GenGrapOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
}
AddOp(grad_op);
}
return Maybe<void>::Ok();
}

Maybe<void> InferDataType(user_op::InferContext* ctx) {
Expand Down
7 changes: 4 additions & 3 deletions oneflow/user/ops/conv_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Maybe<void> CheckAttr(const user_op::UserOpDefWrapper& def,
}
}

void GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
Maybe<void> GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
const auto& padding_before = op.attr<std::vector<int32_t>>("padding_before");
std::string data_format = op.attr<std::string>("data_format");
std::vector<int32_t> kernel_size = op.attr<std::vector<int32_t>>("kernel_size");
Expand All @@ -168,8 +168,8 @@ void GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddO
int32_t groups = op.attr<int32_t>("groups");

int32_t ndims = kernel_size.size();
CHECK_EQ(ndims, strides.size());
CHECK_EQ(ndims, dilation_rate.size());
CHECK_EQ_OR_RETURN(ndims, strides.size());
CHECK_EQ_OR_RETURN(ndims, dilation_rate.size());

if (op.user_op_conf().has_input("bias", 0)) {
if (op.NeedGenGradTensor4OpInput("bias", 0)) {
Expand Down Expand Up @@ -224,6 +224,7 @@ void GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddO
op.BindGradTensorWithOpInput(data_grad_op.output("dx", 0), "in", 0);
AddOp(data_grad_op);
}
return Maybe<void>::Ok();
}

} // namespace
Expand Down
8 changes: 5 additions & 3 deletions oneflow/user/ops/deconv_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ Maybe<void> CheckAttr(const user_op::UserOpDefWrapper& def,
}
}

void GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
Maybe<void> GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
const std::string& data_format = op.attr<std::string>("data_format");
const auto& padding_before = op.attr<std::vector<int32_t>>("padding_before");
const auto& kernel_size = op.attr<std::vector<int32_t>>("kernel_size");
Expand All @@ -145,8 +146,8 @@ void GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, user_op::Ad
const Shape& weight_shape = op.TensorDesc4ArgNameAndIndex("weight", 0).shape();

const int32_t ndims = kernel_size.size();
CHECK_EQ(ndims, strides.size());
CHECK_EQ(ndims, dilation_rate.size());
CHECK_EQ_OR_RETURN(ndims, strides.size());
CHECK_EQ_OR_RETURN(ndims, dilation_rate.size());

if (op.NeedGenGradTensor4OpInput("weight", 0)) {
auto filter_grad_op =
Expand Down Expand Up @@ -186,6 +187,7 @@ void GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, user_op::Ad
op.BindGradTensorWithOpInput(data_grad_op.output("out", 0), "in", 0);
AddOp(data_grad_op);
}
return Maybe<void>::Ok();
}

} // namespace
Expand Down
3 changes: 2 additions & 1 deletion oneflow/user/ops/dropout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ REGISTER_USER_OP("dropout_grad")
});

REGISTER_USER_OP_GRAD("dropout").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper dropout_grad_op =
Expand All @@ -106,6 +106,7 @@ REGISTER_USER_OP_GRAD("dropout").SetGenBackwardOpConfFn([](const user_op::UserOp
op.BindGradTensorWithOpInput(dropout_grad_op.output("dx", 0), "in", 0);
AddOp(dropout_grad_op);
}
return Maybe<void>::Ok();
});

REGISTER_NO_GRAD_USER_OP("random_mask_like")
Expand Down
4 changes: 3 additions & 1 deletion oneflow/user/ops/expand_dims_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ REGISTER_USER_OP("expand_dims")
});

REGISTER_USER_OP_GRAD("expand_dims")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -79,6 +80,7 @@ REGISTER_USER_OP_GRAD("expand_dims")
op.BindGradTensorWithOpInput(grad_op.output("out", 0), "in", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
3 changes: 2 additions & 1 deletion oneflow/user/ops/expand_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ REGISTER_USER_OP("expand_grad")
});

REGISTER_USER_OP_GRAD("expand").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -80,6 +80,7 @@ REGISTER_USER_OP_GRAD("expand").SetGenBackwardOpConfFn([](const user_op::UserOpW
op.BindGradTensorWithOpInput(grad_op.output("out", 0), "in", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
Loading

0 comments on commit de29655

Please sign in to comment.