Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gen_bw_fn return maybe #5454

Merged
merged 36 commits into from
Jul 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d897143
modified SetInputArgModifyFn
luqiang-guo Jul 1, 2021
e33df72
Delete the CHECK changes in the assign_op.cpp file
luqiang-guo Jul 9, 2021
b6732e1
Format
luqiang-guo Jul 9, 2021
6e674d3
Modified the OutputArgModifyFn interface
luqiang-guo Jul 9, 2021
038c42f
Merge branch 'master' into Replace_check_using_maybe_check_part_SetIn…
luqiang-guo Jul 9, 2021
bb22c4d
Merge branch 'master' into Replace_check_using_maybe_check_part_Outpu…
luqiang-guo Jul 9, 2021
c7fccf9
add return
luqiang-guo Jul 9, 2021
9c3fe8b
maybe error stack from CheckAndConstructOp to OutputArgModifier callb…
liufengwei0103 Jul 10, 2021
3434303
maybe error stack from CheckAndConstructOp to OutputArgModifier callb…
liufengwei0103 Jul 10, 2021
46ddb46
Merge branch 'OutputArgModifier_return_maybe' of https://github.com/o…
liufengwei0103 Jul 10, 2021
2d0b5a7
OutputArgModifier return maybe part_1
liufengwei0103 Jul 10, 2021
1d97b6c
Merge branch 'OutputArgModifier_return_maybe_part_1' of https://githu…
liufengwei0103 Jul 10, 2021
dc45b30
maybe error stack from CheckAndConstructOp to OutputArgModifier callb…
liufengwei0103 Jul 10, 2021
1f0e385
Merge branch 'OutputArgModifier_return_maybe' of https://github.com/o…
liufengwei0103 Jul 10, 2021
fab170a
Merge branch 'OutputArgModifier_return_maybe_part_1' of https://githu…
liufengwei0103 Jul 10, 2021
eb78626
input_arg_modifier return maybe
liufengwei0103 Jul 10, 2021
041eecf
gen_bw_fn return maybe
liufengwei0103 Jul 10, 2021
3c213d8
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 16, 2021
648ff8b
add MakeGenBackwardOpConf because ofstatement-expression not allowed …
liufengwei0103 Jul 16, 2021
164a047
Merge branch 'master' into gen_bw_fn_return_maybe
liufengwei0103 Jul 16, 2021
4ef9bf7
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 16, 2021
b93b81c
add maybe after merge master
liufengwei0103 Jul 16, 2021
4473cde
Merge branch 'gen_bw_fn_return_maybe' of https://github.com/Oneflow-I…
liufengwei0103 Jul 16, 2021
88e0979
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 16, 2021
d5f4b10
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
6b1a885
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
3b43178
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
3842722
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
33c2e3e
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
7899fbc
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
4dfdb70
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 16, 2021
edaf987
Merge branch 'master' into gen_bw_fn_return_maybe
oneflow-ci-bot Jul 17, 2021
74ab623
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 17, 2021
09a53d3
Merge branch 'gen_bw_fn_return_maybe' of https://github.com/Oneflow-I…
liufengwei0103 Jul 17, 2021
5a90886
fix bug: JUST in lambda
liufengwei0103 Jul 17, 2021
fba73cf
Merge branch 'master' into gen_bw_fn_return_maybe
liufengwei0103 Jul 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/core/framework/user_op_grad_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace oneflow {
namespace user_op {

using AddOpFn = std::function<void(const UserOpConfWrapper&)>;
using GenBackwardOpConfFn = std::function<void(const UserOpWrapper&, AddOpFn)>;
using GenBackwardOpConfFn = std::function<Maybe<void>(const UserOpWrapper&, AddOpFn)>;
using BackwardOpConfGenFn = std::function<void(BackwardOpConfContext*)>;

struct OpGradRegistryResult {
Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/job_rewriter/user_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Maybe<void> GenerateBackwardOpConf(
const Operator& fw_op, std::vector<OperatorConf>* bw_op_confs,
const std::function<LogicalBlobId*(const std::string&)>& DiffLbi4BnInOp,
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4BnInOp) {
CHECK(fw_op.op_conf().has_user_conf());
CHECK_OR_RETURN(fw_op.op_conf().has_user_conf());
const UserOpConf& user_conf = fw_op.op_conf().user_conf();
const user_op::OpGradRegistryResult* val =
user_op::UserOpRegistryMgr::Get().GetOpGradRegistryResult(user_conf.op_type_name());
Expand All @@ -43,13 +43,13 @@ Maybe<void> GenerateBackwardOpConf(
auto AddOp = [&](const user_op::UserOpConfWrapper& wrapper) {
bw_op_confs->push_back(wrapper.op_conf());
};
val->gen_bw_fn(fw_user_op, AddOp);
JUST(val->gen_bw_fn(fw_user_op, AddOp));
}

for (const std::string& ibn : fw_op.input_bns()) {
LogicalBlobId* lbi = DiffLbi4BnInOp(ibn);
if (lbi != nullptr) {
CHECK(lbi->has_op_name() && lbi->has_blob_name())
CHECK_OR_RETURN(lbi->has_op_name() && lbi->has_blob_name())
<< " user_op: " << fw_op.op_name() << " op_type_name: " << user_conf.op_type_name()
<< " 's input blob " << ibn << " has not generate input diff blob !";
}
Expand Down
1 change: 0 additions & 1 deletion oneflow/user/kernels/pool_gpu_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class GPUPoolOpKernelState final : public user_op::OpKernelState {
static std::shared_ptr<GPUPoolOpKernelState> FromKernelComputeContext(
const int32_t& dim, const std::string& pooling_type, user_op::KernelComputeContext* ctx) {
if (pooling_type != "MAX" && pooling_type != "AVG") { UNIMPLEMENTED(); }
const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const ShapeView& x_shape = ctx->Tensor4ArgNameAndIndex("x", 0)->shape();
const std::string& data_format = ctx->Attr<std::string>("data_format");
const std::string& padding = ctx->Attr<std::string>("padding");
Expand Down
3 changes: 2 additions & 1 deletion oneflow/user/ops/add_n_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ REGISTER_USER_OP("add_n")
});

REGISTER_USER_OP_GRAD("add_n").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
user_op::AddOpFn AddOp) -> Maybe<void> {
int32_t in_size = op.input_size("in");
for (int i = 0; i < in_size; ++i) {
if (op.NeedGenGradTensor4OpInput("in", i)) {
op.BindGradTensorWithOpInput(op.GetGradTensorWithOpOutput("out", 0), "in", i);
}
}
return Maybe<void>::Ok();
});

} // namespace oneflow
4 changes: 3 additions & 1 deletion oneflow/user/ops/amp_white_identity_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ REGISTER_USER_OP("amp_white_identity")
});

REGISTER_USER_OP_GRAD("amp_white_identity")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -56,6 +57,7 @@ REGISTER_USER_OP_GRAD("amp_white_identity")
op.BindGradTensorWithOpInput(grad_op.output("out", 0), "in", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace
Expand Down
4 changes: 3 additions & 1 deletion oneflow/user/ops/batch_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ REGISTER_USER_OP("batch_gather")
});

REGISTER_USER_OP_GRAD("batch_gather")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
bool need_grad_in = op.NeedGenGradTensor4OpInput("in", 0);
if (need_grad_in) {
const Shape in_shape = op.TensorDesc4ArgNameAndIndex("in", 0).shape();
Expand All @@ -102,6 +103,7 @@ REGISTER_USER_OP_GRAD("batch_gather")
op.BindGradTensorWithOpInput(in_grad_op.output("out", 0), "in", 0);
AddOp(in_grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
4 changes: 3 additions & 1 deletion oneflow/user/ops/bias_add_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ REGISTER_USER_OP("bias_add")
});

REGISTER_USER_OP_GRAD("bias_add")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("a", 0)) {
op.BindGradTensorWithOpInput(op.GetGradTensorWithOpOutput("out", 0), "a", 0);
}
Expand All @@ -79,6 +80,7 @@ REGISTER_USER_OP_GRAD("bias_add")
AddOp(grad_op);
op.BindGradTensorWithOpInput(grad_op.output("output_tensor", 0), "b", 0);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
16 changes: 12 additions & 4 deletions oneflow/user/ops/broadcast_ops_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ std::string CreateReduceSumLikeBlob(const std::string& in_lbn, const Shape& in_s
} // namespace

REGISTER_USER_OP_GRAD("broadcast_add")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
const Shape& z_shape = op.TensorDesc4ArgNameAndIndex("z", 0).shape();
const std::string& dz_lbn = op.GetGradTensorWithOpOutput("z", 0);
if (op.NeedGenGradTensor4OpInput("x", 0)) {
Expand All @@ -71,10 +72,12 @@ REGISTER_USER_OP_GRAD("broadcast_add")
CreateReduceSumLikeBlob(dz_lbn, z_shape, y_lbn, y_shape, op.op_name() + "_y", AddOp);
op.BindGradTensorWithOpInput(out_lbn, "y", 0);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("broadcast_sub")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
const Shape& z_shape = op.TensorDesc4ArgNameAndIndex("z", 0).shape();
const std::string& dz_lbn = op.GetGradTensorWithOpOutput("z", 0);
if (op.NeedGenGradTensor4OpInput("x", 0)) {
Expand Down Expand Up @@ -102,10 +105,12 @@ REGISTER_USER_OP_GRAD("broadcast_sub")
scalar_mul_op.output("out", 0), z_shape, y_lbn, y_shape, op.op_name() + "_y", AddOp);
op.BindGradTensorWithOpInput(out_lbn, "y", 0);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("broadcast_mul")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
const Shape& z_shape = op.TensorDesc4ArgNameAndIndex("z", 0).shape();
const std::string& dz_lbn = op.GetGradTensorWithOpOutput("z", 0);
if (op.NeedGenGradTensor4OpInput("x", 0)) {
Expand Down Expand Up @@ -136,10 +141,12 @@ REGISTER_USER_OP_GRAD("broadcast_mul")
broadcast_mul_op.output("z", 0), z_shape, y_lbn, y_shape, op.op_name() + "_y", AddOp);
op.BindGradTensorWithOpInput(out_lbn, "y", 0);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("broadcast_div")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
const std::string& dz_lbn = op.GetGradTensorWithOpOutput("z", 0);
if (op.NeedGenGradTensor4OpInput("x", 0)) {
const Shape& z_shape = op.TensorDesc4ArgNameAndIndex("z", 0).shape();
Expand Down Expand Up @@ -167,6 +174,7 @@ REGISTER_USER_OP_GRAD("broadcast_div")
op.BindGradTensorWithOpInput(grad_op.output("dy", 0), "y", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
3 changes: 2 additions & 1 deletion oneflow/user/ops/cast_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ REGISTER_USER_OP("cast")
.SetDataTypeInferFn(InferDataType);

REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
const DataType& dtype = op.TensorDesc4ArgNameAndIndex("in", 0).data_type();
Expand All @@ -64,6 +64,7 @@ REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWra
op.BindGradTensorWithOpInput(cast_grad_op.output("out", 0), "in", 0);
AddOp(cast_grad_op);
}
return Maybe<void>::Ok();
});

} // namespace
Expand Down
4 changes: 3 additions & 1 deletion oneflow/user/ops/cast_to_static_shape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ REGISTER_USER_OP("cast_to_static_shape")
});

REGISTER_USER_OP_GRAD("cast_to_static_shape")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("input", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper identity_op =
Expand All @@ -59,6 +60,7 @@ REGISTER_USER_OP_GRAD("cast_to_static_shape")
op.BindGradTensorWithOpInput(identity_op.output("out", 0), "input", 0);
AddOp(identity_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
12 changes: 9 additions & 3 deletions oneflow/user/ops/clip_by_value_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ REGISTER_USER_OP("clip_by_scalar_max_grad")
.SetDataTypeInferFn(InferClipGradDataType);

REGISTER_USER_OP_GRAD("clip_by_scalar")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -144,10 +145,12 @@ REGISTER_USER_OP_GRAD("clip_by_scalar")
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("clip_by_scalar_min")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -161,10 +164,12 @@ REGISTER_USER_OP_GRAD("clip_by_scalar_min")
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

REGISTER_USER_OP_GRAD("clip_by_scalar_max")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -178,6 +183,7 @@ REGISTER_USER_OP_GRAD("clip_by_scalar_max")
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
4 changes: 3 additions & 1 deletion oneflow/user/ops/combined_margin_loss_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ REGISTER_USER_OP("combined_margin_loss_grad")
});

REGISTER_USER_OP_GRAD("combined_margin_loss")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op = builder.Op("combined_margin_loss_grad")
Expand All @@ -122,6 +123,7 @@ REGISTER_USER_OP_GRAD("combined_margin_loss")
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
3 changes: 2 additions & 1 deletion oneflow/user/ops/concat_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Maybe<void> GetSbpSignature(user_op::SbpContext* ctx) {
return Maybe<void>::Ok();
}

void GenGrapOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
Maybe<void> GenGrapOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
bool need_grad = false;
const int32_t in_size = op.input_size("in");
FOR_RANGE(int32_t, i, 0, in_size) {
Expand All @@ -90,6 +90,7 @@ void GenGrapOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
}
AddOp(grad_op);
}
return Maybe<void>::Ok();
}

Maybe<void> InferDataType(user_op::InferContext* ctx) {
Expand Down
7 changes: 4 additions & 3 deletions oneflow/user/ops/conv_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Maybe<void> CheckAttr(const user_op::UserOpDefWrapper& def,
}
}

void GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
Maybe<void> GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
const auto& padding_before = op.attr<std::vector<int32_t>>("padding_before");
std::string data_format = op.attr<std::string>("data_format");
std::vector<int32_t> kernel_size = op.attr<std::vector<int32_t>>("kernel_size");
Expand All @@ -168,8 +168,8 @@ void GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddO
int32_t groups = op.attr<int32_t>("groups");

int32_t ndims = kernel_size.size();
CHECK_EQ(ndims, strides.size());
CHECK_EQ(ndims, dilation_rate.size());
CHECK_EQ_OR_RETURN(ndims, strides.size());
CHECK_EQ_OR_RETURN(ndims, dilation_rate.size());

if (op.user_op_conf().has_input("bias", 0)) {
if (op.NeedGenGradTensor4OpInput("bias", 0)) {
Expand Down Expand Up @@ -224,6 +224,7 @@ void GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddO
op.BindGradTensorWithOpInput(data_grad_op.output("dx", 0), "in", 0);
AddOp(data_grad_op);
}
return Maybe<void>::Ok();
}

} // namespace
Expand Down
8 changes: 5 additions & 3 deletions oneflow/user/ops/deconv_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ Maybe<void> CheckAttr(const user_op::UserOpDefWrapper& def,
}
}

void GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
Maybe<void> GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
const std::string& data_format = op.attr<std::string>("data_format");
const auto& padding_before = op.attr<std::vector<int32_t>>("padding_before");
const auto& kernel_size = op.attr<std::vector<int32_t>>("kernel_size");
Expand All @@ -145,8 +146,8 @@ void GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, user_op::Ad
const Shape& weight_shape = op.TensorDesc4ArgNameAndIndex("weight", 0).shape();

const int32_t ndims = kernel_size.size();
CHECK_EQ(ndims, strides.size());
CHECK_EQ(ndims, dilation_rate.size());
CHECK_EQ_OR_RETURN(ndims, strides.size());
CHECK_EQ_OR_RETURN(ndims, dilation_rate.size());

if (op.NeedGenGradTensor4OpInput("weight", 0)) {
auto filter_grad_op =
Expand Down Expand Up @@ -186,6 +187,7 @@ void GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, user_op::Ad
op.BindGradTensorWithOpInput(data_grad_op.output("out", 0), "in", 0);
AddOp(data_grad_op);
}
return Maybe<void>::Ok();
}

} // namespace
Expand Down
3 changes: 2 additions & 1 deletion oneflow/user/ops/dropout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ REGISTER_USER_OP("dropout_grad")
});

REGISTER_USER_OP_GRAD("dropout").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper dropout_grad_op =
Expand All @@ -106,6 +106,7 @@ REGISTER_USER_OP_GRAD("dropout").SetGenBackwardOpConfFn([](const user_op::UserOp
op.BindGradTensorWithOpInput(dropout_grad_op.output("dx", 0), "in", 0);
AddOp(dropout_grad_op);
}
return Maybe<void>::Ok();
});

REGISTER_NO_GRAD_USER_OP("random_mask_like")
Expand Down
4 changes: 3 additions & 1 deletion oneflow/user/ops/expand_dims_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ REGISTER_USER_OP("expand_dims")
});

REGISTER_USER_OP_GRAD("expand_dims")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -79,6 +80,7 @@ REGISTER_USER_OP_GRAD("expand_dims")
op.BindGradTensorWithOpInput(grad_op.output("out", 0), "in", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow
3 changes: 2 additions & 1 deletion oneflow/user/ops/expand_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ REGISTER_USER_OP("expand_grad")
});

REGISTER_USER_OP_GRAD("expand").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
Expand All @@ -80,6 +80,7 @@ REGISTER_USER_OP_GRAD("expand").SetGenBackwardOpConfFn([](const user_op::UserOpW
op.BindGradTensorWithOpInput(grad_op.output("out", 0), "in", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});

} // namespace oneflow