From a2e86ecea7f29846b996e40d565db9f528249af8 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Wed, 16 Jun 2021 22:48:37 +0800 Subject: [PATCH 01/42] Add partial unary and math functional apis. --- .../autograd/gradient_funcs/scalar_mul.cpp | 69 ++++++ oneflow/core/functional/functional_api.yaml | 180 ++++++++++++++- oneflow/core/functional/impl/add_functor.cpp | 51 +---- oneflow/core/functional/impl/common.h | 30 +++ oneflow/core/functional/impl/math_functor.cpp | 205 ++++++++++++++++++ ...rmalization_functor.cpp => nn_functor.cpp} | 0 .../core/functional/impl/unary_functor.cpp | 80 +++++++ oneflow/core/functional/impl/unary_functor.h | 49 +++++ oneflow/python/nn/modules/abs.py | 4 +- oneflow/python/nn/modules/acos.py | 7 +- oneflow/python/nn/modules/acosh.py | 3 +- oneflow/python/nn/modules/activation.py | 29 +-- oneflow/python/nn/modules/exp.py | 3 +- oneflow/python/nn/modules/loss.py | 22 +- oneflow/python/nn/modules/math_ops.py | 92 ++------ oneflow/python/nn/modules/negative.py | 3 +- oneflow/python/nn/modules/permute.py | 10 +- oneflow/python/nn/modules/round.py | 3 +- oneflow/python/nn/modules/sign.py | 3 +- oneflow/python/nn/modules/sinh.py | 3 +- oneflow/python/nn/modules/transpose.py | 10 +- 21 files changed, 648 insertions(+), 208 deletions(-) create mode 100644 oneflow/core/autograd/gradient_funcs/scalar_mul.cpp create mode 100644 oneflow/core/functional/impl/common.h create mode 100644 oneflow/core/functional/impl/math_functor.cpp rename oneflow/core/functional/impl/{normalization_functor.cpp => nn_functor.cpp} (100%) create mode 100644 oneflow/core/functional/impl/unary_functor.cpp create mode 100644 oneflow/core/functional/impl/unary_functor.h diff --git a/oneflow/core/autograd/gradient_funcs/scalar_mul.cpp b/oneflow/core/autograd/gradient_funcs/scalar_mul.cpp new file mode 100644 index 00000000000..bc64ff10126 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/scalar_mul.cpp @@ -0,0 +1,69 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct ScalarMulInterpState : public OpExprInterpState { + bool requires_grad; + functional::Scalar operand; +}; + +class ScalarMul : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(ScalarMulInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + ComposedAttrMap composed_attrs(attrs, base_attrs_); + bool has_float_operand = JUST(composed_attrs.GetAttr("has_float_operand")); + if (has_float_operand) { + ctx->operand = functional::Scalar(JUST(composed_attrs.GetAttr("float_operand"))); + } else { + ctx->operand = functional::Scalar(JUST(composed_attrs.GetAttr("int_operand"))); + } + return Maybe::Ok(); + } + + Maybe Apply(const ScalarMulInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + in_grads->at(0) = JUST(functional::ScalarMul(out_grads.at(0), ctx->operand)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_mul", ScalarMul); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 3e2dca8c370..db9121432f6 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -12,26 +12,188 @@ # See the License for the specific language governing permissions and # limitations under the License. -# The following data types are allowed: +# The following data types are allowed, # { # "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool", -# "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList", "BoolList" +# "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList", +# "BoolList", DataType, Shape # } +- name: "add_n" + signature: "Tensor AddN(TensorTuple inputs)" + bind_python: True + - name: "add" signature: "Tensor Add(Tensor x, Tensor y)" bind_python: True -- name: "add_n" - signature: "Tensor AddN(TensorTuple inputs)" +- name: "add_scalar" + signature: "Tensor ScalarAdd(Tensor x, *, Scalar alpha)" bind_python: True -- name: "add_scalar" - signature: "Tensor AddScalar(Tensor x, *, Scalar alpha)" +- name: "mul_scalar" + signature: "Tensor ScalarMul(Tensor x, *, Scalar alpha)" + bind_python: True + +- name: "pow" + signature: "Tensor Pow(Tensor x, Tensor y)" + bind_python: True + +- name: "pow_scalar" + signature: "Tensor ScalarPow(Tensor x, *, Scalar alpha)" + bind_python: True + +- name: "reduce_sum" + signature: "Tensor ReduceSum(Tensor x, *, Int32List axis, Bool keepdims=False)" + bind_python: True + +- name: "transpose" + signature: "Tensor Transpose(Tensor x, *, Int32List perm)" + bind_python: True + +- name: "reciprocal" + signature: "Tensor Reciprocal(Tensor x)" + bind_python: True + +- name: "reciprocal_no_nan" + signature: "Tensor ReciprocalNoNan(Tensor x)" + bind_python: True + +- name: "sin" + signature: "Tensor Sin(Tensor x)" + bind_python: True + +- name: "cos" + signature: "Tensor Cos(Tensor x)" + bind_python: True + +- name: "cosh" + signature: "Tensor Cosh(Tensor x)" + bind_python: True + +- name: "log" + signature: "Tensor Log(Tensor x)" + bind_python: True + +- name: "sqrt" + signature: "Tensor Sqrt(Tensor x)" + bind_python: True + +- name: "rsqrt" + signature: "Tensor Rsqrt(Tensor x)" + bind_python: True + +- name: "square" + signature: "Tensor Square(Tensor x)" + bind_python: True + +- name: "tanh" + signature: "Tensor Tanh(Tensor x)" + bind_python: True + +- name: "sigmoid" + signature: "Tensor Sigmoid(Tensor x)" bind_python: True - name: "normalization" - signature: "Tensor Normalization(Tensor x, Tensor moving_mean, Tensor moving_variance, - Tensor gamma, Tensor beta, *, Int32 axis=1, Float epsilon=1e-5, - Float momentum=0.9, Bool is_training=False)" + signature: + "Tensor Normalization(Tensor x, Tensor moving_mean, Tensor moving_variance, + Tensor gamma, Tensor beta, *, Int32 axis=1, Float epsilon=1e-5, + Float momentum=0.9, Bool is_training=False)" + bind_python: True + +- name: "range" + signature: "Tensor Range(*, Int64 start, Int64 limit, Int64 delta, DataType dtype=kInt64)" + bind_python: True + +- name: "argmax" + signature: "Tensor ArgMax(Tensor x)" + bind_python: True + +- name: "cast" + signature: "Tensor Cast(Tensor x, *, DataType dtype)" + bind_python: True + +- name: "exp" + signature: "Tensor Exp(Tensor x)" + bind_python: True + +- name: "negative" + signature: "Tensor Negative(Tensor x)" + bind_python: True + +- name: "abs" + signature: "Tensor Abs(Tensor x)" + bind_python: True + +- name: "acos" + signature: "Tensor Acos(Tensor x)" + bind_python: True + +- name: "acosh" + signature: "Tensor Acosh(Tensor x)" + bind_python: True + +- name: "asin" + signature: "Tensor Asin(Tensor x)" + bind_python: True + +- name: "asinh" + signature: "Tensor Asinh(Tensor x)" + bind_python: True + +- name: "atan" + signature: "Tensor Atan(Tensor x)" + bind_python: True + +- name: "atanh" + signature: "Tensor Atanh(Tensor x)" + bind_python: True + +- name: "ceil" + signature: "Tensor Ceil(Tensor x)" + bind_python: True + +- name: "erf" + signature: "Tensor Erf(Tensor x)" + bind_python: True + +- name: "expm1" + signature: "Tensor Expm1(Tensor x)" + bind_python: True + +- name: "floor" + signature: "Tensor Floor(Tensor x)" + bind_python: True + +- name: "lgamma" + signature: "Tensor Lgamma(Tensor x)" + bind_python: True + +- name: "log1p" + signature: "Tensor Log1p(Tensor x)" + bind_python: True + +- name: "log_sigmoid" + signature: "Tensor LogSigmoid(Tensor x)" + bind_python: True + +- name: "rint" + signature: "Tensor Rint(Tensor x)" + bind_python: True + +- name: "round" + signature: "Tensor Round(Tensor x)" + bind_python: True + +- name: "sign" + signature: "Tensor Sign(Tensor x)" + bind_python: True + +- name: "sinh" + signature: "Tensor Sinh(Tensor x)" + bind_python: True + +- name: "softplus" + signature: "Tensor Softplus(Tensor x)" bind_python: True diff --git a/oneflow/core/functional/impl/add_functor.cpp b/oneflow/core/functional/impl/add_functor.cpp index 522d25dd85d..16a9ccc4c04 100644 --- a/oneflow/core/functional/impl/add_functor.cpp +++ b/oneflow/core/functional/impl/add_functor.cpp @@ -43,58 +43,9 @@ class AddFunctor { std::shared_ptr add_op_; }; -class AddNFunctor { - public: - AddNFunctor() { - add_n_op_.resize(128 /*the maximum number of inputs*/); - for (int n = 2; n < add_n_op_.size(); ++n) { - add_n_op_[n] = CHECK_JUST(one::OpBuilder("add_n").Input("in", n).Output("out").Build()); - } - } - Maybe operator()(const TensorTuple& inputs) const { - CHECK_GE_OR_RETURN(inputs.size(), 2); - CHECK_LT_OR_RETURN(inputs.size(), add_n_op_.size()) - << "The maximum number supported of inputs is " << add_n_op_.size(); - return OpInterpUtil::Dispatch(*add_n_op_.at(inputs.size()), inputs); - } - - private: - std::vector> add_n_op_; -}; - -class AddScalarFunctor { - public: - AddScalarFunctor() { - add_scalar_op_ = CHECK_JUST(one::OpBuilder("scalar_add").Input("in").Output("out").Build()); - } - Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { - MutableAttrMap attrs; - if (scalar.IsFloatingPoint()) { - JUST(attrs.SetAttr("float_operand", JUST(scalar.As()))); - JUST(attrs.SetAttr("has_float_operand", true)); - JUST(attrs.SetAttr("has_int_operand", false)); - return OpInterpUtil::Dispatch(*add_scalar_op_, {x}, attrs); - } else if (scalar.IsIntegral()) { - JUST(attrs.SetAttr("int_operand", JUST(scalar.As()))); - JUST(attrs.SetAttr("has_float_operand", false)); - JUST(attrs.SetAttr("has_int_operand", true)); - return OpInterpUtil::Dispatch(*add_scalar_op_, {x}, attrs); - } else { - UNIMPLEMENTED_THEN_RETURN(); - } - } - - private: - std::shared_ptr add_scalar_op_; -}; - } // namespace impl -ONEFLOW_FUNCTION_LIBRARY(m) { - m.add_functor("Add"); - m.add_functor("AddN"); - m.add_functor("AddScalar"); -}; +ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Add"); } } // namespace functional } // namespace one diff --git a/oneflow/core/functional/impl/common.h b/oneflow/core/functional/impl/common.h new file mode 100644 index 00000000000..f1c651bba88 --- /dev/null +++ b/oneflow/core/functional/impl/common.h @@ -0,0 +1,30 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_ +#define ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_ + +namespace oneflow { +namespace one { +namespace functional { + +static constexpr size_t kMaxInputCount = 128; +static constexpr size_t kMaxOutputCount = 128; + +} // namespace functional +} // namespace one +} // namespace oneflow + +#endif // ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_ diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp new file mode 100644 index 00000000000..ab42f42742f --- /dev/null +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -0,0 +1,205 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/impl/common.h" +#include "oneflow/core/functional/impl/unary_functor.h" +#include "oneflow/core/functional/scalar.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class AddNFunctor { + public: + AddNFunctor() { + op_.resize(kMaxInputCount /*the maximum number of inputs*/); + for (int n = 2; n < op_.size(); ++n) { + op_[n] = CHECK_JUST(one::OpBuilder("add_n").Input("in", n).Output("out").Build()); + } + } + Maybe operator()(const TensorTuple& inputs) const { + CHECK_GE_OR_RETURN(inputs.size(), 2); + CHECK_LT_OR_RETURN(inputs.size(), op_.size()) + << "The maximum number supported of inputs is " << op_.size(); + return OpInterpUtil::Dispatch(*op_.at(inputs.size()), inputs); + } + + private: + std::vector> op_; +}; + +class ScalarAddFunctor { + public: + ScalarAddFunctor() { + op_ = CHECK_JUST(one::OpBuilder("scalar_add").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { + MutableAttrMap attrs; + if (scalar.IsFloatingPoint()) { + JUST(attrs.SetAttr("float_operand", JUST(scalar.As()))); + JUST(attrs.SetAttr("has_float_operand", true)); + JUST(attrs.SetAttr("has_int_operand", false)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } else if (scalar.IsIntegral()) { + JUST(attrs.SetAttr("int_operand", JUST(scalar.As()))); + JUST(attrs.SetAttr("has_float_operand", false)); + JUST(attrs.SetAttr("has_int_operand", true)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } + + private: + std::shared_ptr op_; +}; + +class ScalarMulFunctor { + public: + ScalarMulFunctor() { + op_ = CHECK_JUST(one::OpBuilder("scalar_mul").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { + MutableAttrMap attrs; + if (scalar.IsFloatingPoint()) { + JUST(attrs.SetAttr("float_operand", JUST(scalar.As()))); + JUST(attrs.SetAttr("has_float_operand", true)); + JUST(attrs.SetAttr("has_int_operand", false)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } else if (scalar.IsIntegral()) { + JUST(attrs.SetAttr("int_operand", JUST(scalar.As()))); + JUST(attrs.SetAttr("has_float_operand", false)); + JUST(attrs.SetAttr("has_int_operand", true)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } + + private: + std::shared_ptr op_; +}; + +class ScalarPowFunctor { + public: + ScalarPowFunctor() { + op_ = CHECK_JUST(one::OpBuilder("scalar_pow").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const Scalar& scalar) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("exponent", JUST(scalar.As()))); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ReduceSumFunctor { + public: + ReduceSumFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("reduce_sum").Input("input_tensor").Output("output_tensor").Build()); + } + Maybe operator()(const std::shared_ptr& x, const std::vector& axis, + const bool& keepdims) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("axis", axis)); + JUST(attrs.SetAttr("keepdims", keepdims)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class TransposeFunctor { + public: + TransposeFunctor() { + op_ = CHECK_JUST(one::OpBuilder("transpose").Input("input").Output("output").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::vector& permute) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("perm", permute)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class RangeFunctor { + public: + RangeFunctor() { op_ = CHECK_JUST(one::OpBuilder("range").Output("out").Build()); } + Maybe operator()(const int64_t& start, const int64_t& limit, const int64_t& delta, + const DataType& dtype) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("start", start)); + JUST(attrs.SetAttr("limit", limit)); + JUST(attrs.SetAttr("delta", delta)); + JUST(attrs.SetAttr("dtype", dtype)); + return OpInterpUtil::Dispatch(*op_, {}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ArgMaxFunctor : public UnaryFunctor { + public: + ArgMaxFunctor() { op_ = CHECK_JUST(one::OpBuilder("argmax").Input("in").Output("out").Build()); } +}; + +class CastFunctor { + public: + CastFunctor() { op_ = CHECK_JUST(one::OpBuilder("cast").Input("in").Output("out").Build()); } + Maybe operator()(const std::shared_ptr& x, const DataType& dtype) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("dtype", dtype)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor("AddN"); + m.add_functor("ScalarAdd"); + m.add_functor("ScalarMul"); + m.add_functor("ScalarPow"); + m.add_functor("ReduceSum"); + m.add_functor("Transpose"); + m.add_functor("Range"); + m.add_functor("ArgMax"); + m.add_functor("Cast"); +}; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/impl/normalization_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp similarity index 100% rename from oneflow/core/functional/impl/normalization_functor.cpp rename to oneflow/core/functional/impl/nn_functor.cpp diff --git a/oneflow/core/functional/impl/unary_functor.cpp b/oneflow/core/functional/impl/unary_functor.cpp new file mode 100644 index 00000000000..1f261e23fdb --- /dev/null +++ b/oneflow/core/functional/impl/unary_functor.cpp @@ -0,0 +1,80 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/functional/impl/unary_functor.h" + +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/user/ops/math_unary_elementwise_seq.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +#define UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name) \ + class class_name##Functor : public UnaryFunctor { \ + public: \ + class_name##Functor() { \ + op_ = CHECK_JUST(one::OpBuilder(op_type_name).Input("x").Output("y").Build()); \ + } \ + }; + +OF_PP_FOR_EACH_TUPLE(UNARY_ELEMENTWISE_FUNCTOR, MATH_UNARY_ELEMENTWISE_FUNC_SEQ); + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor("Abs"); + m.add_functor("Acos"); + m.add_functor("Acosh"); + m.add_functor("Asin"); + m.add_functor("Asinh"); + m.add_functor("Atan"); + m.add_functor("Atanh"); + m.add_functor("Ceil"); + m.add_functor("Cos"); + m.add_functor("Cosh"); + m.add_functor("Erf"); + m.add_functor("Erfc"); + m.add_functor("Exp"); + m.add_functor("Expm1"); + m.add_functor("Floor"); + m.add_functor("Lgamma"); + m.add_functor("Log"); + m.add_functor("Log1p"); + m.add_functor("LogSigmoid"); + m.add_functor("Negative"); + m.add_functor("Reciprocal"); + m.add_functor("ReciprocalNoNan"); + m.add_functor("Rint"); + m.add_functor("Round"); + m.add_functor("Rsqrt"); + m.add_functor("Sigmoid"); + m.add_functor("Sign"); + m.add_functor("Sin"); + m.add_functor("Sinh"); + m.add_functor("Softplus"); + m.add_functor("Sqrt"); + m.add_functor("Square"); + m.add_functor("Tan"); + m.add_functor("Tanh"); +}; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/impl/unary_functor.h b/oneflow/core/functional/impl/unary_functor.h new file mode 100644 index 00000000000..f481f5e43a2 --- /dev/null +++ b/oneflow/core/functional/impl/unary_functor.h @@ -0,0 +1,49 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_ +#define ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_ + +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class UnaryFunctor { + public: + Maybe operator()(const std::shared_ptr& x) const { + return OpInterpUtil::Dispatch(*op_, {x}); + } + + protected: + UnaryFunctor() = default; + virtual ~UnaryFunctor() = default; + + std::shared_ptr op_; +}; + +} // namespace impl + +} // namespace functional +} // namespace one +} // namespace oneflow + +#endif // ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_ diff --git a/oneflow/python/nn/modules/abs.py b/oneflow/python/nn/modules/abs.py index 6bcb1ca6bbf..a3a09a6ad9a 100644 --- a/oneflow/python/nn/modules/abs.py +++ b/oneflow/python/nn/modules/abs.py @@ -23,11 +23,9 @@ class Abs(Module): def __init__(self): super().__init__() - self._op = flow.builtin_op("abs").Input("x").Output("y").Build() def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.abs(x) @oneflow_export("abs") diff --git a/oneflow/python/nn/modules/acos.py b/oneflow/python/nn/modules/acos.py index 4a65c5edf9a..3f5b40872fa 100644 --- a/oneflow/python/nn/modules/acos.py +++ b/oneflow/python/nn/modules/acos.py @@ -19,17 +19,12 @@ from oneflow.python.framework.tensor import register_tensor_op -def _build_math_binary_elementwise_op(math_op): - return flow.builtin_op(math_op).Input("x").Input("y").Output("z").Build() - - class Acos(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("acos").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.acos(x) @oneflow_export("acos") diff --git a/oneflow/python/nn/modules/acosh.py b/oneflow/python/nn/modules/acosh.py index 4c13b0f6af2..cbbb1a59ede 100644 --- a/oneflow/python/nn/modules/acosh.py +++ b/oneflow/python/nn/modules/acosh.py @@ -23,10 +23,9 @@ class Acosh(Module): def __init__(self): super().__init__() - self._op = flow.builtin_op("acosh").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.acosh(x) @oneflow_export("acosh") diff --git a/oneflow/python/nn/modules/activation.py b/oneflow/python/nn/modules/activation.py index 49eacb21097..ab900285dc1 100644 --- a/oneflow/python/nn/modules/activation.py +++ b/oneflow/python/nn/modules/activation.py @@ -170,11 +170,9 @@ class Tanh(Module): def __init__(self): super().__init__() - self._op = flow.builtin_op("tanh").Input("x").Output("y").Build() def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.tanh(x) @oneflow_export("tanh") @@ -379,10 +377,9 @@ class Sigmoid(Module): def __init__(self): super().__init__() - self._op = flow.builtin_op("sigmoid").Input("in").Output("out").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.sigmoid(x) @oneflow_export("sigmoid") @@ -471,22 +468,15 @@ def __init__(self, dim: Optional[int] = None): super().__init__() self.axis = -1 if dim is None else dim self._op = flow.builtin_op("softmax").Input("in").Output("out").Build() - self._transpose_op = ( - flow.builtin_op("transpose") - .Input("input") - .Output("output") - .Attr("perm", []) - .Build() - ) def forward(self, x): need_transpose, permute = _softmax_need_transpose(x, self.axis) if need_transpose: - x = self._transpose_op(x, perm=permute)[0] + x = flow.F.transpose(x, perm=permute) res = self._op(x)[0] if need_transpose: - res = self._transpose_op(res, perm=permute)[0] + res = flow.F.transpose(res, perm=permute) return res @@ -586,13 +576,6 @@ def __init__( ): super().__init__() self.dim = dim - self._op = ( - flow.builtin_op("transpose") - .Input("input") - .Output("output") - .Attr("perm", []) - .Build() - ) def __setstate__(self, state): self.__dict__.update(state) @@ -602,13 +585,13 @@ def __setstate__(self, state): def forward(self, x): need_transpose, permute = _softmax_need_transpose(x, self.dim) if need_transpose: - x = self._op(x, perm=permute)[0] + x = flow.F.transpose(x, perm=permute) x = x.softmax() res = x.log() if need_transpose: - res = self._op(res, perm=permute)[0] + res = flow.F.transpose(res, perm=permute) return res diff --git a/oneflow/python/nn/modules/exp.py b/oneflow/python/nn/modules/exp.py index 42c7c7f3204..f230b30cab3 100644 --- a/oneflow/python/nn/modules/exp.py +++ b/oneflow/python/nn/modules/exp.py @@ -22,10 +22,9 @@ class Exp(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("exp").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.exp(x) @oneflow_export("exp") diff --git a/oneflow/python/nn/modules/loss.py b/oneflow/python/nn/modules/loss.py index cff63122f50..24f4cd8faef 100644 --- a/oneflow/python/nn/modules/loss.py +++ b/oneflow/python/nn/modules/loss.py @@ -108,13 +108,6 @@ def __init__( .Output("out") .Build() ) - self._transpose_op = ( - flow.builtin_op("transpose") - .Input("input") - .Output("output") - .Attr("perm", []) - .Build() - ) def forward(self, input, target): assert len(input.shape) <= 4 @@ -122,12 +115,12 @@ def forward(self, input, target): input_shape_len = len(input.shape) if input_shape_len == 3: b, c, h = input.shape[0], input.shape[1], input.shape[2] - input = self._transpose_op(input, perm=(0, 2, 1))[0] + input = flow.F.transpose(input, perm=(0, 2, 1)) input = input.reshape(shape=[-1, input.shape[2]]) target = target.flatten() elif input_shape_len == 4: b, c, h, w = input.shape[0], input.shape[1], input.shape[2], input.shape[3] - input = self._transpose_op(input, perm=(0, 2, 3, 1))[0] + input = flow.F.transpose(input, perm=(0, 2, 3, 1)) input = input.reshape(shape=[-1, input.shape[3]]) target = target.flatten() elif input_shape_len >= 5: @@ -247,13 +240,6 @@ def __init__( .Attr("dim", 1) .Build() ) - self._transpose_op = ( - flow.builtin_op("transpose") - .Input("input") - .Output("output") - .Attr("perm", []) - .Build() - ) def nllloss_1d(self, input, target): target = flow.experimental.reshape(target, (target.shape[0], 1)) @@ -269,14 +255,14 @@ def forward(self, input, target): res = self.nllloss_1d(input, target) elif len(input.shape) == 3: b, c, h = input.shape[0], input.shape[1], input.shape[2] - input = self._transpose_op(input, perm=(0, 2, 1))[0] + input = flow.F.transpose(input, perm=(0, 2, 1)) input = input.reshape(shape=[-1, input.shape[2]]) target = target.flatten() res = self.nllloss_1d(input, target) res = res.reshape((b, h)) elif len(input.shape) == 4: b, c, h, w = input.shape[0], input.shape[1], input.shape[2], input.shape[3] - input = self._transpose_op(input, perm=(0, 2, 3, 1))[0] + input = flow.F.transpose(input, perm=(0, 2, 3, 1)) input = input.reshape(shape=[-1, input.shape[3]]) target = target.flatten() res = self.nllloss_1d(input, target) diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py index 1fcf4624dd9..3c08eb2ac23 100644 --- a/oneflow/python/nn/modules/math_ops.py +++ b/oneflow/python/nn/modules/math_ops.py @@ -28,35 +28,15 @@ ) -def _build_math_binary_elementwise_op(math_op): - return flow.builtin_op(math_op).Input("x").Input("y").Output("z").Build() - - class ScalarMul(Module): - def __init__(self, operand) -> None: + def __init__(self, alpha) -> None: super().__init__() - self._op = flow.builtin_op("scalar_mul").Input("in").Output("out") - if isinstance(operand, int): - self._op = ( - self._op.Attr("has_int_operand", True) - .Attr("has_float_operand", False) - .Attr("int_operand", operand) - .Attr("float_operand", 0.0) - .Build() - ) - elif isinstance(operand, float): - self._op = ( - self._op.Attr("has_int_operand", False) - .Attr("has_float_operand", True) - .Attr("int_operand", 0) - .Attr("float_operand", operand) - .Build() - ) - else: - raise ValueError("operand type can only be int or float") + if not isinstance(alpha, int) and not isinstance(alpha, float): + raise ValueError("alpha type can only be int or float") + self.alpha = alpha def forward(self, x): - return self._op(x)[0] + return flow.F.mul_scalar(x, self.alpha) class ScalarMulByTensor(Module): @@ -388,10 +368,9 @@ def _div(x, y): class Reciprocal(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("reciprocal_no_nan").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.reciprocal_no_nan(x) @oneflow_export("reciprocal") @@ -511,10 +490,9 @@ def _add(x, y): class Asin(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("asin").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.asin(x) @oneflow_export("asin") @@ -586,10 +564,9 @@ def arcsin_op_tensor(input): class Asinh(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("asinh").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.asinh(x) @oneflow_export("asinh") @@ -662,10 +639,9 @@ def arcsinh_op_tensor(input): class Sin(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("sin").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.sin(x) @oneflow_export("sin") @@ -719,10 +695,9 @@ def sin_op_tensor(tensor): class Cos(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("cos").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.cos(x) @oneflow_export("cos") @@ -756,10 +731,9 @@ def cos_op(tensor): class Atan(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("atan").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.atan(x) @oneflow_export("atan") @@ -825,10 +799,9 @@ def arctan_op_tensor(tensor): class Log(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("log").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.log(x) @oneflow_export("log") @@ -883,10 +856,9 @@ def forward(self, x, y): class Sqrt(Module): def __init__(self) -> None: super().__init__() - self.sqrt_op = flow.builtin_op("sqrt").Input("x").Output("y").Build() def forward(self, input): - return self.sqrt_op(input)[0] + return flow.F.sqrt(input) @oneflow_export("rsqrt") @@ -921,10 +893,9 @@ def rsqrt_op(input): class Rsqrt(Module): def __init__(self) -> None: super().__init__() - self.rsqrt_op = flow.builtin_op("rsqrt").Input("x").Output("y").Build() def forward(self, input): - return self.rsqrt_op(input)[0] + return flow.F.rsqrt(input) @oneflow_export("sqrt") @@ -959,10 +930,9 @@ def sqrt_op(input): class Square(Module): def __init__(self) -> None: super().__init__() - self.square_op = flow.builtin_op("square").Input("x").Output("y").Build() def forward(self, input): - return self.square_op(input)[0] + return flow.F.square(input) @oneflow_export("square") @@ -1072,16 +1042,12 @@ def std_op(tensor, dim, unbiased=True, keepdim=False): class Pow(Module): def __init__(self) -> None: super().__init__() - self._scalar_pow_op = ( - flow.builtin_op("scalar_pow").Input("in").Output("out").Build() - ) - self._elementwise_pow_op = _build_math_binary_elementwise_op("pow") def forward(self, x, y): if isinstance(y, (int, float)): - return self._scalar_pow_op(x, exponent=float(y))[0] + return flow.F.pow_scalar(x, alpha=y) else: - return self._elementwise_pow_op(x, y)[0] + return flow.F.pow(x, y) @oneflow_export("pow") @@ -1328,10 +1294,9 @@ def clip_op_tensor(tensor, min=None, max=None): class Cosh(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("cosh").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.cosh(x) @oneflow_export("cosh") @@ -1368,10 +1333,9 @@ def cosh_op(tensor): class Erf(Module): def __init__(self) -> None: super().__init__() - self.erf_op = flow.builtin_op("erf").Input("x").Output("y").Build() def forward(self, input): - return self.erf_op(input)[0] + return flow.F.erf(input) @oneflow_export("erf") @@ -1507,10 +1471,9 @@ def erfc_op_tensor(input): class Ceil(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("ceil").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.ceil(x) @oneflow_export("ceil") @@ -1587,10 +1550,9 @@ def ceil_op_tensor(x): class Expm1(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("expm1").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.expm1(x) @oneflow_export("expm1") @@ -1676,14 +1638,6 @@ def __init__( .Attr("sorted", sorted) .Build() ) - self._transpose_op = ( - flow.builtin_op("transpose") - .Input("input") - .Output("output") - .Attr("perm", []) - .Build() - ) - self.dim = dim self.largest = largest @@ -1703,13 +1657,13 @@ def forward(self, input): return (flow.experimental.gather(input, indices, dim=axis), indices) else: perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis) - x = self._transpose_op(input, perm=perm)[0] + x = flow.F.transpose(input, perm=perm) if self.largest: indices = self._op_topk_last_dim(x)[0] else: neg_input = flow.experimental.mul(x, -1) indices = self._op_topk_last_dim(neg_input)[0] - indices = self._transpose_op(indices, perm=get_inversed_perm(perm))[0] + indices = flow.F.transpose(indices, perm=get_inversed_perm(perm)) return (flow.experimental.gather(input, indices, dim=axis), indices) diff --git a/oneflow/python/nn/modules/negative.py b/oneflow/python/nn/modules/negative.py index fc6532b8747..a3e0041ee72 100644 --- a/oneflow/python/nn/modules/negative.py +++ b/oneflow/python/nn/modules/negative.py @@ -22,10 +22,9 @@ class Negative(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("negative").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.negative(x) @oneflow_export("negative", "neg") diff --git a/oneflow/python/nn/modules/permute.py b/oneflow/python/nn/modules/permute.py index 1e4cdac869a..0c2eb599295 100644 --- a/oneflow/python/nn/modules/permute.py +++ b/oneflow/python/nn/modules/permute.py @@ -25,14 +25,6 @@ def __init__(self, *dims) -> None: super().__init__() self.perm = list(*dims) - self._op = ( - flow.builtin_op("transpose") - .Input("input") - .Output("output") - .Attr("perm", []) - .Build() - ) - def forward(self, x): assert len(self.perm) == len(x.shape) new_perm = [] @@ -43,7 +35,7 @@ def forward(self, x): x.shape ), "Invalid dim0 {}, len(shape): {}".format(dim, len(x.shape)) new_perm.append(dim) - return self._op(x, perm=new_perm)[0] + return flow.F.transpose(x, perm=new_perm) @register_tensor_op("permute") diff --git a/oneflow/python/nn/modules/round.py b/oneflow/python/nn/modules/round.py index 726926379e8..ff5356cdeb4 100644 --- a/oneflow/python/nn/modules/round.py +++ b/oneflow/python/nn/modules/round.py @@ -23,10 +23,9 @@ class Round(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("round").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.round(x) @oneflow_export("round") diff --git a/oneflow/python/nn/modules/sign.py b/oneflow/python/nn/modules/sign.py index 5e83b2d1971..bbcdde77612 100644 --- a/oneflow/python/nn/modules/sign.py +++ b/oneflow/python/nn/modules/sign.py @@ -23,10 +23,9 @@ class Sign(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("sign").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.sign(x) @oneflow_export("sign") diff --git a/oneflow/python/nn/modules/sinh.py b/oneflow/python/nn/modules/sinh.py index 82100990b4e..9708cb65dd1 100644 --- a/oneflow/python/nn/modules/sinh.py +++ b/oneflow/python/nn/modules/sinh.py @@ -22,10 +22,9 @@ class Sinh(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("sinh").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.sinh(x) @oneflow_export("sinh") diff --git a/oneflow/python/nn/modules/transpose.py b/oneflow/python/nn/modules/transpose.py index f19aea4697a..223ba3ad7f4 100644 --- a/oneflow/python/nn/modules/transpose.py +++ b/oneflow/python/nn/modules/transpose.py @@ -32,14 +32,6 @@ def __init__( if batch_axis_non_change: raise NotImplementedError - self._op = ( - flow.builtin_op("transpose") - .Input("input") - .Output("output") - .Attr("perm", []) - .Build() - ) - self.dim0 = dim0 self.dim1 = dim1 @@ -63,7 +55,7 @@ def forward(self, x): perm.append(i) perm[dim0], perm[dim1] = perm[dim1], perm[dim0] - return self._op(x, perm=perm)[0] + return flow.F.transpose(x, perm=perm) @oneflow_export("transpose") From 6987c58c41460b78fdb1af67f3658a15b5102c39 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Wed, 16 Jun 2021 22:57:45 +0800 Subject: [PATCH 02/42] Revert elementwise pow. --- oneflow/core/functional/functional_api.yaml | 4 ---- oneflow/python/nn/modules/math_ops.py | 3 ++- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index db9121432f6..39baa6979c0 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -35,10 +35,6 @@ signature: "Tensor ScalarMul(Tensor x, *, Scalar alpha)" bind_python: True -- name: "pow" - signature: "Tensor Pow(Tensor x, Tensor y)" - bind_python: True - - name: "pow_scalar" signature: "Tensor ScalarPow(Tensor x, *, Scalar alpha)" bind_python: True diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py index 3c08eb2ac23..6d76f2bf432 100644 --- a/oneflow/python/nn/modules/math_ops.py +++ b/oneflow/python/nn/modules/math_ops.py @@ -1042,12 +1042,13 @@ def std_op(tensor, dim, unbiased=True, keepdim=False): class Pow(Module): def __init__(self) -> None: super().__init__() + self._elementwise_pow_op = flow.builtin_op("pow").Input("x").Input("y").Output("z").Build() def forward(self, x, y): if isinstance(y, (int, float)): return flow.F.pow_scalar(x, alpha=y) else: - return flow.F.pow(x, y) + return self._elementwise_pow_op(x, y)[0] @oneflow_export("pow") From de6427dfd58b88bce8963da490276129eb8f682f Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Wed, 16 Jun 2021 15:03:24 +0000 Subject: [PATCH 03/42] auto format by CI --- oneflow/python/nn/modules/math_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py index 6d76f2bf432..4a382454a7b 100644 --- a/oneflow/python/nn/modules/math_ops.py +++ b/oneflow/python/nn/modules/math_ops.py @@ -1042,7 +1042,9 @@ def std_op(tensor, dim, unbiased=True, keepdim=False): class Pow(Module): def __init__(self) -> None: super().__init__() - self._elementwise_pow_op = flow.builtin_op("pow").Input("x").Input("y").Output("z").Build() + self._elementwise_pow_op = ( + flow.builtin_op("pow").Input("x").Input("y").Output("z").Build() + ) def forward(self, x, y): if isinstance(y, (int, float)): From 3354fe3a9cda14e8954057ec19f1130cca3935d7 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Fri, 18 Jun 2021 14:11:26 +0800 Subject: [PATCH 04/42] Support add with large number of inputs. --- oneflow/core/functional/impl/math_functor.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index ab42f42742f..bf106fac051 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -35,15 +35,21 @@ class AddNFunctor { public: AddNFunctor() { op_.resize(kMaxInputCount /*the maximum number of inputs*/); - for (int n = 2; n < op_.size(); ++n) { - op_[n] = CHECK_JUST(one::OpBuilder("add_n").Input("in", n).Output("out").Build()); + for (int n = 1; n < op_.size(); ++n) { + op_[n] = CHECK_JUST(one::OpBuilder("add_n").Input("in", n + 1).Output("out").Build()); } } Maybe operator()(const TensorTuple& inputs) const { CHECK_GE_OR_RETURN(inputs.size(), 2); - CHECK_LT_OR_RETURN(inputs.size(), op_.size()) - << "The maximum number supported of inputs is " << op_.size(); - return OpInterpUtil::Dispatch(*op_.at(inputs.size()), inputs); + TensorTuple outputs; + for (int i = 0; i < inputs.size(); i += kMaxInputCount) { + size_t size = (i + kMaxInputCount) < inputs.size() ? kMaxInputCount : inputs.size() - i; + TensorTuple partial_inputs(size); + for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; } + outputs.push_back(JUST(OpInterpUtil::Dispatch(*op_.at(size - 1), partial_inputs))); + } + if (outputs.size() == 1) { return outputs.at(0); } + return this->operator()(outputs); } private: From cffeeb3e7298e055bf49d73bd7492f88a07f4589 Mon Sep 17 00:00:00 2001 From: Houjiang Chen Date: Fri, 18 Jun 2021 16:38:16 +0800 Subject: [PATCH 05/42] Update oneflow/python/nn/modules/math_ops.py Co-authored-by: Yinggang Wang --- oneflow/python/nn/modules/math_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py index 4a382454a7b..2078f880add 100644 --- a/oneflow/python/nn/modules/math_ops.py +++ b/oneflow/python/nn/modules/math_ops.py @@ -31,7 +31,7 @@ class ScalarMul(Module): def __init__(self, alpha) -> None: super().__init__() - if not isinstance(alpha, int) and not isinstance(alpha, float): + if not isinstance(alpha, (int, float)): raise ValueError("alpha type can only be int or float") self.alpha = alpha From 840e1a661354c4a54dcebf22e8ee0c2169a53374 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Fri, 18 Jun 2021 16:46:05 +0800 Subject: [PATCH 06/42] Refine --- oneflow/core/functional/functional_api.yaml | 2 +- oneflow/core/functional/impl/math_functor.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 39baa6979c0..649a25a86a5 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -16,7 +16,7 @@ # { # "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool", # "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList", -# "BoolList", DataType, Shape +# "BoolList", "DataType", "Shape" # } - name: "add_n" diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index bf106fac051..d6d101bce3c 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -45,7 +45,7 @@ class AddNFunctor { for (int i = 0; i < inputs.size(); i += kMaxInputCount) { size_t size = (i + kMaxInputCount) < inputs.size() ? kMaxInputCount : inputs.size() - i; TensorTuple partial_inputs(size); - for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; } + std::copy(inputs.begin() + i, inputs.begin() + i + size, partial_inputs.begin()); outputs.push_back(JUST(OpInterpUtil::Dispatch(*op_.at(size - 1), partial_inputs))); } if (outputs.size() == 1) { return outputs.at(0); } @@ -74,7 +74,7 @@ class ScalarAddFunctor { JUST(attrs.SetAttr("has_int_operand", true)); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } else { - UNIMPLEMENTED_THEN_RETURN(); + UNIMPLEMENTED_THEN_RETURN() << "The scalar in ScalarAdd shoule be float or int."; } } @@ -100,7 +100,7 @@ class ScalarMulFunctor { JUST(attrs.SetAttr("has_int_operand", true)); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } else { - UNIMPLEMENTED_THEN_RETURN(); + UNIMPLEMENTED_THEN_RETURN() << "The scalar in ScalarMul shoule be float or int."; } } From ab998e73e6f3ecde988e00dcd5e58918eeafda0e Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 09:09:23 +0800 Subject: [PATCH 07/42] Migrate binary and activation ops. --- oneflow/core/functional/functional_api.yaml | 100 ++++++++++ .../functional/impl/activation_functor.cpp | 187 ++++++++++++++++++ .../core/functional/impl/binary_functor.cpp | 155 +++++++++++++++ .../{add_functor.cpp => binary_functor.h} | 26 ++- oneflow/python/nn/modules/activation.py | 66 ++----- oneflow/python/nn/modules/math_ops.py | 66 +------ oneflow/python/nn/modules/prelu.py | 3 +- 7 files changed, 479 insertions(+), 124 deletions(-) create mode 100644 oneflow/core/functional/impl/activation_functor.cpp create mode 100644 oneflow/core/functional/impl/binary_functor.cpp rename oneflow/core/functional/impl/{add_functor.cpp => binary_functor.h} (66%) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 649a25a86a5..3b0d964400e 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -31,10 +31,62 @@ signature: "Tensor ScalarAdd(Tensor x, *, Scalar alpha)" bind_python: True +- name: "add_scalar_by_tensor" + signature: "Tensor ScalarAddByTensor(Tensor x, Tensor scalar)" + bind_python: True + +- name: "broadcast_add" + signature: "Tensor BroadcastAdd(Tensor x, Tensor y)" + bind_python: True + +- name: "sub_scalar_by_tensor" + signature: "Tensor ScalarSubByTensor(Tensor x, Tensor scalar)" + bind_python: True + +- name: "broadcast_sub" + signature: "Tensor BroadcastSub(Tensor x, Tensor y)" + bind_python: True + +- name: "mul" + signature: "Tensor Multiply(Tensor x, Tensor y)" + bind_python: True + - name: "mul_scalar" signature: "Tensor ScalarMul(Tensor x, *, Scalar alpha)" bind_python: True +- name: "mul_scalar_by_tensor" + signature: "Tensor ScalarMulByTensor(Tensor x, Tensor scalar)" + bind_python: True + +- name: "broadcast_mul" + signature: "Tensor BroadcastMul(Tensor x, Tensor y)" + bind_python: True + +- name: "div_scalar_by_tensor" + signature: "Tensor ScalarDivByTensor(Tensor x, Tensor scalar)" + bind_python: True + +- name: "broadcast_div" + signature: "Tensor BroadcastDiv(Tensor x, Tensor y)" + bind_python: True + +- name: "broadcast_equal" + signature: "Tensor BroadcastEqual(Tensor x, Tensor y)" + bind_python: True + +- name: "broadcast_greater" + signature: "Tensor BroadcastGreater(Tensor x, Tensor y)" + bind_python: True + +- name: "broadcast_less" + signature: "Tensor BroadcastLess(Tensor x, Tensor y)" + bind_python: True + +- name: "pow" + signature: "Tensor Pow(Tensor x, Tensor y)" + bind_python: True + - name: "pow_scalar" signature: "Tensor ScalarPow(Tensor x, *, Scalar alpha)" bind_python: True @@ -83,14 +135,58 @@ signature: "Tensor Square(Tensor x)" bind_python: True +- name: "relu" + signature: "Tensor Relu(Tensor x)" + bind_python: True + +- name: "hardtanh" + signature: "Tensor HardTanh(Tensor x, *, Double min_val, Double max_val)" + bind_python: True + +- name: "hardtanh_grad" + signature: "Tensor HardTanhGrad(Tensor y, Tensor dy, *, Double min_val, Double max_val)" + bind_python: False + - name: "tanh" signature: "Tensor Tanh(Tensor x)" bind_python: True +- name: "elu" + signature: "Tensor Elu(Tensor x, *, Double alpha)" + bind_python: True + +- name: "elu_grad" + signature: "Tensor EluGrad(Tensor x, Tensor dy, *, Double alpha)" + bind_python: False + +- name: "gelu" + signature: "Tensor Gelu(Tensor x)" + bind_python: True + - name: "sigmoid" signature: "Tensor Sigmoid(Tensor x)" bind_python: True +- name: "hardsigmoid" + signature: "Tensor HardSigmoid(Tensor x)" + bind_python: True + +- name: "softmax" + signature: "Tensor Softmax(Tensor x)" + bind_python: True + +- name: "hardswish" + signature: "Tensor HardSwish(Tensor x)" + bind_python: True + +- name: "leaky_relu" + signature: "Tensor LeakyRelu(Tensor x, *, Float alpha)" + bind_python: True + +- name: "leaky_relu_grad" + signature: "Tensor LeakyReluGrad(Tensor x, Tensor dy, *, Float alpha)" + bind_python: False + - name: "normalization" signature: "Tensor Normalization(Tensor x, Tensor moving_mean, Tensor moving_variance, @@ -118,6 +214,10 @@ signature: "Tensor Negative(Tensor x)" bind_python: True +- name: "prelu" + signature: "Tensor PRelu(Tensor x, Tensor alpha)" + bind_python: True + - name: "abs" signature: "Tensor Abs(Tensor x)" bind_python: True diff --git a/oneflow/core/functional/impl/activation_functor.cpp b/oneflow/core/functional/impl/activation_functor.cpp new file mode 100644 index 00000000000..c1efd54a10f --- /dev/null +++ b/oneflow/core/functional/impl/activation_functor.cpp @@ -0,0 +1,187 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/functional/impl/unary_functor.h" +#include "oneflow/core/functional/impl/binary_functor.h" + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/scalar.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class ReluFunctor : public UnaryFunctor { + public: + ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("in").Output("out").Build()); } +}; + +class PReluFunctor : public BinaryFunctor { + public: + PReluFunctor() { + op_ = CHECK_JUST(one::OpBuilder("prelu").Input("x").Input("alpha").Output("y").Build()); + } +}; + +class HardTanhFunctor { + public: + HardTanhFunctor() { + op_ = CHECK_JUST(one::OpBuilder("hardtanh").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const double& min_val, + const double& max_val) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("min_val", min_val)); + JUST(attrs.SetAttr("max_val", max_val)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class HardTanhGradFunctor { + public: + HardTanhGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("hardtanh_grad").Input("y").Input("dy").Output("dx").Build()); + } + Maybe operator()(const std::shared_ptr& y, + const std::shared_ptr& dy, const double& min_val, + const double& max_val) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("min_val", min_val)); + JUST(attrs.SetAttr("max_val", max_val)); + return OpInterpUtil::Dispatch(*op_, {y, dy}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class EluFunctor { + public: + EluFunctor() { op_ = CHECK_JUST(one::OpBuilder("elu").Input("in").Output("out").Build()); } + Maybe operator()(const std::shared_ptr& x, const double& alpha) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("alpha", alpha)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class EluGradFunctor { + public: + EluGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("elu_grad").Input("x").Input("dy").Output("dx").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& dy, const double& alpha) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("alpha", alpha)); + return OpInterpUtil::Dispatch(*op_, {x, dy}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class GeluFunctor : public UnaryFunctor { + public: + GeluFunctor() { op_ = CHECK_JUST(one::OpBuilder("gelu").Input("in").Output("out").Build()); } +}; + +class HardSigmoidFunctor : public UnaryFunctor { + public: + HardSigmoidFunctor() { + op_ = CHECK_JUST(one::OpBuilder("hardsigmoid").Input("in").Output("out").Build()); + } +}; + +class SoftmaxFunctor : public UnaryFunctor { + public: + SoftmaxFunctor() { + op_ = CHECK_JUST(one::OpBuilder("softmax").Input("in").Output("out").Build()); + } +}; + +class HardSwishFunctor : public UnaryFunctor { + public: + HardSwishFunctor() { + op_ = CHECK_JUST(one::OpBuilder("hardswish").Input("in").Output("out").Build()); + } +}; + +class LeakyReluFunctor { + public: + LeakyReluFunctor() { + op_ = CHECK_JUST(one::OpBuilder("leaky_relu").Input("x").Output("y").Build()); + } + Maybe operator()(const std::shared_ptr& x, const float& alpha) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("alpha", alpha)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class LeakyReluGradFunctor { + public: + LeakyReluGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("leaky_relu_grad").Input("x").Input("dy").Output("dx").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& dy, const float& alpha) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("alpha", alpha)); + return OpInterpUtil::Dispatch(*op_, {x, dy}, attrs); + } + + private: + std::shared_ptr op_; +}; + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor("Relu"); + m.add_functor("PRelu"); + m.add_functor("HardTanh"); + m.add_functor("HardTanhGrad"); + m.add_functor("Elu"); + m.add_functor("EluGrad"); + m.add_functor("Gelu"); + m.add_functor("HardSigmoid"); + m.add_functor("Softmax"); + m.add_functor("HardSwish"); + m.add_functor("LeakyRelu"); + m.add_functor("LeakyReluGrad"); +}; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/impl/binary_functor.cpp b/oneflow/core/functional/impl/binary_functor.cpp new file mode 100644 index 00000000000..d61b0bdd975 --- /dev/null +++ b/oneflow/core/functional/impl/binary_functor.cpp @@ -0,0 +1,155 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/functional/impl/binary_functor.h" + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/scalar.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class AddFunctor : public BinaryFunctor { + public: + AddFunctor() { op_ = CHECK_JUST(one::OpBuilder("add_n").Input("in", 2).Output("out").Build()); } +}; + +class MultiplyFunctor : public BinaryFunctor { + public: + MultiplyFunctor() { + op_ = CHECK_JUST(one::OpBuilder("multiply").Input("x").Input("y").Output("out").Build()); + } +}; + +class PowFunctor : public BinaryFunctor { + public: + PowFunctor() { + op_ = CHECK_JUST(one::OpBuilder("pow").Input("x").Input("y").Output("z").Build()); + } +}; + +class BroadcastAddFunctor : public BinaryFunctor { + public: + BroadcastAddFunctor() { + op_ = CHECK_JUST(one::OpBuilder("broadcast_add").Input("x").Input("y").Output("z").Build()); + } +}; + +class BroadcastSubFunctor : public BinaryFunctor { + public: + BroadcastSubFunctor() { + op_ = CHECK_JUST(one::OpBuilder("broadcast_sub").Input("x").Input("y").Output("z").Build()); + } +}; + +class BroadcastMulFunctor : public BinaryFunctor { + public: + BroadcastMulFunctor() { + op_ = CHECK_JUST(one::OpBuilder("broadcast_mul").Input("x").Input("y").Output("z").Build()); + } +}; + +class BroadcastDivFunctor : public BinaryFunctor { + public: + BroadcastDivFunctor() { + op_ = CHECK_JUST(one::OpBuilder("broadcast_div").Input("x").Input("y").Output("z").Build()); + } +}; + +class BroadcastEqualFunctor : public BinaryFunctor { + public: + BroadcastEqualFunctor() { + op_ = CHECK_JUST(one::OpBuilder("broadcast_equal").Input("x").Input("y").Output("z").Build()); + } +}; + +class BroadcastGreaterFunctor : public BinaryFunctor { + public: + BroadcastGreaterFunctor() { + op_ = CHECK_JUST(one::OpBuilder("broadcast_greater").Input("x").Input("y").Output("z").Build()); + } +}; + +class BroadcastLessFunctor : public BinaryFunctor { + public: + BroadcastLessFunctor() { + op_ = CHECK_JUST(one::OpBuilder("broadcast_less").Input("x").Input("y").Output("z").Build()); + } +}; + +class ScalarAddByTensorFunctor : public BinaryFunctor { + public: + ScalarAddByTensorFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("scalar_add_by_tensor").Input("x").Input("scalar").Output("y").Build()); + } +}; + +class ScalarSubByTensorFunctor : public BinaryFunctor { + public: + ScalarSubByTensorFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("scalar_sub_by_tensor").Input("x").Input("scalar").Output("y").Build()); + } +}; + +class ScalarMulByTensorFunctor : public BinaryFunctor { + public: + ScalarMulByTensorFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("scalar_mul_by_tensor").Input("x").Input("scalar").Output("y").Build()); + } +}; + +class ScalarDivByTensorFunctor : public BinaryFunctor { + public: + ScalarDivByTensorFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("scalar_div_by_tensor").Input("x").Input("scalar").Output("y").Build()); + } +}; + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor("Add"); + m.add_functor("Multiply"); + m.add_functor("Pow"); + m.add_functor("BroadcastAdd"); + m.add_functor("BroadcastSub"); + m.add_functor("BroadcastMul"); + m.add_functor("BroadcastDiv"); + m.add_functor("BroadcastEqual"); + m.add_functor("BroadcastGreater"); + m.add_functor("BroadcastLess"); + m.add_functor("ScalarAddByTensor"); + m.add_functor("ScalarSubByTensor"); + m.add_functor("ScalarMulByTensor"); + m.add_functor("ScalarDivByTensor"); +}; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/impl/add_functor.cpp b/oneflow/core/functional/impl/binary_functor.h similarity index 66% rename from oneflow/core/functional/impl/add_functor.cpp rename to oneflow/core/functional/impl/binary_functor.h index 16a9ccc4c04..9090b01f4d0 100644 --- a/oneflow/core/functional/impl/add_functor.cpp +++ b/oneflow/core/functional/impl/binary_functor.h @@ -14,14 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/framework/attr_map.h" -#include "oneflow/core/framework/op_builder.h" +#ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_ +#define ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_ + #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" -#include "oneflow/core/framework/tensor_tuple.h" -#include "oneflow/core/functional/function_library.h" -#include "oneflow/core/functional/scalar.h" namespace oneflow { namespace one { @@ -29,24 +27,24 @@ namespace functional { namespace impl { -class AddFunctor { +class BinaryFunctor { public: - AddFunctor() { - add_op_ = CHECK_JUST(one::OpBuilder("add_n").Input("in", 2).Output("out").Build()); - } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& y) const { - return OpInterpUtil::Dispatch(*add_op_, {x, y}); + return OpInterpUtil::Dispatch(*op_, {x, y}); } - private: - std::shared_ptr add_op_; + protected: + BinaryFunctor() = default; + virtual ~BinaryFunctor() = default; + + std::shared_ptr op_; }; } // namespace impl -ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Add"); } - } // namespace functional } // namespace one } // namespace oneflow + +#endif // ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_ diff --git a/oneflow/python/nn/modules/activation.py b/oneflow/python/nn/modules/activation.py index ab900285dc1..efa1e786173 100644 --- a/oneflow/python/nn/modules/activation.py +++ b/oneflow/python/nn/modules/activation.py @@ -72,11 +72,9 @@ class ReLU(Module): def __init__(self, inplace: bool = False): super().__init__() - self._op = flow.builtin_op("relu").Input("in").Output("out").Build() def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.relu(x) @oneflow_export("nn.ReLU6") @@ -120,18 +118,9 @@ class ReLU6(Module): def __init__(self, inplace: bool = False): super().__init__() - self._op = ( - flow.builtin_op("hardtanh") - .Input("in") - .Attr("min_val", 0.0) - .Attr("max_val", 6.0) - .Output("out") - .Build() - ) def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.hardtanh(x, min_val=0.0, max_val=6.0) @oneflow_export("nn.Tanh") @@ -254,17 +243,10 @@ class ELU(Module): def __init__(self, alpha: float = 1.0, inplace: bool = False): super().__init__() - self._op = ( - flow.builtin_op("elu") - .Input("in") - .Attr("alpha", alpha) - .Output("out") - .Build() - ) + self.alpha = alpha def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.elu(x, alpha=self.alpha) @oneflow_export("nn.GELU") @@ -303,11 +285,9 @@ class GELU(Module): def __init__(self): super().__init__() - self._op = flow.builtin_op("gelu").Input("in").Output("out").Build() def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.gelu(x) @oneflow_export("gelu") @@ -454,11 +434,9 @@ class Hardsigmoid(Module): def __init__(self, inplace: bool = False): super().__init__() - self._op = flow.builtin_op("hardsigmoid").Input("in").Output("out").Build() def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.hardsigmoid(x) @oneflow_export("nn.Softmax") @@ -467,14 +445,13 @@ class Softmax(Module): def __init__(self, dim: Optional[int] = None): super().__init__() self.axis = -1 if dim is None else dim - self._op = flow.builtin_op("softmax").Input("in").Output("out").Build() def forward(self, x): need_transpose, permute = _softmax_need_transpose(x, self.axis) if need_transpose: x = flow.F.transpose(x, perm=permute) - res = self._op(x)[0] + res = flow.F.softmax(x) if need_transpose: res = flow.F.transpose(res, perm=permute) return res @@ -736,11 +713,9 @@ class Hardswish(Module): def __init__(self, inplace: bool = False): super().__init__() - self._op = flow.builtin_op("hardswish").Input("in").Output("out").Build() def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.hardswish(x) @oneflow_export("nn.Hardtanh") @@ -811,18 +786,12 @@ def __init__( "keyword argument max_value is deprecated and rename to max_val" ) max_val = max_value - self._op = ( - flow.builtin_op("hardtanh") - .Input("in") - .Attr("min_val", min_val) - .Attr("max_val", max_val) - .Output("out") - .Build() - ) + + self.min_val = min_val + self.max_val = max_val def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.hardtanh(x, min_val=self.min_val, max_val=self.max_val) @oneflow_export("nn.LeakyReLU") @@ -863,17 +832,10 @@ class LeakyReLU(Module): def __init__(self, negative_slope: float = 1e-2, inplace: bool = False): super().__init__() - self._op = ( - flow.builtin_op("leaky_relu") - .Input("x") - .Attr("alpha", negative_slope) - .Output("y") - .Build() - ) + self.negative_slope = negative_slope def forward(self, x): - res = self._op(x)[0] - return res + return flow.F.leaky_relu(x, alpha=self.negative_slope) if __name__ == "__main__": diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py index 2078f880add..fb0070b44b0 100644 --- a/oneflow/python/nn/modules/math_ops.py +++ b/oneflow/python/nn/modules/math_ops.py @@ -42,38 +42,25 @@ def forward(self, x): class ScalarMulByTensor(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("scalar_mul_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .Build() - ) def forward(self, x, y): - return self._op(x, y)[0] + return flow.F.mul_scalar_by_tensor(x, y) class ElementwiseMul(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("multiply").Input("x").Input("y").Output("out").Build() - ) def forward(self, x, y): - return self._op(x, y)[0] + return flow.F.mul(x, y) class BroadcastMul(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_mul").Input("x").Input("y").Output("z").Build() - ) def forward(self, x, y): - return self._op(x, y)[0] + return flow.F.broadcast_mul(x, y) @oneflow_export("mul") @@ -190,27 +177,17 @@ def variance_op(input, dim=None, keepdim=False): class ScalarSubByTensor(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("scalar_sub_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .Build() - ) def forward(self, x, y): - return self._op(x, y)[0] + return flow.F.sub_scalar_by_tensor(x, y) class BroadcastSub(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_sub").Input("x").Input("y").Output("z").Build() - ) def forward(self, x, y): - return self._op(x, y)[0] + return flow.F.broadcast_sub(x, y) class ScalarAdd(Module): @@ -281,27 +258,17 @@ def _sub(x, y): class BroadcastDiv(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_div").Input("x").Input("y").Output("z").Build() - ) def forward(self, x, y): - return self._op(x, y)[0] + return flow.F.broadcast_div(x, y) class ScalarDivByTensor(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("scalar_div_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .Build() - ) def forward(self, x, scalar): - return self._op(x, scalar)[0] + return flow.F.div_scalar_by_tensor(x, scalar) @oneflow_export("div") @@ -401,16 +368,9 @@ def _reciprocal(x): class ScalarAddByTensor(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("scalar_add_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .Build() - ) def forward(self, x, y): - return self._op(x, y)[0] + return flow.F.add_scalar_by_tensor(x, y) class ElementwiseAdd(Module): @@ -424,12 +384,9 @@ def forward(self, x, y): class BroadcastAdd(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_add").Input("x").Input("y").Output("z").Build() - ) def forward(self, x, y): - return self._op(x, y)[0] + return flow.F.broadcast_add(x, y) @oneflow_export("add") @@ -1042,15 +999,12 @@ def std_op(tensor, dim, unbiased=True, keepdim=False): class Pow(Module): def __init__(self) -> None: super().__init__() - self._elementwise_pow_op = ( - flow.builtin_op("pow").Input("x").Input("y").Output("z").Build() - ) def forward(self, x, y): if isinstance(y, (int, float)): return flow.F.pow_scalar(x, alpha=y) else: - return self._elementwise_pow_op(x, y)[0] + return flow.F.pow(x, y) @oneflow_export("pow") diff --git a/oneflow/python/nn/modules/prelu.py b/oneflow/python/nn/modules/prelu.py index 43968c593ac..b461bb42c81 100644 --- a/oneflow/python/nn/modules/prelu.py +++ b/oneflow/python/nn/modules/prelu.py @@ -71,13 +71,12 @@ def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None: super().__init__() self.num_parameters = num_parameters self.weight = flow.nn.Parameter(flow.Tensor(num_parameters, 1, 1).fill_(init)) - self.op = flow.builtin_op("prelu").Input("x").Input("alpha").Output("y").Build() def forward(self, x): assert ( self.num_parameters == 1 or self.num_parameters == x.shape[1] ), f"num_parameters in prelu must be 1 or {x.shape[1]}" - return self.op(x, self.weight)[0] + return flow.F.prelu(x, self.weight) if __name__ == "__main__": From ae100846d8c9b82bab8bfcdf65475cc83c5d8c5d Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 09:54:14 +0800 Subject: [PATCH 08/42] Migrate array ops. --- oneflow/api/python/functional/python_arg.cpp | 16 + .../core/autograd/gradient_funcs/gather.cpp | 18 +- oneflow/core/functional/functional_api.yaml | 87 ++++ .../core/functional/impl/array_functor.cpp | 418 ++++++++++++++++++ oneflow/core/functional/value_types.h | 5 + oneflow/python/nn/modules/arange.py | 8 +- oneflow/python/nn/modules/argmax.py | 29 +- oneflow/python/nn/modules/argwhere.py | 14 +- oneflow/python/nn/modules/broadcast_like.py | 10 +- oneflow/python/nn/modules/cast.py | 10 +- oneflow/python/nn/modules/concat.py | 8 +- oneflow/python/nn/modules/constant.py | 47 +- oneflow/python/nn/modules/expand.py | 5 +- oneflow/python/nn/modules/flatten.py | 12 +- oneflow/python/nn/modules/gather.py | 12 +- oneflow/python/nn/modules/reshape.py | 10 +- oneflow/python/nn/modules/slice.py | 47 +- oneflow/python/nn/modules/sparse.py | 10 +- oneflow/python/nn/modules/squeeze.py | 10 +- oneflow/python/nn/modules/to.py | 8 +- oneflow/python/nn/modules/unsqueeze.py | 3 +- oneflow/python/nn/modules/upsample.py | 23 +- oneflow/python/nn/modules/where.py | 10 +- oneflow/user/kernels/constant_kernel.cpp | 22 +- tools/generate_functional_api.py | 6 +- 25 files changed, 597 insertions(+), 251 deletions(-) create mode 100644 oneflow/core/functional/impl/array_functor.cpp diff --git a/oneflow/api/python/functional/python_arg.cpp b/oneflow/api/python/functional/python_arg.cpp index 7c27719f96a..c572a2885fd 100644 --- a/oneflow/api/python/functional/python_arg.cpp +++ b/oneflow/api/python/functional/python_arg.cpp @@ -134,6 +134,22 @@ Maybe PythonArg::ObjectAs() const { return kInvalidDataType; } +template<> +Maybe PythonArg::ObjectAs() const { + py::object obj = Borrow(); + if (detail::isinstance(obj)) { + return *JUST(detail::cast>(obj)); + } else if (detail::isinstance(obj) || detail::isinstance(obj)) { + const auto& shape = JUST(ObjectAs>()); + DimVector dim_vec(shape->size()); + for (int i = 0; i < shape->size(); ++i) { dim_vec[i] = shape->at(i); } + return std::make_shared(std::move(dim_vec)); + } else { + UNIMPLEMENTED_THEN_RETURN() << "Can not convert object to Shape from " + << *JUST(detail::cast(py::str(py::type::of(obj)))); + } +} + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/gather.cpp b/oneflow/core/autograd/gradient_funcs/gather.cpp index 33fdb9ed8cf..0f83aceb71f 100644 --- a/oneflow/core/autograd/gradient_funcs/gather.cpp +++ b/oneflow/core/autograd/gradient_funcs/gather.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_expr_helper.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { @@ -37,15 +38,12 @@ class Gather : public OpExprGradFunction { private: AttrMap base_attrs_; - std::shared_ptr grad_op_; }; Maybe Gather::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); - const std::string& op_name = fw_op_expr->op_name(); - grad_op_ = JUST(op_expr_helper::UnsortedSegmentSumLikeOp(/*axis=*/0, GradientOpName(op_name))); return Maybe::Ok(); } @@ -54,8 +52,8 @@ Maybe Gather::Capture(GatherInterpState* ctx, const TensorTuple& inputs, ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } - ctx->SaveTensorForBackward(inputs.at(1)); // indices - ctx->SaveTensorForBackward(inputs.at(0)); // in + ctx->SaveTensorForBackward(inputs.at(0)); + ctx->SaveTensorForBackward(inputs.at(1)); ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr("axis")); @@ -66,14 +64,10 @@ Maybe Gather::Apply(const GatherInterpState* ctx, const TensorTuple& out_g TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); - const std::shared_ptr& indices = ctx->SavedTensors().at(0); - const std::shared_ptr& in = ctx->SavedTensors().at(1); - - MutableAttrMap attrs; - JUST(attrs.SetAttr("axis", ctx->axis)); - in_grads->resize(3); + const auto& x = ctx->SavedTensors().at(0); + const auto& indices = ctx->SavedTensors().at(1); in_grads->at(0) = - JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grads.at(0), indices, in}, attrs)); + JUST(functional::UnsortedSegmentSumLike(out_grads.at(0), indices, x, ctx->axis)); return Maybe::Ok(); } diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 3b0d964400e..17d71dd72a8 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -198,18 +198,66 @@ signature: "Tensor Range(*, Int64 start, Int64 limit, Int64 delta, DataType dtype=kInt64)" bind_python: True +- name: "flatten" + signature: "Tensor Flatten(Tensor x, *, Int32 start_dim=0, Int32 end_dim=-1)" + bind_python: True + - name: "argmax" signature: "Tensor ArgMax(Tensor x)" bind_python: True +- name: "argwhere" + signature: "TensorTuple ArgWhere(Tensor x, *, DataType dtype=kInt32)" + bind_python: True + +- name: "broadcast_like" + signature: "Tensor BroadcastLike(Tensor x, Tensor like, *, Int32List broadcast_axes)" + bind_python: True + - name: "cast" signature: "Tensor Cast(Tensor x, *, DataType dtype)" bind_python: True +- name: "constant" + signature: "Tensor Constant(*, Shape shape, Scalar value, DataType dtype)" + bind_python: True + +- name: "zeros_like" + signature: "Tensor ZerosLike(Tensor x)" + bind_python: True + +- name: "ones_like" + signature: "Tensor OnesLike(Tensor x)" + bind_python: True + +- name: "concat" + signature: "Tensor Concat(TensorTuple inputs, *, Int64 axis, Int64 max_dim_size)" + bind_python: True + +- name: "expand" + signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)" + bind_python: True + +- name: "expand_dims" + signature: "Tensor ExpandDims(Tensor x, *, Int32 axis)" + bind_python: True + - name: "exp" signature: "Tensor Exp(Tensor x)" bind_python: True +- name: "gather" + signature: "Tensor Gather(Tensor x, Tensor indices, *, Int64 axis)" + bind_python: True + +- name: "dim_gather" + signature: "Tensor DimGather(Tensor x, Tensor indices, *, Int32 dim)" + bind_python: True + +- name: "where" + signature: "Tensor Where(Tensor condition, Tensor x, Tensor y)" + bind_python: True + - name: "negative" signature: "Tensor Negative(Tensor x)" bind_python: True @@ -218,6 +266,40 @@ signature: "Tensor PRelu(Tensor x, Tensor alpha)" bind_python: True +- name: "reshape" + signature: "Tensor Reshape(Tensor x, *, Shape shape)" + bind_python: True + +- name: "slice" + signature: "Tensor Slice(Tensor x, *, Int64List start, Int64List stop, Int64List step)" + bind_python: True + +- name: "slice_update" + signature: "Tensor SliceUpdate(Tensor x, Tensor update, *, Int64List start, Int64List stop, Int64List step)" + bind_python: True + +- name: "logical_slice" + signature: "Tensor LogicalSlice(Tensor x, *, Int64List start, Int64List stop, Int64List step)" + bind_python: True + +- name: "logical_slice_assign" + signature: "Void LogicalSliceAssign(Tensor ref, Tensor value, *, Int64List start, Int64List stop, Int64List step)" + bind_python: True + +- name: "squeeze" + signature: "Tensor Squeeze(Tensor x, *, Int32List dim)" + bind_python: True + +- name: "copy" + signature: "Tensor Copy(Tensor x, *, String device_type, Int64 device_id)" + bind_python: True + +- name: "upsample" + signature: + "Tensor Upsample(Tensor x, *, Float height_scale, Float width_scale, Bool align_corners, + String interpolation, String data_format=\"channels_first\")" + bind_python: True + - name: "abs" signature: "Tensor Abs(Tensor x)" bind_python: True @@ -293,3 +375,8 @@ - name: "softplus" signature: "Tensor Softplus(Tensor x)" bind_python: True + +- name: "unsorted_segment_sum_like" + signature: + "Tensor UnsortedSegmentSumLike(Tensor x, Tensor segment_ids, Tensor like, *, Int64 axis)" + bind_python: False diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp new file mode 100644 index 00000000000..6f23b4e19df --- /dev/null +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -0,0 +1,418 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/impl/common.h" +#include "oneflow/core/functional/impl/unary_functor.h" +#include "oneflow/core/functional/scalar.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class ConstantFunctor { + public: + ConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder("constant").Output("out").Build()); } + Maybe operator()(const Shape& shape, const Scalar& value, const DataType& dtype) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("shape", shape)); + JUST(attrs.SetAttr("dtype", dtype)); + if (IsIntegralDataType(dtype)) { + JUST(attrs.SetAttr("is_floating_value", false)); + JUST(attrs.SetAttr("integer_value", JUST(value.As()))); + } else { + JUST(attrs.SetAttr("is_floating_value", true)); + JUST(attrs.SetAttr("floating_value", JUST(value.As()))); + } + return OpInterpUtil::Dispatch(*op_, {}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ZerosLikeFunctor : public UnaryFunctor { + public: + ZerosLikeFunctor() { + op_ = CHECK_JUST(one::OpBuilder("zero_like").Input("like").Output("out").Build()); + } +}; + +class OnesLikeFunctor : public UnaryFunctor { + public: + OnesLikeFunctor() { + op_ = CHECK_JUST(one::OpBuilder("ones_like").Input("like").Output("out").Build()); + } +}; + +class FlattenFunctor { + public: + FlattenFunctor() { + op_ = CHECK_JUST(one::OpBuilder("flatten").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const int32_t& start_dim, + const int32_t& end_dim) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("start_dim", start_dim)); + JUST(attrs.SetAttr("end_dim", end_dim)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class WhereFunctor { + public: + WhereFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("where").Input("condition").Input("x").Input("y").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& condition, + const std::shared_ptr& x, + const std::shared_ptr& y) const { + return OpInterpUtil::Dispatch(*op_, {condition, x, y}); + } + + private: + std::shared_ptr op_; +}; + +class ArgWhereFunctor { + public: + ArgWhereFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("argwhere").Input("input").Output("output").Output("output_size").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const DataType& dtype) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("dtype", dtype)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class BroadcastLikeFunctor { + public: + BroadcastLikeFunctor() { + op_ = CHECK_JUST(one::OpBuilder("broadcast_like").Input("x").Input("like").Output("y").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& like, + const std::vector& broadcast_axes) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("broadcast_axes", broadcast_axes)); + return OpInterpUtil::Dispatch(*op_, {x, like}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ConcatFunctor { + public: + ConcatFunctor() { + ops_.resize(kMaxInputCount); + for (int n = 1; n < ops_.size(); ++n) { + ops_[n] = CHECK_JUST(one::OpBuilder("concat").Input("in", n + 1).Output("out").Build()); + } + } + Maybe operator()(const TensorTuple& inputs, const int64_t& axis, + const int64_t& max_dim_size) const { + CHECK_GE_OR_RETURN(inputs.size(), 2); + MutableAttrMap attrs; + JUST(attrs.SetAttr("axis", axis)); + JUST(attrs.SetAttr("max_dim_size", max_dim_size)); + TensorTuple outputs; + for (int i = 0; i < inputs.size(); i += kMaxInputCount) { + size_t size = (i + kMaxInputCount) < inputs.size() ? kMaxInputCount : inputs.size() - i; + TensorTuple partial_inputs(size); + for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; } + outputs.push_back( + JUST(OpInterpUtil::Dispatch(*ops_.at(size - 1), partial_inputs, attrs))); + } + if (outputs.size() == 1) { return outputs.at(0); } + return this->operator()(outputs, axis, max_dim_size); + } + + private: + std::vector> ops_; +}; + +class ExpandFunctor { + public: + ExpandFunctor() { op_ = CHECK_JUST(one::OpBuilder("expand").Input("in").Output("out").Build()); } + Maybe operator()(const std::shared_ptr& x, + const std::vector& in_shape, + const std::vector& out_shape, + const std::vector& stride) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("in_shape", in_shape)); + JUST(attrs.SetAttr>("out_shape", out_shape)); + JUST(attrs.SetAttr>("stride", stride)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ExpandDimsFunctor { + public: + ExpandDimsFunctor() { + op_ = CHECK_JUST(one::OpBuilder("expand_dims").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const int32_t& axis) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("axis", axis)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class GatherFunctor { + public: + GatherFunctor() { + op_ = CHECK_JUST(one::OpBuilder("gather").Input("in").Input("indices").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& indices, const int64_t& axis) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("axis", axis)); + return OpInterpUtil::Dispatch(*op_, {x, indices}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class DimGatherFunctor { + public: + DimGatherFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("dim_gather").Input("input").Input("index").Output("output").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& indices, const int32_t& dim) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("dim", dim)); + return OpInterpUtil::Dispatch(*op_, {x, indices}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ReshapeFunctor { + public: + ReshapeFunctor() { + op_ = CHECK_JUST(one::OpBuilder("reshape").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const Shape& shape) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("shape", shape)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class SliceBaseFunctor { + public: + SliceBaseFunctor() = default; + virtual ~SliceBaseFunctor() = default; + Maybe operator()(const std::shared_ptr& x, const std::vector& start, + const std::vector& stop, + const std::vector& step) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("start", start)); + JUST(attrs.SetAttr>("stop", stop)); + JUST(attrs.SetAttr>("step", step)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + protected: + std::shared_ptr op_; +}; + +class SliceFunctor : public SliceBaseFunctor { + public: + SliceFunctor() { op_ = CHECK_JUST(one::OpBuilder("slice").Input("x").Output("y").Build()); } +}; + +class LogicalSliceFunctor : public SliceBaseFunctor { + public: + LogicalSliceFunctor() { + op_ = CHECK_JUST(one::OpBuilder("logical_slice").Input("x").Output("y").Build()); + } +}; + +class LogicalSliceAssignFunctor { + public: + LogicalSliceAssignFunctor() { + op_ = CHECK_JUST(one::OpBuilder("logical_slice_assign").Input("ref").Input("value").Build()); + } + Maybe operator()(const std::shared_ptr& ref, + const std::shared_ptr& value, + const std::vector& start, const std::vector& stop, + const std::vector& step) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("start", start)); + JUST(attrs.SetAttr>("stop", stop)); + JUST(attrs.SetAttr>("step", step)); + JUST(OpInterpUtil::Dispatch(*op_, {ref, value}, attrs)); + return Maybe::Ok(); + } + + private: + std::shared_ptr op_; +}; + +class SliceUpdateFunctor { + public: + SliceUpdateFunctor() { + op_ = CHECK_JUST(one::OpBuilder("slice_update").Input("x").Input("update").Output("y").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& update, + const std::vector& start, const std::vector& stop, + const std::vector& step) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("start", start)); + JUST(attrs.SetAttr>("stop", stop)); + JUST(attrs.SetAttr>("step", step)); + return OpInterpUtil::Dispatch(*op_, {x, update}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class SqueezeFunctor { + public: + SqueezeFunctor() { + op_ = CHECK_JUST(one::OpBuilder("squeeze").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::vector& axes) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("axes", axes)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class CopyFunctor { + public: + CopyFunctor() { op_ = CHECK_JUST(one::OpBuilder("copy").Input("in").Output("out").Build()); } + Maybe operator()(const std::shared_ptr& x, const std::string& device_type, + const int64_t& device_id) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("device_type", device_type)); + JUST(attrs.SetAttr("device_id", device_id)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class UpsampleFunctor { + public: + UpsampleFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample").Input("x").Output("y").Build()); } + Maybe operator()(const std::shared_ptr& x, const float& height_scale, + const float& width_scale, const bool& align_corners, + const std::string& interpolation, const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("height_scale", height_scale)); + JUST(attrs.SetAttr("width_scale", width_scale)); + JUST(attrs.SetAttr("align_corners", align_corners)); + JUST(attrs.SetAttr("interpolation", interpolation)); + JUST(attrs.SetAttr("data_format", data_format)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class UnsortedSegmentSumLikeFunctor { + public: + UnsortedSegmentSumLikeFunctor() { + op_ = CHECK_JUST(one::OpBuilder("unsorted_segment_sum_like") + .Input("data") + .Input("segment_ids") + .Input("like") + .Output("out") + .Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& segment_ids, + const std::shared_ptr& like, const int64_t& axis) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("axis", axis)); + return OpInterpUtil::Dispatch(*op_, {x, segment_ids, like}, attrs); + } + + private: + std::shared_ptr op_; +}; + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor("Constant"); + m.add_functor("ZerosLike"); + m.add_functor("OnesLike"); + m.add_functor("Flatten"); + m.add_functor("Where"); + m.add_functor("ArgWhere"); + m.add_functor("BroadcastLike"); + m.add_functor("Concat"); + m.add_functor("Expand"); + m.add_functor("ExpandDims"); + m.add_functor("Gather"); + m.add_functor("DimGather"); + m.add_functor("Reshape"); + m.add_functor("Slice"); + m.add_functor("LogicalSliceAssign"); + m.add_functor("LogicalSlice"); + m.add_functor("SliceUpdate"); + m.add_functor("Squeeze"); + m.add_functor("Copy"); + m.add_functor("Upsample"); + m.add_functor("UnsortedSegmentSumLike"); +}; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/value_types.h b/oneflow/core/functional/value_types.h index 4d0f1b0ad21..cd8904face4 100644 --- a/oneflow/core/functional/value_types.h +++ b/oneflow/core/functional/value_types.h @@ -61,6 +61,8 @@ enum ValueType { kDOUBLE_LIST, kBOOL_LIST, kSTRING_LIST, + kVOID_MAYBE, + kBOOL_MAYBE, kSCALAR, kTENSOR, kTENSOR_REF, @@ -99,6 +101,9 @@ VALUE_TYPE_OF_IMPL(std::vector, kDOUBLE_LIST); VALUE_TYPE_OF_IMPL(std::vector, kBOOL_LIST); VALUE_TYPE_OF_IMPL(std::vector, kSTRING_LIST); +VALUE_TYPE_OF_IMPL(Maybe, kVOID_MAYBE); +VALUE_TYPE_OF_IMPL(Maybe, kBOOL_MAYBE); + VALUE_TYPE_OF_IMPL(Scalar, kSCALAR); VALUE_TYPE_OF_IMPL(one::Tensor, kTENSOR); VALUE_TYPE_OF_IMPL(std::shared_ptr, kTENSOR_REF); diff --git a/oneflow/python/nn/modules/arange.py b/oneflow/python/nn/modules/arange.py index e11b67a6554..f9be484b69c 100644 --- a/oneflow/python/nn/modules/arange.py +++ b/oneflow/python/nn/modules/arange.py @@ -42,12 +42,10 @@ def __init__( self.device = device self.requires_grad = requires_grad - self._op_arange = ( - flow.builtin_op("range").Output("out").Attr("dtype", flow.int64).Build() - ) - def forward(self): - tmp = self._op_arange(start=self.start, delta=self.step, limit=self.end)[0] + tmp = flow.F.range( + start=self.start, limit=self.end, delta=self.step, dtype=flow.int64 + ) tmp.requires_grad = self.requires_grad if isinstance(self.device, str): diff --git a/oneflow/python/nn/modules/argmax.py b/oneflow/python/nn/modules/argmax.py index b22eebe1996..df38749d384 100644 --- a/oneflow/python/nn/modules/argmax.py +++ b/oneflow/python/nn/modules/argmax.py @@ -26,47 +26,28 @@ class Argmax(Module): def __init__(self, dim: int = None, keepdim: bool = False) -> None: super().__init__() - self._op_softmax_last_dim = ( - flow.builtin_op("argmax").Input("in").Output("out").Build() - ) - self._flatten = ( - flow.builtin_op("flatten") - .Input("in") - .Output("out") - .Attr("start_dim", 0) - .Attr("end_dim", -1) - .Build() - ) - self._transpose_op = ( - flow.builtin_op("transpose") - .Input("input") - .Output("output") - .Attr("perm", []) - .Build() - ) - self.dim = dim self.keepdim = keepdim def forward(self, input): if self.dim == None: - input = self._flatten(input)[0] + input = flow.F.flatten(input) self.dim = 0 num_axes = len(input.shape) axis = self.dim if self.dim >= 0 else self.dim + num_axes assert 0 <= axis < num_axes, "axis out of range" if axis == num_axes - 1: - x = self._op_softmax_last_dim(input)[0] + x = flow.F.argmax(input) if self.keepdim == True: x = flow.experimental.unsqueeze(x, -1) return x else: perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis) - x = self._transpose_op(input, perm=perm)[0] - x = self._op_softmax_last_dim(x)[0] + x = flow.F.transpose(input, perm=perm) + x = flow.F.argmax(x) x = flow.experimental.unsqueeze(x, -1) - x = self._transpose_op(x, perm=get_inversed_perm(perm))[0] + x = flow.F.transpose(x, perm=get_inversed_perm(perm)) if self.keepdim == False: x = x.squeeze(dim=[axis]) return x diff --git a/oneflow/python/nn/modules/argwhere.py b/oneflow/python/nn/modules/argwhere.py index 4ae279e7bcf..3f5779e3da6 100644 --- a/oneflow/python/nn/modules/argwhere.py +++ b/oneflow/python/nn/modules/argwhere.py @@ -26,19 +26,11 @@ def __init__(self, dtype) -> None: super().__init__() if dtype == None: dtype = flow.int32 - self._op = ( - flow.builtin_op("argwhere") - .Input("input") - .Output("output") - .Output("output_size") - .Attr("dtype", dtype) - .Build() - ) + self.dtype = dtype def forward(self, x): - size = self._op(x)[1].numpy() - res = self._op(x)[0] - slice_tup_list = [[0, int(size), 1]] + res, size = flow.F.argwhere(x, dtype=self.dtype) + slice_tup_list = [[0, int(size.numpy()), 1]] return flow.experimental.slice(res, slice_tup_list=slice_tup_list) diff --git a/oneflow/python/nn/modules/broadcast_like.py b/oneflow/python/nn/modules/broadcast_like.py index 597475ac3b1..c24846d3337 100644 --- a/oneflow/python/nn/modules/broadcast_like.py +++ b/oneflow/python/nn/modules/broadcast_like.py @@ -21,18 +21,10 @@ class BroadCastLike(Module): def __init__(self, broadcast_axes: None) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_like") - .Input("x") - .Input("like") - .Attr("broadcast_axes", broadcast_axes) - .Output("y") - .Build() - ) self.broadcast_axes = broadcast_axes def forward(self, x, like_tensor): - return self._op(x, like_tensor, broadcast_axes=self.broadcast_axes)[0] + return flow.F.broadcast_like(x, like_tensor, broadcast_axes=self.broadcast_axes) @oneflow_export("broadcast_like") diff --git a/oneflow/python/nn/modules/cast.py b/oneflow/python/nn/modules/cast.py index f7eea305860..dd20c2c2c0f 100644 --- a/oneflow/python/nn/modules/cast.py +++ b/oneflow/python/nn/modules/cast.py @@ -22,16 +22,10 @@ class Cast(Module): def __init__(self, dtype: flow.dtype) -> None: super().__init__() - self._op = ( - flow.builtin_op("cast") - .Input("in") - .Output("out") - .Attr("dtype", dtype) - .Build() - ) + self.dtype = dtype def forward(self, x): - return self._op(x)[0] + return flow.F.cast(x, dtype=self.dtype) @oneflow_export("cast") diff --git a/oneflow/python/nn/modules/concat.py b/oneflow/python/nn/modules/concat.py index 5a14da7a35c..8bbf97a6725 100644 --- a/oneflow/python/nn/modules/concat.py +++ b/oneflow/python/nn/modules/concat.py @@ -22,9 +22,8 @@ class Cat(Module): - def __init__(self, dim=0, n=0) -> None: + def __init__(self, dim=0) -> None: super().__init__() - self._op = flow.builtin_op("concat").Input("in", n).Output("out").Build() self.axis = dim def forward(self, inputs): @@ -49,7 +48,7 @@ def forward(self, inputs): else: assert input.shape[i] == first_input_shape[i] - return self._op(*inputs, axis=axis, max_dim_size=dynamic_dim_size)[0] + return flow.F.concat(inputs, axis=axis, max_dim_size=dynamic_dim_size) @oneflow_export("cat") @@ -83,8 +82,7 @@ def concat_op(inputs, dim=0): flow.Size([2, 18, 5, 3]) """ - n = len(inputs) - return Cat(dim=dim, n=n)(inputs) + return Cat(dim=dim)(inputs) if __name__ == "__main__": diff --git a/oneflow/python/nn/modules/constant.py b/oneflow/python/nn/modules/constant.py index c2c52c9fbba..4d04f421f79 100644 --- a/oneflow/python/nn/modules/constant.py +++ b/oneflow/python/nn/modules/constant.py @@ -46,45 +46,12 @@ def __init__( if device is None: self.device = flow.device("cpu") - if dtype in [ - flow.int, - flow.int64, - flow.int32, - flow.char, - flow.int8, - flow.long, - flow.uint8, - ]: - floating_value = float(0) - integer_value = int(value) - is_floating_value = False - elif dtype in [ - flow.float32, - flow.float, - flow.double, - flow.float64, - flow.float16, - flow.half, - ]: - floating_value = float(value) - integer_value = int(0) - is_floating_value = True - else: - raise NotImplementedError("Unsupport data type") - - self._op = ( - flow.builtin_op("constant") - .Output("out") - .Attr("is_floating_value", is_floating_value) - .Attr("floating_value", floating_value) - .Attr("integer_value", integer_value) - .Attr("dtype", dtype) - .Attr("shape", size) - .Build() - ) + self.shape = size + self.value = value + self.dtype = dtype def forward(self): - res = self._op()[0] + res = flow.F.constant(self.shape, self.value, self.dtype) res = res.to(device=self.device) res.requires_grad = self.requires_grad return res @@ -171,10 +138,9 @@ def zeros_op( class ZerosLike(Module): def __init__(self): super().__init__() - self._op = flow.builtin_op("zero_like").Input("like").Output("out").Build() def forward(self, other): - return self._op(other)[0] + return flow.F.zeros_like(other) @oneflow_export("zeros_like") @@ -205,10 +171,9 @@ def zeros_like_op(other): class OnesLike(Module): def __init__(self): super().__init__() - self._op = flow.builtin_op("ones_like").Input("like").Output("out").Build() def forward(self, other): - return self._op(other)[0] + return flow.F.ones_like(other) @oneflow_export("ones_like") diff --git a/oneflow/python/nn/modules/expand.py b/oneflow/python/nn/modules/expand.py index b2f20fd82e1..5385b91d5e0 100644 --- a/oneflow/python/nn/modules/expand.py +++ b/oneflow/python/nn/modules/expand.py @@ -23,7 +23,6 @@ class Expand(Module): def __init__(self, *sizes) -> None: super().__init__() - self._op = flow.builtin_op("expand").Input("in").Output("out").Build() self.expand_size = list(*sizes) def forward(self, x): @@ -59,9 +58,9 @@ def forward(self, x): else: new_stride.insert(0, 0) - return self._op( + return flow.F.expand( x, in_shape=list(x.shape), out_shape=new_size, stride=new_stride - )[0] + ) @oneflow_export("expand") diff --git a/oneflow/python/nn/modules/flatten.py b/oneflow/python/nn/modules/flatten.py index e14119d95e4..05833569143 100644 --- a/oneflow/python/nn/modules/flatten.py +++ b/oneflow/python/nn/modules/flatten.py @@ -45,17 +45,11 @@ class Flatten(Module): def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: super().__init__() - self.op_ = ( - flow.builtin_op("flatten") - .Input("in") - .Output("out") - .Attr("start_dim", start_dim) - .Attr("end_dim", end_dim) - .Build() - ) + self.start_dim = start_dim + self.end_dim = end_dim def forward(self, input): - return self.op_(input)[0] + return flow.F.flatten(input, start_dim=self.start_dim, end_dim=self.end_dim) @oneflow_export("flatten") diff --git a/oneflow/python/nn/modules/gather.py b/oneflow/python/nn/modules/gather.py index 1d145ab37e2..440b6a2993c 100644 --- a/oneflow/python/nn/modules/gather.py +++ b/oneflow/python/nn/modules/gather.py @@ -31,15 +31,6 @@ def __init__( assert sparse_grad is False, "Only support bool = False for now!" self.dim = dim - self._gather_op = ( - flow.builtin_op("dim_gather") - .Input("input") - .Input("index") - .Output("output") - .Attr("dim", int(dim)) - .Build() - ) - def forward(self, input, index): assert self.dim < len( index.shape @@ -56,8 +47,7 @@ def forward(self, input, index): input.shape[i] == index.shape[i] ), "Dimensions of input and index should be same except at dim" - res = self._gather_op(input, index)[0] - return res + return flow.F.dim_gather(input, index, dim=self.dim) @oneflow_export("gather") diff --git a/oneflow/python/nn/modules/reshape.py b/oneflow/python/nn/modules/reshape.py index 6fa18280e0d..7658c819118 100644 --- a/oneflow/python/nn/modules/reshape.py +++ b/oneflow/python/nn/modules/reshape.py @@ -42,19 +42,11 @@ def __init__(self, shape: Sequence[int]) -> None: shape = list(shape) assert all(dim == -1 or dim > 0 for dim in shape) assert shape.count(-1) <= 1 - - self._op = ( - flow.builtin_op("reshape") - .Input("in") - .Output("out") - .Attr("shape", shape) - .Build() - ) self.shape = shape def forward(self, x): new_shape = infer_shape(x, self.shape) - return self._op(x, shape=new_shape)[0] + return flow.F.reshape(x, shape=new_shape) @oneflow_export("reshape") diff --git a/oneflow/python/nn/modules/slice.py b/oneflow/python/nn/modules/slice.py index 8a8c10fdfe2..7bc16d9c074 100644 --- a/oneflow/python/nn/modules/slice.py +++ b/oneflow/python/nn/modules/slice.py @@ -26,18 +26,12 @@ def __init__( self, start: Tuple[int, ...], stop: Tuple[int, ...], step: Tuple[int, ...] ) -> None: super().__init__() - self._op = ( - flow.builtin_op("slice") - .Input("x") - .Output("y") - .Attr("start", start) - .Attr("stop", stop) - .Attr("step", step) - .Build() - ) + self.start = start + self.stop = stop + self.step = step def forward(self, x): - return self._op(x)[0] + return flow.F.slice(x, start=self.start, stop=self.stop, step=self.step) @oneflow_export("slice") @@ -74,19 +68,14 @@ def __init__( self, start: Tuple[int, ...], stop: Tuple[int, ...], step: Tuple[int, ...] ) -> None: super().__init__() - self._op = ( - flow.builtin_op("slice_update") - .Input("x") - .Input("update") - .Output("y") - .Attr("start", start) - .Attr("stop", stop) - .Attr("step", step) - .Build() - ) + self.start = start + self.stop = stop + self.step = step def forward(self, x, update): - return self._op(x, update)[0] + return flow.F.slice_update( + x, update, start=self.start, stop=self.stop, step=self.step + ) @oneflow_export("slice_update") @@ -122,18 +111,14 @@ def __init__( self, start: Tuple[int, ...], stop: Tuple[int, ...], step: Tuple[int, ...] ) -> None: super().__init__() - self._op = ( - flow.builtin_op("logical_slice_assign") - .Input("ref") - .Input("value") - .Attr("start", start) - .Attr("stop", stop) - .Attr("step", step) - .Build() - ) + self.start = start + self.stop = stop + self.step = step def forward(self, x, update): - return self._op(x, update) + return flow.F.logical_slice_assign( + x, update, start=self.start, stop=self.stop, step=self.step + ) # NOTE: conflict with existing userop: flow.experimental.logical_slice_assign, so use tmp.logical_slice_assign diff --git a/oneflow/python/nn/modules/sparse.py b/oneflow/python/nn/modules/sparse.py index fe48a582410..eb95d58bbc8 100644 --- a/oneflow/python/nn/modules/sparse.py +++ b/oneflow/python/nn/modules/sparse.py @@ -96,14 +96,6 @@ def __init__( self.weight = flow.nn.Parameter(_weight) self.sparse = sparse - self._gather_op = ( - flow.builtin_op("gather") - .Input("in") - .Input("indices") - .Output("out") - .Attr("axis", int(0)) - .Build() - ) def reset_parameters(self) -> None: flow.nn.init.normal_(self.weight) @@ -116,7 +108,7 @@ def _fill_padding_idx_with_zero(self) -> None: self.weight[self.padding_idx].fill_(0) def forward(self, indices): - res = self._gather_op(self.weight, indices)[0] + res = flow.F.gather(self.weight, indices, axis=0) return res diff --git a/oneflow/python/nn/modules/squeeze.py b/oneflow/python/nn/modules/squeeze.py index 7563a3be40d..da73627707d 100644 --- a/oneflow/python/nn/modules/squeeze.py +++ b/oneflow/python/nn/modules/squeeze.py @@ -26,18 +26,10 @@ def __init__(self, dim: Optional[Sequence[int]] = None) -> None: super().__init__() self.dim = dim - self._op = ( - flow.builtin_op("squeeze") - .Input("in") - .Output("out") - .Attr("axes", dim) - .Build() - ) - def forward(self, x): if self.dim is None: return x - return self._op(x)[0] + return flow.F.squeeze(x, dim=self.dim) @oneflow_export("squeeze") diff --git a/oneflow/python/nn/modules/to.py b/oneflow/python/nn/modules/to.py index e3328e12771..b2236ee93fd 100644 --- a/oneflow/python/nn/modules/to.py +++ b/oneflow/python/nn/modules/to.py @@ -23,20 +23,16 @@ class To(Module): def __init__(self, copy): super().__init__() - self._copy_op = flow.builtin_op("copy").Input("in").Output("out").Build() - self._cast_op = flow.builtin_op("cast").Input("in").Output("out").Build() self.copy = copy def forward(self, x, device, dtype): result = x if device is not None: if x.device != device or self.copy: - result = self._copy_op( - x, device_type=device.type, device_id=device.index - )[0] + result = flow.F.copy(x, device_type=device.type, device_id=device.index) if dtype is not None: if x.dtype != dtype or self.copy: - result = self._cast_op(result, dtype=dtype)[0] + result = flow.F.cast(result, dtype=dtype) return result diff --git a/oneflow/python/nn/modules/unsqueeze.py b/oneflow/python/nn/modules/unsqueeze.py index 26ad39b7d5c..cc1008039cf 100644 --- a/oneflow/python/nn/modules/unsqueeze.py +++ b/oneflow/python/nn/modules/unsqueeze.py @@ -23,7 +23,6 @@ class Unsqueeze(Module): def __init__(self, dim: int = 0) -> None: super().__init__() self.dim = dim - self._op = flow.builtin_op("expand_dims").Input("in").Output("out").Build() def forward(self, input): assert ( @@ -32,7 +31,7 @@ def forward(self, input): if self.dim < 0: self.dim = 1 + input.ndimension() + self.dim - return self._op(input, axis=self.dim)[0] + return flow.F.expand_dims(input, axis=self.dim) @oneflow_export("unsqueeze") diff --git a/oneflow/python/nn/modules/upsample.py b/oneflow/python/nn/modules/upsample.py index 4bd4f798d3e..aafabd3e3ae 100644 --- a/oneflow/python/nn/modules/upsample.py +++ b/oneflow/python/nn/modules/upsample.py @@ -121,18 +121,6 @@ def __init__( if self.mode == "nearest" and self.align_corners: raise ValueError('interpolation "nearest" does not support align_corners.') - self._op = ( - flow.builtin_op("upsample") - .Input("x") - .Output("y") - .Attr("height_scale", float(1.0)) - .Attr("width_scale", float(1.0)) - .Attr("align_corners", self.align_corners) - .Attr("data_format", "channels_first") - .Attr("interpolation", self.mode) - .Build() - ) - def forward(self, x): assert ( self.size != None or self.scale_factor != None @@ -149,9 +137,14 @@ def forward(self, x): else: self.width_scale = 1.0 * self.size[1] / w - res = self._op(x, height_scale=self.height_scale, width_scale=self.width_scale)[ - 0 - ] + res = flow.F.upsample( + x, + height_scale=self.height_scale, + width_scale=self.width_scale, + align_corners=self.align_corners, + interpolation=self.mode, + data_format="channels_first", + ) return res diff --git a/oneflow/python/nn/modules/where.py b/oneflow/python/nn/modules/where.py index ec5881a1efb..86c5e8a4305 100644 --- a/oneflow/python/nn/modules/where.py +++ b/oneflow/python/nn/modules/where.py @@ -22,14 +22,6 @@ class Where(Module): def __init__(self) -> None: super().__init__() - self._where_op = ( - flow.builtin_op("where") - .Input("condition") - .Input("x") - .Input("y") - .Output("out") - .Build() - ) def forward(self, condition, x, y): assert condition.dtype == flow.int32 or condition.dtype == flow.int8 @@ -92,7 +84,7 @@ def forward(self, condition, x, y): y, broadcast_like_tensor, broadcast_axes=tuple(broadcast_y_axes) ) - return self._where_op(broadcast_cond, broadcast_x, broadcast_y)[0] + return flow.F.where(broadcast_cond, broadcast_x, broadcast_y) @oneflow_export("where") diff --git a/oneflow/user/kernels/constant_kernel.cpp b/oneflow/user/kernels/constant_kernel.cpp index 0a654bea3ef..2dbe187814d 100644 --- a/oneflow/user/kernels/constant_kernel.cpp +++ b/oneflow/user/kernels/constant_kernel.cpp @@ -19,32 +19,14 @@ limitations under the License. namespace oneflow { namespace user_op { -class ConstState final : public OpKernelState { - public: - ConstState(bool is_init) : is_init_(is_init) {} - ~ConstState() = default; - bool is_inited() const { return is_init_; } - void set_is_inited(bool val) { is_init_ = val; } - - private: - bool is_init_; -}; - template class ConstantKernel final : public OpKernel { public: ConstantKernel() = default; ~ConstantKernel() = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(false); - } - private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* const_state = dynamic_cast(state); - if (const_state->is_inited()) { return; } + void Compute(user_op::KernelComputeContext* ctx) const override { Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); bool is_floating_value = ctx->Attr("is_floating_value"); const int64_t elem_cnt = out_tensor->shape().elem_cnt(); @@ -54,8 +36,6 @@ class ConstantKernel final : public OpKernel { ? static_cast(ctx->Attr("floating_value")) : static_cast(ctx->Attr("integer_value")), out_tensor->mut_dptr()); - - const_state->set_is_inited(true); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py index 164d2111be7..4164dfc3b6f 100644 --- a/tools/generate_functional_api.py +++ b/tools/generate_functional_api.py @@ -124,6 +124,7 @@ ) types_allowed = { + "Void", "Tensor", "TensorTuple", "Scalar", @@ -174,10 +175,11 @@ } return_type_aliases = { + "Void": "Maybe", "Tensor": "Maybe", "TensorTuple": "Maybe", - "String": "std::string", - **generic_type_aliases, + "String": "Maybe", + **{k: "Maybe<{0}>".format(v) for k, v in generic_type_aliases.items()}, } value_aliases = { From ab44b5043f99064b777add2a10b73fa4daa8c1e4 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 10:02:45 +0800 Subject: [PATCH 09/42] Add or refactor activation grad funcs. --- oneflow/core/autograd/gradient_funcs/elu.cpp | 68 ++++++++++++++++++ .../core/autograd/gradient_funcs/hardtanh.cpp | 71 +++++++++++++++++++ .../autograd/gradient_funcs/leaky_relu.cpp | 68 ++++++++++++++++++ 3 files changed, 207 insertions(+) create mode 100644 oneflow/core/autograd/gradient_funcs/elu.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/hardtanh.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/leaky_relu.cpp diff --git a/oneflow/core/autograd/gradient_funcs/elu.cpp b/oneflow/core/autograd/gradient_funcs/elu.cpp new file mode 100644 index 00000000000..4c3764dc6e8 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/elu.cpp @@ -0,0 +1,68 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct EluInterpState : public OpExprInterpState { + bool requires_grad; + double alpha; +}; + +class Elu : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(EluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); + ctx->SaveTensorForBackward(inputs.at(0)); + return Maybe::Ok(); + } + + Maybe Apply(const EluInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::EluGrad(x, out_grads.at(0), ctx->alpha)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("elu", Elu); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/hardtanh.cpp b/oneflow/core/autograd/gradient_funcs/hardtanh.cpp new file mode 100644 index 00000000000..5c4c71e227d --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/hardtanh.cpp @@ -0,0 +1,71 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct HardTanhInterpState : public OpExprInterpState { + bool requires_grad; + double min_val; + double max_val; +}; + +class HardTanh : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(HardTanhInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->min_val = JUST(composed_attrs.GetAttr("min_val")); + ctx->max_val = JUST(composed_attrs.GetAttr("max_val")); + ctx->SaveTensorForBackward(outputs.at(0)); + return Maybe::Ok(); + } + + Maybe Apply(const HardTanhInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& y = ctx->SavedTensors().at(0); + in_grads->at(0) = + JUST(functional::HardTanhGrad(y, out_grads.at(0), ctx->min_val, ctx->max_val)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("hardtanh", HardTanh); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/leaky_relu.cpp b/oneflow/core/autograd/gradient_funcs/leaky_relu.cpp new file mode 100644 index 00000000000..eaab31d6f10 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/leaky_relu.cpp @@ -0,0 +1,68 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct LeakyReluInterpState : public OpExprInterpState { + bool requires_grad; + float alpha; +}; + +class LeakyRelu : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(LeakyReluInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); + ctx->SaveTensorForBackward(inputs.at(0)); + return Maybe::Ok(); + } + + Maybe Apply(const LeakyReluInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), ctx->alpha)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("leaky_relu", LeakyRelu); + +} // namespace one +} // namespace oneflow From 5c1463a7d13166d05836b049b0f94443692c4fd1 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 10:02:45 +0800 Subject: [PATCH 10/42] Add or refactor activation grad funcs. --- oneflow/core/autograd/gradient_funcs/elu.cpp | 68 ++++++++++++++++++ .../core/autograd/gradient_funcs/hardtanh.cpp | 71 +++++++++++++++++++ .../autograd/gradient_funcs/leaky_relu.cpp | 68 ++++++++++++++++++ 3 files changed, 207 insertions(+) create mode 100644 oneflow/core/autograd/gradient_funcs/elu.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/hardtanh.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/leaky_relu.cpp diff --git a/oneflow/core/autograd/gradient_funcs/elu.cpp b/oneflow/core/autograd/gradient_funcs/elu.cpp new file mode 100644 index 00000000000..4c3764dc6e8 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/elu.cpp @@ -0,0 +1,68 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct EluInterpState : public OpExprInterpState { + bool requires_grad; + double alpha; +}; + +class Elu : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(EluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); + ctx->SaveTensorForBackward(inputs.at(0)); + return Maybe::Ok(); + } + + Maybe Apply(const EluInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::EluGrad(x, out_grads.at(0), ctx->alpha)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("elu", Elu); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/hardtanh.cpp b/oneflow/core/autograd/gradient_funcs/hardtanh.cpp new file mode 100644 index 00000000000..5c4c71e227d --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/hardtanh.cpp @@ -0,0 +1,71 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct HardTanhInterpState : public OpExprInterpState { + bool requires_grad; + double min_val; + double max_val; +}; + +class HardTanh : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(HardTanhInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->min_val = JUST(composed_attrs.GetAttr("min_val")); + ctx->max_val = JUST(composed_attrs.GetAttr("max_val")); + ctx->SaveTensorForBackward(outputs.at(0)); + return Maybe::Ok(); + } + + Maybe Apply(const HardTanhInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& y = ctx->SavedTensors().at(0); + in_grads->at(0) = + JUST(functional::HardTanhGrad(y, out_grads.at(0), ctx->min_val, ctx->max_val)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("hardtanh", HardTanh); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/leaky_relu.cpp b/oneflow/core/autograd/gradient_funcs/leaky_relu.cpp new file mode 100644 index 00000000000..eaab31d6f10 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/leaky_relu.cpp @@ -0,0 +1,68 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct LeakyReluInterpState : public OpExprInterpState { + bool requires_grad; + float alpha; +}; + +class LeakyRelu : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(LeakyReluInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->alpha = JUST(composed_attrs.GetAttr("alpha")); + ctx->SaveTensorForBackward(inputs.at(0)); + return Maybe::Ok(); + } + + Maybe Apply(const LeakyReluInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), ctx->alpha)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("leaky_relu", LeakyRelu); + +} // namespace one +} // namespace oneflow From 0ceb5ec483b62df45218465050873ec1b9ddedea Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 11:02:53 +0800 Subject: [PATCH 11/42] Revert unpack all --- oneflow/api/python/functional/unpack_call.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/oneflow/api/python/functional/unpack_call.h b/oneflow/api/python/functional/unpack_call.h index eb669727754..a471404f9ee 100644 --- a/oneflow/api/python/functional/unpack_call.h +++ b/oneflow/api/python/functional/unpack_call.h @@ -69,6 +69,10 @@ INSTANCE_MAYBE_UNPACK_CALL(Maybe, std::shared_ptr, ([](const Maybe& t) { return t.GetPtrOrThrow(); })); INSTANCE_MAYBE_UNPACK_CALL(Maybe, std::shared_ptr, ([](const Maybe& t) { return t.GetPtrOrThrow(); })); +INSTANCE_MAYBE_UNPACK_CALL(Maybe, bool, ([](const Maybe& t) { + t.GetOrThrow(); + return true; + })); #undef INSTANCE_MAYBE_UNPACK_CALL From d60b7d238a6cd5a8a4d6579f62bd99b8feab9658 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 11:12:51 +0800 Subject: [PATCH 12/42] Fix masked fill --- oneflow/python/nn/modules/masked_fill.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/oneflow/python/nn/modules/masked_fill.py b/oneflow/python/nn/modules/masked_fill.py index 04e948760f3..28b00766ac7 100644 --- a/oneflow/python/nn/modules/masked_fill.py +++ b/oneflow/python/nn/modules/masked_fill.py @@ -23,20 +23,12 @@ class MaskedFill(Module): def __init__(self, value) -> None: super().__init__() self.value = value - self._where_op = ( - flow.builtin_op("where") - .Input("condition") - .Input("x") - .Input("y") - .Output("out") - .Build() - ) def forward(self, input, mask): in_shape = tuple(input.shape) value_like_x = flow.Tensor(*in_shape, device=input.device) value_like_x.fill_(self.value) - return self._where_op(mask, value_like_x, input)[0] + return flow.F.where(mask, value_like_x, input) @oneflow_export("masked_fill") From c43e0f7c5f57e2aa3248cfa7e6eac766b9e3c321 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 11:18:14 +0800 Subject: [PATCH 13/42] Refine --- oneflow/python/nn/modules/atanh.py | 3 +-- oneflow/python/nn/modules/eq.py | 6 +----- oneflow/python/nn/modules/floor.py | 3 +-- oneflow/python/nn/modules/softplus.py | 3 +-- 4 files changed, 4 insertions(+), 11 deletions(-) diff --git a/oneflow/python/nn/modules/atanh.py b/oneflow/python/nn/modules/atanh.py index 1908b391b83..9020773aafb 100644 --- a/oneflow/python/nn/modules/atanh.py +++ b/oneflow/python/nn/modules/atanh.py @@ -22,10 +22,9 @@ class Atanh(Module): def __init__(self): super().__init__() - self._op = flow.builtin_op("atanh").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.atanh(x) @oneflow_export("atanh") diff --git a/oneflow/python/nn/modules/eq.py b/oneflow/python/nn/modules/eq.py index ca19c34f77f..7faefb3b672 100644 --- a/oneflow/python/nn/modules/eq.py +++ b/oneflow/python/nn/modules/eq.py @@ -22,9 +22,6 @@ class Eq(Module): def __init__(self) -> None: super().__init__() - self.eq_op = ( - flow.builtin_op("broadcast_equal").Input("x").Input("y").Output("z").Build() - ) def forward(self, input, other): if isinstance(other, flow.Tensor) or isinstance( @@ -42,8 +39,7 @@ def forward(self, input, other): raise NotImplementedError( "Unsupport data type, The second argument can be a tensor whose shape is broadcastable with the first argument." ) - - return self.eq_op(input, other)[0] + return flow.F.broadcast_equal(input, other) @oneflow_export("eq", "equal") diff --git a/oneflow/python/nn/modules/floor.py b/oneflow/python/nn/modules/floor.py index 9ed9433b2de..853fc49fdd4 100644 --- a/oneflow/python/nn/modules/floor.py +++ b/oneflow/python/nn/modules/floor.py @@ -27,10 +27,9 @@ class Floor(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("floor").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.floor(x) @oneflow_export("floor") diff --git a/oneflow/python/nn/modules/softplus.py b/oneflow/python/nn/modules/softplus.py index 545d27b6440..68cac8ddb10 100644 --- a/oneflow/python/nn/modules/softplus.py +++ b/oneflow/python/nn/modules/softplus.py @@ -22,10 +22,9 @@ class Softplus(Module): def __init__(self) -> None: super().__init__() - self._op = flow.builtin_op("softplus").Input("x").Output("y").Build() def forward(self, x): - return self._op(x)[0] + return flow.F.softplus(x) @oneflow_export("softplus") From 22c69a303a9e3dd1ae09f7cf0b10ef26cf119cfd Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 11:39:39 +0800 Subject: [PATCH 14/42] Add nn ops. --- .../core/autograd/gradient_funcs/bias_add.cpp | 13 +- .../sparse_softmax_cross_entropy.cpp | 4 +- oneflow/core/functional/functional_api.yaml | 37 +++++ oneflow/core/functional/impl/nn_functor.cpp | 146 +++++++++++++++++- oneflow/python/nn/modules/linear.py | 29 +--- oneflow/python/nn/modules/loss.py | 26 +--- oneflow/python/nn/modules/matmul.py | 38 +---- oneflow/python/nn/modules/normalization.py | 33 +--- 8 files changed, 208 insertions(+), 118 deletions(-) diff --git a/oneflow/core/autograd/gradient_funcs/bias_add.cpp b/oneflow/core/autograd/gradient_funcs/bias_add.cpp index 53969595e83..6e59f14982a 100644 --- a/oneflow/core/autograd/gradient_funcs/bias_add.cpp +++ b/oneflow/core/autograd/gradient_funcs/bias_add.cpp @@ -18,7 +18,7 @@ limitations under the License. #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_expr_helper.h" -#include "oneflow/core/framework/user_op_conf_trait.h" +#include "oneflow/core/framework/attr_map.h" namespace oneflow { namespace one { @@ -26,6 +26,7 @@ namespace one { struct BiasAddInterpState : public OpExprInterpState { bool input_requires_grad; bool bias_requires_grad; + int32_t axis; }; class BiasAdd : public OpExprGradFunction { @@ -33,9 +34,8 @@ class BiasAdd : public OpExprGradFunction { Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); const std::string& op_name = fw_op_expr->op_name(); - op_trait_ = std::make_shared(op_name, fw_op_expr->proto()); - axis_ = JUST(op_trait_->GetAttr("axis")); backward_input_op_ = JUST(op_expr_helper::IdentityOp(GradientOpName(op_name + "_input"))); backward_bias_op_ = JUST( op_expr_helper::ReduceSumOp({0}, /*keepdims=*/false, GradientOpName(op_name + "_bias"))); @@ -47,6 +47,8 @@ class BiasAdd : public OpExprGradFunction { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->bias_requires_grad = inputs.at(1)->requires_grad(); + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->axis = JUST(composed_attrs.GetAttr("axis")); return Maybe::Ok(); } @@ -58,7 +60,7 @@ class BiasAdd : public OpExprGradFunction { std::vector reduce_axes_vec; reduce_axes_vec.reserve(num_axes); for (int i = 0; i < num_axes; ++i) { - if (i != axis_) { reduce_axes_vec.push_back(i); } + if (i != ctx->axis) { reduce_axes_vec.push_back(i); } } MutableAttrMap attrs; JUST(attrs.SetAttr>("axis", reduce_axes_vec)); @@ -73,8 +75,7 @@ class BiasAdd : public OpExprGradFunction { } private: - std::shared_ptr op_trait_; - int32_t axis_; + AttrMap base_attrs_; std::shared_ptr backward_input_op_; std::shared_ptr backward_bias_op_; }; diff --git a/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp b/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp index 859010ccbbe..cb1c93c09f3 100644 --- a/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp +++ b/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp @@ -58,7 +58,7 @@ Maybe SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyInterpSt ctx->depth = JUST(composed_attrs.GetAttr("depth")); CHECK_EQ_OR_RETURN(inputs.size(), 2); CHECK_EQ_OR_RETURN(outputs.size(), 2); - ctx->SaveTensorForBackward(outputs.at(0)); // prob + ctx->SaveTensorForBackward(outputs.at(1)); // prob ctx->SaveTensorForBackward(inputs.at(1)); // label return Maybe::Ok(); } @@ -67,7 +67,7 @@ Maybe SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyInte const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 2); - const auto& dy = out_grads.at(1); + const auto& dy = out_grads.at(0); const auto& prob = ctx->SavedTensors().at(0); const auto& label = ctx->SavedTensors().at(1); MutableAttrMap attrs; diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 17d71dd72a8..a96e7512654 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -234,6 +234,10 @@ signature: "Tensor Concat(TensorTuple inputs, *, Int64 axis, Int64 max_dim_size)" bind_python: True +- name: "bias_add" + signature: "Tensor BiasAdd(Tensor x, Tensor bias, *, Int32 axis=1)" + bind_python: True + - name: "expand" signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)" bind_python: True @@ -254,14 +258,47 @@ signature: "Tensor DimGather(Tensor x, Tensor indices, *, Int32 dim)" bind_python: True +- name: "matmul" + signature: + "Tensor MatMul(Tensor a, Tensor b, *, Bool transpose_a=False, Bool transpose_b=False, + Double alpha=1.0)" + bind_python: True + +- name: "broadcast_matmul" + signature: + "Tensor BroadcastMatMul(Tensor a, Tensor b, *, Bool transpose_a=False, + Bool transpose_b=False, Double alpha=1.0)" + bind_python: True + +- name: "sparse_softmax_cross_entropy" + signature: "Tensor SparseSoftmaxCrossEntropy(Tensor logits, Tensor label, *, Int64 depth)" + bind_python: True + - name: "where" signature: "Tensor Where(Tensor condition, Tensor x, Tensor y)" bind_python: True +- name: "batch_matmul" + signature: + "Tensor BatchMatMul(Tensor a, Tensor b, *, Bool transpose_a=False, Bool transpose_b=False, + Double alpha=1.0)" + bind_python: True + - name: "negative" signature: "Tensor Negative(Tensor x)" bind_python: True +- name: "layer_norm_affine" + signature: + "Tensor LayerNormAffine(Tensor x, Tensor beta, Tensor gamma, *, Int64 begin_norm_axis, + Int64 begin_params_axis, Double epsilon)" + bind_python: True + +- name: "layer_norm" + signature: + "Tensor LayerNorm(Tensor x, *, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon)" + bind_python: True + - name: "prelu" signature: "Tensor PRelu(Tensor x, Tensor alpha)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index ec3c302a9b7..575c68a7307 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -21,6 +21,8 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/impl/common.h" +#include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/functional/scalar.h" namespace oneflow { @@ -29,6 +31,139 @@ namespace functional { namespace impl { +class BiasAddFunctor { + public: + BiasAddFunctor() { + op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& bias, const int32_t& axis) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("axis", axis)); + return OpInterpUtil::Dispatch(*op_, {x, bias}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class MatMulBaseFunctor { + public: + MatMulBaseFunctor() = default; + virtual ~MatMulBaseFunctor() = default; + Maybe operator()(const std::shared_ptr& a, + const std::shared_ptr& b, const bool& transpose_a, + const bool& transpose_b, const double& alpha) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("transpose_a", transpose_a)); + JUST(attrs.SetAttr("transpose_b", transpose_b)); + JUST(attrs.SetAttr("alpha", alpha)); + return OpInterpUtil::Dispatch(*op_, {a, b}, attrs); + } + + protected: + std::shared_ptr op_; +}; + +class MatMulFunctor : public MatMulBaseFunctor { + public: + MatMulFunctor() { + op_ = CHECK_JUST(one::OpBuilder("matmul").Input("a").Input("b").Output("out").Build()); + } +}; + +class BatchMatMulFunctor : public MatMulBaseFunctor { + public: + BatchMatMulFunctor() { + op_ = CHECK_JUST(one::OpBuilder("batch_matmul").Input("a").Input("b").Output("out").Build()); + } +}; + +class BroadcastMatMulFunctor : public MatMulBaseFunctor { + public: + BroadcastMatMulFunctor() { + op_ = + CHECK_JUST(one::OpBuilder("broadcast_matmul").Input("a").Input("b").Output("out").Build()); + } +}; + +class LayerNormFunctor { + public: + LayerNormFunctor() { + op_ = CHECK_JUST(one::OpBuilder("layer_norm") + .Input("x") + .Output("y") + .Output("mean") + .Output("inv_variance") + .Build()); + } + Maybe operator()(const std::shared_ptr& x, const int64_t& begin_norm_axis, + const int64_t& begin_params_axis, const double& epsilon) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("begin_norm_axis", begin_norm_axis)); + JUST(attrs.SetAttr("begin_params_axis", begin_params_axis)); + JUST(attrs.SetAttr("epsilon", epsilon)); + JUST(attrs.SetAttr("center", false)); + JUST(attrs.SetAttr("scale", false)); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class LayerNormAffineFunctor { + public: + LayerNormAffineFunctor() { + op_ = CHECK_JUST(one::OpBuilder("layer_norm") + .Input("x") + .Input("beta") + .Input("gamma") + .Output("y") + .Output("mean") + .Output("inv_variance") + .Output("normalized") + .Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& beta, + const std::shared_ptr& gamma, + const int64_t& begin_norm_axis, const int64_t& begin_params_axis, + const double& epsilon) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("begin_norm_axis", begin_norm_axis)); + JUST(attrs.SetAttr("begin_params_axis", begin_params_axis)); + JUST(attrs.SetAttr("epsilon", epsilon)); + JUST(attrs.SetAttr("center", true)); + JUST(attrs.SetAttr("scale", true)); + return OpInterpUtil::Dispatch(*op_, {x, beta, gamma}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class SparseSoftmaxCrossEntropyFunctor { + public: + SparseSoftmaxCrossEntropyFunctor() { + op_ = CHECK_JUST(one::OpBuilder("sparse_softmax_cross_entropy") + .Input("prediction") + .Input("label") + .Output("out") + .Output("prob") + .Build()); + } + Maybe operator()(const std::shared_ptr& logits, + const std::shared_ptr& label, const int64_t& depth) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("depth", depth)); + return OpInterpUtil::Dispatch(*op_, {logits, label}, attrs); + } + + private: + std::shared_ptr op_; +}; + class NormalizationFunctor { public: NormalizationFunctor() { @@ -81,7 +216,16 @@ class NormalizationFunctor { } // namespace impl -ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Normalization"); }; +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor("BiasAdd"); + m.add_functor("MatMul"); + m.add_functor("BatchMatMul"); + m.add_functor("BroadcastMatMul"); + m.add_functor("LayerNorm"); + m.add_functor("LayerNormAffine"); + m.add_functor("SparseSoftmaxCrossEntropy"); + m.add_functor("Normalization"); +}; } // namespace functional } // namespace one diff --git a/oneflow/python/nn/modules/linear.py b/oneflow/python/nn/modules/linear.py index 363a2ec446a..eb359cdf449 100644 --- a/oneflow/python/nn/modules/linear.py +++ b/oneflow/python/nn/modules/linear.py @@ -106,29 +106,6 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> No if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_features)) - - self._matmul_op = ( - flow.builtin_op("matmul") - .Input("a") - .Input("b") - .Output("out") - .Attr("transpose_a", False) - .Attr("transpose_b", True) - .Attr("alpha", 1.0) - .Build() - ) - - self._broadcast_matmul_op = ( - flow.builtin_op("broadcast_matmul") - .Input("a") - .Input("b") - .Output("out") - .Attr("transpose_a", False) - .Attr("transpose_b", True) - .Attr("alpha", 1.0) - .Build() - ) - self.reset_parameters() def reset_parameters(self) -> None: @@ -143,9 +120,11 @@ def forward(self, x): assert len(x.shape) >= 2, "Tensor x's dim should >=2" if len(x.shape) == 2: - res = self._matmul_op(x, self.weight)[0] + res = flow.F.matmul(x, self.weight, transpose_a=False, transpose_b=True) else: - res = self._broadcast_matmul_op(x, self.weight)[0] + res = flow.F.broadcast_matmul( + x, self.weight, transpose_a=False, transpose_b=True + ) if self.use_bias: res += self.bias diff --git a/oneflow/python/nn/modules/loss.py b/oneflow/python/nn/modules/loss.py index 79dba41f34c..b076f7fdba7 100644 --- a/oneflow/python/nn/modules/loss.py +++ b/oneflow/python/nn/modules/loss.py @@ -99,14 +99,6 @@ def __init__( self.ignore_index = ignore_index self.reduction = reduction - self._op = ( - flow.builtin_op("sparse_softmax_cross_entropy") - .Input("prediction") - .Input("label") - .Output("prob") - .Output("out") - .Build() - ) def forward(self, input, target): assert len(input.shape) <= 4 @@ -125,7 +117,9 @@ def forward(self, input, target): elif input_shape_len >= 5: raise NotImplemented - prob, out = self._op(input, target, depth=input.shape[len(input.shape) - 1]) + out = flow.F.sparse_softmax_cross_entropy( + input, target, depth=input.shape[len(input.shape) - 1] + ) if self.ignore_index is not None: zeros = flow.experimental.zeros( size=out.shape, dtype=out.dtype, device=out.device @@ -347,19 +341,11 @@ def __init__( self.ignore_index = ignore_index self.reduction = reduction - self._dim_gather_op = ( - flow.builtin_op("dim_gather") - .Input("input") - .Input("index") - .Output("output") - .Attr("dim", 1) - .Build() - ) def nllloss_1d(self, input, target): - target = flow.experimental.reshape(target, (target.shape[0], 1)) - res = self._dim_gather_op(input, target)[0] - res = flow.experimental.squeeze(res, dim=[1]) + target = flow.F.reshape(target, shape=(target.shape[0], 1)) + res = flow.F.dim_gather(input, target, dim=1) + res = flow.F.squeeze(res, dim=[1]) return res def forward(self, input, target): diff --git a/oneflow/python/nn/modules/matmul.py b/oneflow/python/nn/modules/matmul.py index 8fd27001d21..220e8a7f052 100644 --- a/oneflow/python/nn/modules/matmul.py +++ b/oneflow/python/nn/modules/matmul.py @@ -24,38 +24,6 @@ class MatMul(Module): def __init__(self) -> None: super().__init__() - self._matmul_op = ( - flow.builtin_op("matmul") - .Input("a") - .Input("b") - .Output("out") - .Attr("transpose_a", False) - .Attr("transpose_b", False) - .Attr("alpha", 1.0) - .Build() - ) - - self._batch_matmul_op = ( - flow.builtin_op("batch_matmul") - .Input("a") - .Input("b") - .Output("out") - .Attr("transpose_a", False) - .Attr("transpose_b", False) - .Attr("alpha", 1.0) - .Build() - ) - - self._broadcast_matmul_op = ( - flow.builtin_op("broadcast_matmul") - .Input("a") - .Input("b") - .Output("out") - .Attr("transpose_a", False) - .Attr("transpose_b", False) - .Attr("alpha", 1.0) - .Build() - ) def forward(self, a, b): assert len(a.shape) >= 2, "Tensor a's dim should >=2" @@ -63,15 +31,15 @@ def forward(self, a, b): if len(a.shape) == len(b.shape): if len(a.shape) == 2: - res = self._matmul_op(a, b)[0] + res = flow.F.matmul(a, b) else: - res = self._batch_matmul_op(a, b)[0] + res = flow.F.batch_matmul(a, b) else: # NOTE: support broadcast b to a only for now assert ( len(b.shape) == 2 ), "Not support number of dimensions of a being less than number of dimensions of b!" - res = self._broadcast_matmul_op(a, b)[0] + res = flow.F.broadcast_matmul(a, b) return res diff --git a/oneflow/python/nn/modules/normalization.py b/oneflow/python/nn/modules/normalization.py index 20353497aa5..7fe2cac1549 100644 --- a/oneflow/python/nn/modules/normalization.py +++ b/oneflow/python/nn/modules/normalization.py @@ -138,27 +138,6 @@ def __init__( # An integer specifies which axis params at, defaults to 1 in 'NCHW' format self.begin_params_axis = 1 - self._op = ( - flow.builtin_op("layer_norm") - .Input("x") - .Input("gamma") - .Input("beta") - .Output("y") - .Output("mean") - .Output("inv_variance") - .Output("normalized") - .Build() - ) - - self._op2 = ( - flow.builtin_op("layer_norm") - .Input("x") - .Output("y") - .Output("mean") - .Output("inv_variance") - .Build() - ) - def reset_parameters(self) -> None: if self.elementwise_affine: init.ones_(self.weight) @@ -220,25 +199,21 @@ def forward(self, x): return affined else: if self.elementwise_affine: - res = self._op( + flow.F.layer_norm_affine( x, self.weight, self.bias, - center=True, - scale=True, begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis, epsilon=self.epsilon, - )[0] + ) else: - res = self._op2( + flow.F.layer_norm( x, - center=False, - scale=False, begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis, epsilon=self.epsilon, - )[0] + ) return res def extra_repr(self) -> str: From 5453ff73055606c910ccda800ddf3e90e7d2c9e2 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 11:45:41 +0800 Subject: [PATCH 15/42] Refine --- oneflow/python/nn/modules/greater.py | 9 +-------- oneflow/python/nn/modules/less.py | 5 +---- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/oneflow/python/nn/modules/greater.py b/oneflow/python/nn/modules/greater.py index b04490e35a4..ab05c6b7ef5 100644 --- a/oneflow/python/nn/modules/greater.py +++ b/oneflow/python/nn/modules/greater.py @@ -22,13 +22,6 @@ class Greater(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_greater") - .Input("x") - .Input("y") - .Output("z") - .Build() - ) def forward(self, x, y): if x.dtype != flow.float32: @@ -39,7 +32,7 @@ def forward(self, x, y): ) if y.dtype != flow.float32: y = flow.experimental.cast(y, flow.float32) - return self._op(x, y)[0] + return flow.F.broadcast_greater(x, y) @oneflow_export("gt") diff --git a/oneflow/python/nn/modules/less.py b/oneflow/python/nn/modules/less.py index 89cd8fcf09e..f14646c2933 100644 --- a/oneflow/python/nn/modules/less.py +++ b/oneflow/python/nn/modules/less.py @@ -22,9 +22,6 @@ class Less(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_less").Input("x").Input("y").Output("z").Build() - ) def forward(self, x, y): if x.dtype != flow.float32: @@ -35,7 +32,7 @@ def forward(self, x, y): ) if y.dtype != flow.float32: y = flow.experimental.cast(y, flow.float32) - return self._op(x, y)[0] + return flow.F.broadcast_less(x, y) @oneflow_export("lt") From 53b1820bb76fef753cbdfa7b3ccfe732f8079c9b Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 11:45:41 +0800 Subject: [PATCH 16/42] Refine --- oneflow/python/nn/modules/greater.py | 9 +-------- oneflow/python/nn/modules/less.py | 5 +---- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/oneflow/python/nn/modules/greater.py b/oneflow/python/nn/modules/greater.py index b04490e35a4..ab05c6b7ef5 100644 --- a/oneflow/python/nn/modules/greater.py +++ b/oneflow/python/nn/modules/greater.py @@ -22,13 +22,6 @@ class Greater(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_greater") - .Input("x") - .Input("y") - .Output("z") - .Build() - ) def forward(self, x, y): if x.dtype != flow.float32: @@ -39,7 +32,7 @@ def forward(self, x, y): ) if y.dtype != flow.float32: y = flow.experimental.cast(y, flow.float32) - return self._op(x, y)[0] + return flow.F.broadcast_greater(x, y) @oneflow_export("gt") diff --git a/oneflow/python/nn/modules/less.py b/oneflow/python/nn/modules/less.py index 89cd8fcf09e..f14646c2933 100644 --- a/oneflow/python/nn/modules/less.py +++ b/oneflow/python/nn/modules/less.py @@ -22,9 +22,6 @@ class Less(Module): def __init__(self) -> None: super().__init__() - self._op = ( - flow.builtin_op("broadcast_less").Input("x").Input("y").Output("z").Build() - ) def forward(self, x, y): if x.dtype != flow.float32: @@ -35,7 +32,7 @@ def forward(self, x, y): ) if y.dtype != flow.float32: y = flow.experimental.cast(y, flow.float32) - return self._op(x, y)[0] + return flow.F.broadcast_less(x, y) @oneflow_export("lt") From c2c21e5d1ac10b35b8b0c93e16a097570edd7d0b Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 21 Jun 2021 14:04:29 +0800 Subject: [PATCH 17/42] Migrate conv op --- oneflow/core/autograd/gradient_funcs/conv.cpp | 110 +++++++--------- oneflow/core/functional/functional_api.yaml | 26 ++++ oneflow/core/functional/impl/nn_functor.cpp | 28 ++++ .../core/functional/impl/nn_grad_functor.cpp | 122 ++++++++++++++++++ oneflow/python/nn/modules/conv.py | 77 +++++------ oneflow/user/kernels/conv_cudnn_kernels.cpp | 16 ++- oneflow/user/kernels/conv_kernels.cpp | 46 +++---- 7 files changed, 281 insertions(+), 144 deletions(-) create mode 100644 oneflow/core/functional/impl/nn_grad_functor.cpp diff --git a/oneflow/core/autograd/gradient_funcs/conv.cpp b/oneflow/core/autograd/gradient_funcs/conv.cpp index d87ecd8b928..8080a931618 100644 --- a/oneflow/core/autograd/gradient_funcs/conv.cpp +++ b/oneflow/core/autograd/gradient_funcs/conv.cpp @@ -13,103 +13,93 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" -#include "oneflow/core/framework/op_expr_helper.h" -#include "oneflow/core/framework/user_op_conf_trait.h" +#include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { -namespace { - -struct ConvInterpState : public OpExprInterpState { - bool weight_requires_grad = true; - bool input_requires_grad = true; +struct ConvolutionNdInterpState : public OpExprInterpState { + bool input_requires_grad = false; + bool weight_requires_grad = false; + size_t input_index; + size_t weight_index; + + std::string data_format; + std::vector padding_before; + std::vector kernel_size; + std::vector strides; + std::vector dilation_rate; + int32_t groups; }; -class ConvNdGrad : public OpExprGradFunction { +class ConvolutionNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(ConvInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const override; - Maybe Apply(const ConvInterpState* ctx, const TensorTuple& out_grads, + Maybe Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: - std::shared_ptr op_trait_; - std::shared_ptr data_format_; - std::shared_ptr> padding_before_; - std::shared_ptr> kernel_size_; - std::shared_ptr> strides_; - std::shared_ptr> dilation_rate_; - int32_t groups_; - - std::shared_ptr data_grad_op_; - std::shared_ptr weight_grad_op_; + AttrMap base_attrs_; }; -Maybe ConvNdGrad::Init(const OpExpr& op) { +Maybe ConvolutionNd::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); - const std::string& op_name = fw_op_expr->op_name(); - op_trait_ = std::make_shared(op_name, fw_op_expr->proto()); - - data_format_ = JUST(op_trait_->GetAttr("data_format")); - padding_before_ = JUST(op_trait_->GetAttr>("padding_before")); - kernel_size_ = JUST(op_trait_->GetAttr>("kernel_size")); - strides_ = JUST(op_trait_->GetAttr>("strides")); - dilation_rate_ = JUST(op_trait_->GetAttr>("dilation_rate")); - groups_ = JUST(op_trait_->GetAttr("groups")); - int32_t ndims = kernel_size_->size(); - CHECK_EQ_OR_RETURN(ndims, strides_->size()); - CHECK_EQ_OR_RETURN(ndims, dilation_rate_->size()); - data_grad_op_ = JUST(op_expr_helper::ConvNdDataGradOp(*kernel_size_, *strides_, *padding_before_, - *dilation_rate_, groups_, *data_format_)); - - weight_grad_op_ = JUST(op_expr_helper::ConvNdFilterGradOp( - *kernel_size_, *strides_, *padding_before_, *dilation_rate_, groups_, *data_format_)); - + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } -Maybe ConvNdGrad::Capture(ConvInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrMap& attrs) const { +Maybe ConvolutionNd::Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->weight_requires_grad = inputs.at(1)->requires_grad(); - ctx->SaveTensorForBackward(inputs.at(0)); // x + if (!ctx->input_requires_grad && !ctx->weight_requires_grad) { return Maybe::Ok(); } if (ctx->input_requires_grad) { - ctx->SaveTensorForBackward(inputs.at(1)); // weight + ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1)); // weight } + ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->data_format = JUST(composed_attrs.GetAttr("data_format")); + ctx->padding_before = JUST(composed_attrs.GetAttr>("padding_before")); + ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); + ctx->strides = JUST(composed_attrs.GetAttr>("strides")); + ctx->dilation_rate = JUST(composed_attrs.GetAttr>("dilation_rate")); + ctx->groups = JUST(composed_attrs.GetAttr("groups")); return Maybe::Ok(); } -Maybe ConvNdGrad::Apply(const ConvInterpState* ctx, const TensorTuple& out_grads, - TensorTuple* in_grads) const { - CHECK_EQ_OR_RETURN(out_grads.size(), 1); - const auto& dy = out_grads.at(0); - +Maybe ConvolutionNd::Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { in_grads->resize(2); + size_t num_spatial_dims = ctx->kernel_size.size(); if (ctx->input_requires_grad) { - const auto& x = ctx->SavedTensors().at(0); - const auto& weight = ctx->SavedTensors().at(1); - in_grads->at(0) = - JUST(OpInterpUtil::Dispatch(*data_grad_op_, {dy, weight, x}, /*attrs=*/{})); + const auto& weight = ctx->SavedTensors().at(ctx->weight_index); + const auto& input = ctx->SavedTensors().at(ctx->input_index); + in_grads->at(0) = JUST(functional::ConvDataGrad( + out_grads.at(0), weight, input, num_spatial_dims, ctx->kernel_size, ctx->strides, + ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } if (ctx->weight_requires_grad) { - const auto& x = ctx->SavedTensors().at(0); - in_grads->at(1) = JUST(OpInterpUtil::Dispatch(*weight_grad_op_, {dy, x}, /*attrs=*/{})); + const auto& input = ctx->SavedTensors().at(ctx->input_index); + in_grads->at(1) = JUST(functional::ConvFilterGrad( + out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides, + ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } return Maybe::Ok(); } -} // namespace - -REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvNdGrad); -REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvNdGrad); -REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvNdGrad); +REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvolutionNd); +REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvolutionNd); +REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvolutionNd); } // namespace one } // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index a96e7512654..503c3d8bbb4 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -238,6 +238,32 @@ signature: "Tensor BiasAdd(Tensor x, Tensor bias, *, Int32 axis=1)" bind_python: True +- name: "conv2d" + signature: + "Tensor Conv2D(Tensor x, Tensor weight, *, Int32 filters, Int32List kernel_size, + Int32List strides, Int32List padding_before, Int32List dilation_rate, + Int32 groups=1, String data_format=\"channels_first\")" + bind_python: True + +- name: "conv_data_grad" + signature: + "Tensor ConvDataGrad(Tensor dy, Tensor weight, Tensor x, *, Int32 num_spatial_dims, + Int32List kernel_size, Int32List strides, Int32List padding_before, + Int32List dilation_rate, Int32 groups=1, + String data_format=\"channels_first\")" + bind_python: False + +- name: "conv_filter_grad" + signature: + "Tensor ConvFilterGrad(Tensor dy, Tensor x, *, Int32 num_spatial_dims, Int32List kernel_size, + Int32List strides, Int32List padding_before, Int32List dilation_rate, + Int32 groups=1, String data_format=\"channels_first\")" + bind_python: False + +- name: "conv_bias_grad" + signature: "Tensor ConvBiasGrad(Tensor dy, *, Int32 num_spatial_dims, String data_format=\"channels_first\")" + bind_python: False + - name: "expand" signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 575c68a7307..fe5dd0cf1ba 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -47,6 +47,33 @@ class BiasAddFunctor { std::shared_ptr op_; }; +class Conv2DFunctor { + public: + Conv2DFunctor() { + op_ = CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& weight, const int32_t& filters, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& padding_before, + const std::vector& dilation_rate, const int32_t& groups, + const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("filters", filters)); + JUST(attrs.SetAttr>("kernel_size", kernel_size)); + JUST(attrs.SetAttr>("strides", strides)); + JUST(attrs.SetAttr>("padding_before", padding_before)); + JUST(attrs.SetAttr>("dilation_rate", dilation_rate)); + JUST(attrs.SetAttr("groups", groups)); + JUST(attrs.SetAttr("data_format", data_format)); + return OpInterpUtil::Dispatch(*op_, {x, weight}, attrs); + } + + private: + std::shared_ptr op_; +}; + class MatMulBaseFunctor { public: MatMulBaseFunctor() = default; @@ -218,6 +245,7 @@ class NormalizationFunctor { ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("BiasAdd"); + m.add_functor("Conv2D"); m.add_functor("MatMul"); m.add_functor("BatchMatMul"); m.add_functor("BroadcastMatMul"); diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp new file mode 100644 index 00000000000..6892ea4669a --- /dev/null +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -0,0 +1,122 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/impl/common.h" +#include "oneflow/core/functional/impl/unary_functor.h" +#include "oneflow/core/functional/scalar.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class ConvBiasGradFunctor { + public: + ConvBiasGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("conv_bias_grad").Input("dy").Output("bias_diff").Build()); + } + Maybe operator()(const std::shared_ptr& dy, const int32_t& num_spatial_dims, + const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("num_spatial_dims", num_spatial_dims)); + JUST(attrs.SetAttr("data_format", data_format)); + return OpInterpUtil::Dispatch(*op_, {dy}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ConvFilterGradFunctor { + public: + ConvFilterGradFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("conv_filter_grad").Input("dy").Input("x").Output("filter_diff").Build()); + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& x, const int32_t& num_spatial_dims, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& padding_before, + const std::vector& dilation_rate, const int32_t& groups, + const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("num_spatial_dims", num_spatial_dims)); + JUST(attrs.SetAttr>("kernel_size", kernel_size)); + JUST(attrs.SetAttr>("strides", strides)); + JUST(attrs.SetAttr>("padding_before", padding_before)); + JUST(attrs.SetAttr>("dilation_rate", dilation_rate)); + JUST(attrs.SetAttr("groups", groups)); + JUST(attrs.SetAttr("data_format", data_format)); + return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ConvDataGradFunctor { + public: + ConvDataGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("conv_data_grad") + .Input("dy") + .Input("filter") + .Input("x_like") + .Output("dx") + .Build()); + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& weight, + const std::shared_ptr& x, const int32_t& num_spatial_dims, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& padding_before, + const std::vector& dilation_rate, const int32_t& groups, + const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("num_spatial_dims", num_spatial_dims)); + JUST(attrs.SetAttr>("kernel_size", kernel_size)); + JUST(attrs.SetAttr>("strides", strides)); + JUST(attrs.SetAttr>("padding_before", padding_before)); + JUST(attrs.SetAttr>("dilation_rate", dilation_rate)); + JUST(attrs.SetAttr("groups", groups)); + JUST(attrs.SetAttr("data_format", data_format)); + return OpInterpUtil::Dispatch(*op_, {dy, weight, x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor("ConvBiasGrad"); + m.add_functor("ConvFilterGrad"); + m.add_functor("ConvDataGrad"); +}; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 6e98d72a724..0707206fb2a 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -209,59 +209,20 @@ def __init__( super().__init__() assert padding_mode == "zeros" - kernel_size = _pair(kernel_size) - self.kernel_size = kernel_size - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 self.out_channels = out_channels self.weight = flow.nn.Parameter( - flow.Tensor(out_channels, in_channels // groups, *kernel_size) + flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) ) self.bias = None - self._bias_add_op = None if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) - self._bias_add_op = ( - flow.builtin_op("bias_add") - .Input("a") - .Input("b") - .Output("out") - .Attr("axis", 1) - .Build() - ) - - self._op = ( - flow.builtin_op("conv2d") - .Input("in") - .Input("weight") - .Attr("filters", out_channels) - .Attr("padding_before", padding) - .Attr("strides", stride) - .Attr("kernel_size", kernel_size) - .Attr("dilation_rate", dilation) - .Attr("groups", groups) - .Attr("data_format", "channels_first") - .Output("out") - .Build() - ) - self._cpu_op = ( - flow.builtin_op("conv2d") - .Input("in") - .Input("weight") - .Attr("filters", out_channels // groups) - .Attr("padding_before", padding) - .Attr("strides", stride) - .Attr("kernel_size", kernel_size) - .Attr("dilation_rate", dilation) - .Attr("groups", 1) - .Attr("data_format", "channels_first") - .Output("out") - .Build() - ) self.reset_parameters() def reset_parameters(self) -> None: @@ -280,14 +241,34 @@ def forward(self, x): out_list = [] for i in range(len(in_split_list)): out_list.append( - self._cpu_op(in_split_list[i], self.weight[i : i + 1, :, :, :])[0] + flow.F.conv2d( + in_split_list[i], + self.weight[i : i + 1, :, :, :], + filters=self.out_channels // self.groups, + kernel_size=self.kernel_size, + strides=self.stride, + padding_before=self.padding, + dilation_rate=self.dilation, + groups=1, + data_format="channels_first", + ) ) res = flow.experimental.cat(out_list, dim=in_channel_axis) else: - res = self._op(x, self.weight)[0] + res = flow.F.conv2d( + x, + self.weight, + filters=self.out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + padding_before=self.padding, + dilation_rate=self.dilation, + groups=self.groups, + data_format="channels_first", + ) - if self._bias_add_op is not None: - res = self._bias_add_op(res, self.bias)[0] + if self.bias is not None: + res = flow.F.bias_add(res, self.bias, axis=1) return res diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp index a095e240441..acfa34299be 100644 --- a/oneflow/user/kernels/conv_cudnn_kernels.cpp +++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp @@ -149,8 +149,8 @@ class ConvGpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const { + std::shared_ptr CreateConvCudnnOpKernelState( + user_op::KernelComputeContext* ctx) const { const auto& data_format = ctx->Attr("data_format"); int32_t filters = ctx->Attr("filters"); @@ -185,7 +185,8 @@ class ConvGpuKernel final : public user_op::OpKernel { const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); if (bias != nullptr) { - ConvCudnnOpKernelState* conv_state = dynamic_cast(state); + auto state = CreateConvCudnnOpKernelState(ctx); + auto* conv_state = state.get(); CHECK_NOTNULL(conv_state); OF_CUDNN_CHECK(cudnnAddTensor(ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr(), conv_state->bias_desc->Get(), bias->dptr(), @@ -352,8 +353,8 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const { + std::shared_ptr CreateConvBiasGradState( + user_op::KernelComputeContext* ctx) const { const auto* bias_diff = ctx->TensorDesc4ArgNameAndIndex("bias_diff", 0); const auto* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const auto& data_format = ctx->Attr("data_format"); @@ -375,7 +376,7 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* bias_diff = ctx->Tensor4ArgNameAndIndex("bias_diff", 0); CHECK_EQ(bias_diff->shape().NumAxes(), 1); @@ -386,7 +387,8 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel { std::unique_ptr dy_desc; dy_desc.reset(new CudnnTensorDesc(dy->data_type(), dy->shape(), data_format)); - auto* bias_grad_state = dynamic_cast(state); + auto state = CreateConvBiasGradState(ctx); + auto* bias_grad_state = state.get(); CHECK_NOTNULL(bias_grad_state); OF_CUDNN_CHECK(cudnnConvolutionBackwardBias( ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr(), dy_desc->Get(), dy->dptr(), diff --git a/oneflow/user/kernels/conv_kernels.cpp b/oneflow/user/kernels/conv_kernels.cpp index 6c5a8573474..8e846b4c43f 100644 --- a/oneflow/user/kernels/conv_kernels.cpp +++ b/oneflow/user/kernels/conv_kernels.cpp @@ -326,10 +326,10 @@ struct ConvOpKernelState final : public user_op::OpKernelState { }; template -std::shared_ptr CreateConvOpKernelState(user_op::KernelInitContext* ctx, - const std::string& in_name, - const std::string& out_name, - const std::string& weight_name) { +std::shared_ptr> CreateConvOpKernelState(user_op::KernelComputeContext* ctx, + const std::string& in_name, + const std::string& out_name, + const std::string& weight_name) { const auto& data_format = ctx->Attr("data_format"); std::shared_ptr> state(new ConvOpKernelState()); @@ -394,13 +394,12 @@ class ConvCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const { - return CreateConvOpKernelState(ctx, "in", "out", "weight"); - } - private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx) const override { + auto state = CreateConvOpKernelState(ctx, "in", "out", "weight"); + auto* conv_state = state.get(); + CHECK_NOTNULL(conv_state); + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); @@ -408,9 +407,6 @@ class ConvCpuKernel final : public user_op::OpKernel { T* col_buf_dptr = tmp_buffer->mut_dptr(); - auto* conv_state = dynamic_cast*>(state); - conv_state->Update(in->shape(), out->shape()); - CHECK_NOTNULL(conv_state); bool is_bias_mul_inited = false; for (int64_t i = 0; i < in->shape().At(0); ++i) { conv_state->im2col_func_(GetImgDptr(in, i), ShapeView(conv_state->in_5d_shape_), @@ -495,20 +491,17 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const { - return CreateConvOpKernelState(ctx, "dx", "dy", "filter"); - } - private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* conv_state = dynamic_cast*>(state); + void Compute(user_op::KernelComputeContext* ctx) const override { + auto state = CreateConvOpKernelState(ctx, "dx", "dy", "filter"); + auto* conv_state = state.get(); CHECK_NOTNULL(conv_state); + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - conv_state->Update(dx->shape(), dy->shape()); + Memset(ctx->device_ctx(), dx->mut_dptr(), 0, dx->shape().elem_cnt() * sizeof(T)); @@ -571,21 +564,16 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const { - return CreateConvOpKernelState(ctx, "x", "dy", "filter_diff"); - } - private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* conv_state = dynamic_cast*>(state); + void Compute(user_op::KernelComputeContext* ctx) const override { + auto state = CreateConvOpKernelState(ctx, "x", "dy", "filter_diff"); + auto* conv_state = state.get(); CHECK_NOTNULL(conv_state); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex("filter_diff", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - conv_state->Update(x->shape(), dy->shape()); Memset(ctx->device_ctx(), filter_diff->mut_dptr(), 0, filter_diff->shape().elem_cnt() * sizeof(T)); From a004a486425aaf6f4e00c6755ffbba1d2179be91 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Wed, 23 Jun 2021 12:05:45 +0800 Subject: [PATCH 18/42] Fix functional normalization. --- oneflow/core/functional/functional_api.yaml | 2 +- oneflow/core/functional/impl/nn_functor.cpp | 6 +++--- oneflow/python/nn/modules/normalization.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index a96e7512654..24884e18f2d 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -290,7 +290,7 @@ - name: "layer_norm_affine" signature: - "Tensor LayerNormAffine(Tensor x, Tensor beta, Tensor gamma, *, Int64 begin_norm_axis, + "Tensor LayerNormAffine(Tensor x, Tensor gamma, Tensor beta, *, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 575c68a7307..93fb685be2e 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -117,8 +117,8 @@ class LayerNormAffineFunctor { LayerNormAffineFunctor() { op_ = CHECK_JUST(one::OpBuilder("layer_norm") .Input("x") - .Input("beta") .Input("gamma") + .Input("beta") .Output("y") .Output("mean") .Output("inv_variance") @@ -126,8 +126,8 @@ class LayerNormAffineFunctor { .Build()); } Maybe operator()(const std::shared_ptr& x, - const std::shared_ptr& beta, const std::shared_ptr& gamma, + const std::shared_ptr& beta, const int64_t& begin_norm_axis, const int64_t& begin_params_axis, const double& epsilon) const { MutableAttrMap attrs; @@ -136,7 +136,7 @@ class LayerNormAffineFunctor { JUST(attrs.SetAttr("epsilon", epsilon)); JUST(attrs.SetAttr("center", true)); JUST(attrs.SetAttr("scale", true)); - return OpInterpUtil::Dispatch(*op_, {x, beta, gamma}, attrs); + return OpInterpUtil::Dispatch(*op_, {x, gamma, beta}, attrs); } private: diff --git a/oneflow/python/nn/modules/normalization.py b/oneflow/python/nn/modules/normalization.py index 14daa257281..6df5cfe1b5b 100644 --- a/oneflow/python/nn/modules/normalization.py +++ b/oneflow/python/nn/modules/normalization.py @@ -308,7 +308,7 @@ def forward(self, x): return affined else: if self.elementwise_affine: - flow.F.layer_norm_affine( + res = flow.F.layer_norm_affine( x, self.weight, self.bias, @@ -317,7 +317,7 @@ def forward(self, x): epsilon=self.epsilon, ) else: - flow.F.layer_norm( + res = flow.F.layer_norm( x, begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis, From 317e28210407f7d8cb88ecdddc28b0c50f2fc772 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Wed, 23 Jun 2021 06:04:22 +0000 Subject: [PATCH 19/42] auto format by CI --- oneflow/core/functional/impl/nn_functor.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 93fb685be2e..3fb8358a724 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -127,9 +127,8 @@ class LayerNormAffineFunctor { } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& gamma, - const std::shared_ptr& beta, - const int64_t& begin_norm_axis, const int64_t& begin_params_axis, - const double& epsilon) const { + const std::shared_ptr& beta, const int64_t& begin_norm_axis, + const int64_t& begin_params_axis, const double& epsilon) const { MutableAttrMap attrs; JUST(attrs.SetAttr("begin_norm_axis", begin_norm_axis)); JUST(attrs.SetAttr("begin_params_axis", begin_params_axis)); From aee0ffa95b01a97a98894b41c5b051015293cf09 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Wed, 23 Jun 2021 15:57:26 +0800 Subject: [PATCH 20/42] unfinished --- oneflow/core/functional/functional_api.yaml | 6 ++ oneflow/core/functional/impl/nn_functor.cpp | 25 +++++++ oneflow/python/nn/modules/conv.py | 82 +++++++++++++++++++++ 3 files changed, 113 insertions(+) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index a282f7754b9..73b34aff3da 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -238,6 +238,12 @@ signature: "Tensor BiasAdd(Tensor x, Tensor bias, *, Int32 axis=1)" bind_python: True +- name: "conv1d" + signature: + "Tensor Conv1D(Tensor x, Tensor weight, *, Int32List strides, Int32List padding_before, + Int32List dilation_rate, Int32 groups=1, String data_format=\"channels_first\")" + bind_python: True + - name: "conv2d" signature: "Tensor Conv2D(Tensor x, Tensor weight, *, Int32 filters, Int32List kernel_size, diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index db8ba4b980f..8c3b8b2dfc6 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -47,6 +47,30 @@ class BiasAddFunctor { std::shared_ptr op_; }; +class Conv1DFunctor { + public: + Conv1DFunctor() { + op_ = CHECK_JUST(one::OpBuilder("conv1d").Input("in").Input("weight").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& weight, + const std::vector& strides, + const std::vector& padding_before, + const std::vector& dilation_rate, const int32_t& groups, + const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr>("strides", strides)); + JUST(attrs.SetAttr>("padding_before", padding_before)); + JUST(attrs.SetAttr>("dilation_rate", dilation_rate)); + JUST(attrs.SetAttr("groups", groups)); + JUST(attrs.SetAttr("data_format", data_format)); + return OpInterpUtil::Dispatch(*op_, {x, weight}, attrs); + } + + private: + std::shared_ptr op_; +}; + class Conv2DFunctor { public: Conv2DFunctor() { @@ -244,6 +268,7 @@ class NormalizationFunctor { ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("BiasAdd"); + m.add_functor("Conv1D"); m.add_functor("Conv2D"); m.add_functor("MatMul"); m.add_functor("BatchMatMul"); diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 1bf6e8c280f..90f3096b9e9 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -273,6 +273,88 @@ def forward(self, x): return res +@oneflow_export("nn.Conv2d") +@experimental_api +class Conv2d(Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", # TODO: refine this type + ): + super().__init__() + + assert padding_mode == "zeros" + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + assert in_channels % groups == 0 + assert out_channels % groups == 0 + self.out_channels = out_channels + self.weight = flow.nn.Parameter( + flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + self.out_channel_groups = out_channels // groups + self.bias = None + if bias: + self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self) -> None: + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, x): + if x.device.type == "cpu" and self.groups > 1: + in_channel_axis = 1 + in_split_list = ConvUtil.split( + x, axis=in_channel_axis, split_num=self.groups + ) + out_list = [] + for i in range(len(in_split_list)): + out_list.append( + flow.F.conv2d( + in_split_list[i], + self.weight[i : i + 1, :, :, :], + filters=self.out_channels // self.groups, + kernel_size=self.kernel_size, + strides=self.stride, + padding_before=self.padding, + dilation_rate=self.dilation, + groups=1, + data_format="channels_first", + ) + ) + res = flow.experimental.cat(out_list, dim=in_channel_axis) + else: + res = flow.F.conv2d( + x, + self.weight, + filters=self.out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + padding_before=self.padding, + dilation_rate=self.dilation, + groups=self.groups, + data_format="channels_first", + ) + + if self.bias is not None: + res = flow.F.bias_add(res, self.bias, axis=1) + return res + + if __name__ == "__main__": import doctest From bc4aef7b70175dba73b676a86a495b0fcdd684a7 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Wed, 23 Jun 2021 22:11:57 +0800 Subject: [PATCH 21/42] Refine code style --- oneflow/user/kernels/conv_cudnn_kernels.cpp | 10 ++++------ oneflow/user/kernels/conv_kernels.cpp | 15 ++++++--------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp index acfa34299be..55f9674c36c 100644 --- a/oneflow/user/kernels/conv_cudnn_kernels.cpp +++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp @@ -185,9 +185,8 @@ class ConvGpuKernel final : public user_op::OpKernel { const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); if (bias != nullptr) { - auto state = CreateConvCudnnOpKernelState(ctx); - auto* conv_state = state.get(); - CHECK_NOTNULL(conv_state); + const auto& conv_state = CreateConvCudnnOpKernelState(ctx); + CHECK_NOTNULL(conv_state.get()); OF_CUDNN_CHECK(cudnnAddTensor(ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr(), conv_state->bias_desc->Get(), bias->dptr(), CudnnSPOnePtr(), args.ydesc.Get(), out->mut_dptr())); @@ -387,9 +386,8 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel { std::unique_ptr dy_desc; dy_desc.reset(new CudnnTensorDesc(dy->data_type(), dy->shape(), data_format)); - auto state = CreateConvBiasGradState(ctx); - auto* bias_grad_state = state.get(); - CHECK_NOTNULL(bias_grad_state); + const auto& bias_grad_state = CreateConvBiasGradState(ctx); + CHECK_NOTNULL(bias_grad_state.get()); OF_CUDNN_CHECK(cudnnConvolutionBackwardBias( ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr(), dy_desc->Get(), dy->dptr(), CudnnSPZeroPtr(), bias_grad_state->bias_diff_desc->Get(), bias_diff->mut_dptr())); diff --git a/oneflow/user/kernels/conv_kernels.cpp b/oneflow/user/kernels/conv_kernels.cpp index 8e846b4c43f..bb01035f377 100644 --- a/oneflow/user/kernels/conv_kernels.cpp +++ b/oneflow/user/kernels/conv_kernels.cpp @@ -396,9 +396,8 @@ class ConvCpuKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx) const override { - auto state = CreateConvOpKernelState(ctx, "in", "out", "weight"); - auto* conv_state = state.get(); - CHECK_NOTNULL(conv_state); + const auto& conv_state = CreateConvOpKernelState(ctx, "in", "out", "weight"); + CHECK_NOTNULL(conv_state.get()); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); @@ -493,9 +492,8 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx) const override { - auto state = CreateConvOpKernelState(ctx, "dx", "dy", "filter"); - auto* conv_state = state.get(); - CHECK_NOTNULL(conv_state); + const auto& conv_state = CreateConvOpKernelState(ctx, "dx", "dy", "filter"); + CHECK_NOTNULL(conv_state.get()); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0); @@ -566,9 +564,8 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx) const override { - auto state = CreateConvOpKernelState(ctx, "x", "dy", "filter_diff"); - auto* conv_state = state.get(); - CHECK_NOTNULL(conv_state); + const auto& conv_state = CreateConvOpKernelState(ctx, "x", "dy", "filter_diff"); + CHECK_NOTNULL(conv_state.get()); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); From 5573cea955cd7518586f837ff7641275dae81dcd Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 09:27:47 +0800 Subject: [PATCH 22/42] align Torch params --- oneflow/core/functional/functional_api.yaml | 5 +-- oneflow/core/functional/impl/nn_functor.cpp | 41 ++++++++++++--------- oneflow/python/nn/modules/conv.py | 24 +++++------- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index a282f7754b9..f3e1e22042d 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -240,9 +240,8 @@ - name: "conv2d" signature: - "Tensor Conv2D(Tensor x, Tensor weight, *, Int32 filters, Int32List kernel_size, - Int32List strides, Int32List padding_before, Int32List dilation_rate, - Int32 groups=1, String data_format=\"channels_first\")" + "Tensor Conv2D(Tensor x, Tensor weight, Tensor bias, *, Int32List stride, + Int32List padding, Int32List dilation, Int32 groups=1)" bind_python: True - name: "conv_data_grad" diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index db8ba4b980f..dd3105efa20 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -50,28 +50,35 @@ class BiasAddFunctor { class Conv2DFunctor { public: Conv2DFunctor() { - op_ = CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); + conv_op_ = + CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); + bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, - const std::shared_ptr& weight, const int32_t& filters, - const std::vector& kernel_size, - const std::vector& strides, - const std::vector& padding_before, - const std::vector& dilation_rate, const int32_t& groups, - const std::string& data_format) const { - MutableAttrMap attrs; - JUST(attrs.SetAttr("filters", filters)); - JUST(attrs.SetAttr>("kernel_size", kernel_size)); - JUST(attrs.SetAttr>("strides", strides)); - JUST(attrs.SetAttr>("padding_before", padding_before)); - JUST(attrs.SetAttr>("dilation_rate", dilation_rate)); - JUST(attrs.SetAttr("groups", groups)); - JUST(attrs.SetAttr("data_format", data_format)); - return OpInterpUtil::Dispatch(*op_, {x, weight}, attrs); + const std::shared_ptr& weight, + const std::shared_ptr& bias, + const std::vector& stride, const std::vector& padding, + const std::vector& dilation, const int32_t& groups) const { + MutableAttrMap conv_attrs; + std::vector kernel_size_vec; + for (int i = 0; i < 2; i++) { kernel_size_vec.push_back((weight->shape())->At(i + 2)); } + JUST(conv_attrs.SetAttr("filters", (weight->shape())->At(0))); + JUST(conv_attrs.SetAttr>("padding_before", padding)); + JUST(conv_attrs.SetAttr>("kernel_size", kernel_size_vec)); + JUST(conv_attrs.SetAttr>("strides", stride)); + JUST(conv_attrs.SetAttr>("dilation_rate", dilation)); + JUST(conv_attrs.SetAttr("groups", groups)); + JUST(conv_attrs.SetAttr("data_format", std::string("channels_first"))); + std::shared_ptr conv_out = + JUST(OpInterpUtil::Dispatch(*conv_op_, {x, weight}, conv_attrs)); + MutableAttrMap bias_attrs; + JUST(bias_attrs.SetAttr("axis", 1)); + return OpInterpUtil::Dispatch(*bias_op_, {conv_out, bias}, bias_attrs); } private: - std::shared_ptr op_; + std::shared_ptr conv_op_; + std::shared_ptr bias_op_; }; class MatMulBaseFunctor { diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 1bf6e8c280f..a43734c7489 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -216,6 +216,7 @@ def __init__( self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 + self.in_channels = in_channels self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) @@ -245,13 +246,11 @@ def forward(self, x): flow.F.conv2d( in_split_list[i], self.weight[i : i + 1, :, :, :], - filters=self.out_channels // self.groups, - kernel_size=self.kernel_size, - strides=self.stride, - padding_before=self.padding, - dilation_rate=self.dilation, + self.bias[i : i + 1, :, :, :], + stride=self.stride, + padding=self.padding, + dilation=self.dilation, groups=1, - data_format="channels_first", ) ) res = flow.experimental.cat(out_list, dim=in_channel_axis) @@ -259,17 +258,12 @@ def forward(self, x): res = flow.F.conv2d( x, self.weight, - filters=self.out_channels, - kernel_size=self.kernel_size, - strides=self.stride, - padding_before=self.padding, - dilation_rate=self.dilation, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, groups=self.groups, - data_format="channels_first", ) - - if self.bias is not None: - res = flow.F.bias_add(res, self.bias, axis=1) return res From b8fc8232f17aeb2c7308e95564286eeca93829eb Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 09:27:47 +0800 Subject: [PATCH 23/42] align Torch params --- oneflow/core/functional/functional_api.yaml | 25 ++++++++++++++ oneflow/core/functional/impl/nn_functor.cpp | 36 +++++++++++++++++++++ oneflow/python/nn/modules/conv.py | 26 +++++++-------- 3 files changed, 72 insertions(+), 15 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 24884e18f2d..f3e1e22042d 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -238,6 +238,31 @@ signature: "Tensor BiasAdd(Tensor x, Tensor bias, *, Int32 axis=1)" bind_python: True +- name: "conv2d" + signature: + "Tensor Conv2D(Tensor x, Tensor weight, Tensor bias, *, Int32List stride, + Int32List padding, Int32List dilation, Int32 groups=1)" + bind_python: True + +- name: "conv_data_grad" + signature: + "Tensor ConvDataGrad(Tensor dy, Tensor weight, Tensor x, *, Int32 num_spatial_dims, + Int32List kernel_size, Int32List strides, Int32List padding_before, + Int32List dilation_rate, Int32 groups=1, + String data_format=\"channels_first\")" + bind_python: False + +- name: "conv_filter_grad" + signature: + "Tensor ConvFilterGrad(Tensor dy, Tensor x, *, Int32 num_spatial_dims, Int32List kernel_size, + Int32List strides, Int32List padding_before, Int32List dilation_rate, + Int32 groups=1, String data_format=\"channels_first\")" + bind_python: False + +- name: "conv_bias_grad" + signature: "Tensor ConvBiasGrad(Tensor dy, *, Int32 num_spatial_dims, String data_format=\"channels_first\")" + bind_python: False + - name: "expand" signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index ad43ba07949..c18221c087f 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -47,6 +47,42 @@ class BiasAddFunctor { std::shared_ptr op_; }; +<<<<<<< HEAD +======= +class Conv2DFunctor { + public: + Conv2DFunctor() { + conv_op_ = + CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); + bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& weight, + const std::shared_ptr& bias, + const std::vector& stride, const std::vector& padding, + const std::vector& dilation, const int32_t& groups) const { + MutableAttrMap conv_attrs; + std::vector kernel_size_vec; + for (int i = 0; i < 2; i++) { kernel_size_vec.push_back((weight->shape())->At(i + 2)); } + JUST(conv_attrs.SetAttr("filters", (weight->shape())->At(0))); + JUST(conv_attrs.SetAttr>("padding_before", padding)); + JUST(conv_attrs.SetAttr>("kernel_size", kernel_size_vec)); + JUST(conv_attrs.SetAttr>("strides", stride)); + JUST(conv_attrs.SetAttr>("dilation_rate", dilation)); + JUST(conv_attrs.SetAttr("groups", groups)); + JUST(conv_attrs.SetAttr("data_format", std::string("channels_first"))); + std::shared_ptr conv_out = + JUST(OpInterpUtil::Dispatch(*conv_op_, {x, weight}, conv_attrs)); + MutableAttrMap bias_attrs; + JUST(bias_attrs.SetAttr("axis", 1)); + return OpInterpUtil::Dispatch(*bias_op_, {conv_out, bias}, bias_attrs); + } + + private: + std::shared_ptr conv_op_; + std::shared_ptr bias_op_; +}; +>>>>>>> 5573cea95... align Torch params class MatMulBaseFunctor { public: diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 9d8f30cb49b..c34492be7fa 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -217,6 +217,7 @@ def __init__( assert in_channels % groups == 0 assert out_channels % groups == 0 self.in_channels = in_channels +<<<<<<< HEAD self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) @@ -292,6 +293,8 @@ def __init__( self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 +======= +>>>>>>> 5573cea95... align Torch params self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) @@ -321,13 +324,11 @@ def forward(self, x): flow.F.conv2d( in_split_list[i], self.weight[i : i + 1, :, :, :], - filters=self.out_channels // self.groups, - kernel_size=self.kernel_size, - strides=self.stride, - padding_before=self.padding, - dilation_rate=self.dilation, + self.bias[i : i + 1, :, :, :], + stride=self.stride, + padding=self.padding, + dilation=self.dilation, groups=1, - data_format="channels_first", ) ) res = flow.experimental.cat(out_list, dim=in_channel_axis) @@ -335,17 +336,12 @@ def forward(self, x): res = flow.F.conv2d( x, self.weight, - filters=self.out_channels, - kernel_size=self.kernel_size, - strides=self.stride, - padding_before=self.padding, - dilation_rate=self.dilation, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, groups=self.groups, - data_format="channels_first", ) - - if self.bias is not None: - res = flow.F.bias_add(res, self.bias, axis=1) return res From 2f7eb0774fbf5b65f22b2b005dc7146f85de0afa Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 14:20:03 +0800 Subject: [PATCH 24/42] develop unfinish --- oneflow/core/functional/functional_api.yaml | 8 +- oneflow/core/functional/impl/nn_functor.cpp | 39 ++++- oneflow/python/nn/modules/conv.py | 153 +++----------------- 3 files changed, 57 insertions(+), 143 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index f3e1e22042d..ae80df0a427 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -238,12 +238,18 @@ signature: "Tensor BiasAdd(Tensor x, Tensor bias, *, Int32 axis=1)" bind_python: True +- name: "conv1d" + signature: + "Tensor Conv1D(Tensor x, Tensor weight, Tensor bias, *, Int32List stride, + Int32List padding, Int32List dilation, Int32 groups=1)" + bind_python: True + - name: "conv2d" signature: "Tensor Conv2D(Tensor x, Tensor weight, Tensor bias, *, Int32List stride, Int32List padding, Int32List dilation, Int32 groups=1)" bind_python: True - + - name: "conv_data_grad" signature: "Tensor ConvDataGrad(Tensor dy, Tensor weight, Tensor x, *, Int32 num_spatial_dims, diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index c18221c087f..625b7d58d9a 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -47,8 +47,39 @@ class BiasAddFunctor { std::shared_ptr op_; }; -<<<<<<< HEAD -======= +class Conv1DFunctor { + public: + Conv1DFunctor() { + conv_op_ = + CHECK_JUST(one::OpBuilder("conv1d").Input("in").Input("weight").Output("out").Build()); + bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& weight, + const std::shared_ptr& bias, + const std::vector& stride, const std::vector& padding, + const std::vector& dilation, const int32_t& groups) const { + MutableAttrMap conv_attrs; + std::vector kernel_size_vec{(weight->shape())->At(2)}; + JUST(conv_attrs.SetAttr("filters", int64_t((weight->shape())->At(0)))); + JUST(conv_attrs.SetAttr>("padding_before", padding)); + JUST(conv_attrs.SetAttr>("kernel_size", kernel_size_vec)); + JUST(conv_attrs.SetAttr>("strides", stride)); + JUST(conv_attrs.SetAttr>("dilation_rate", dilation)); + JUST(conv_attrs.SetAttr("groups", groups)); + JUST(conv_attrs.SetAttr("data_format", std::string("channels_first"))); + std::shared_ptr conv_out = + JUST(OpInterpUtil::Dispatch(*conv_op_, {x, weight}, conv_attrs)); + MutableAttrMap bias_attrs; + JUST(bias_attrs.SetAttr("axis", 1)); + return OpInterpUtil::Dispatch(*bias_op_, {conv_out, bias}, bias_attrs); + } + + private: + std::shared_ptr conv_op_; + std::shared_ptr bias_op_; +}; + class Conv2DFunctor { public: Conv2DFunctor() { @@ -82,7 +113,6 @@ class Conv2DFunctor { std::shared_ptr conv_op_; std::shared_ptr bias_op_; }; ->>>>>>> 5573cea95... align Torch params class MatMulBaseFunctor { public: @@ -254,11 +284,8 @@ class NormalizationFunctor { ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("BiasAdd"); -<<<<<<< HEAD m.add_functor("Conv1D"); m.add_functor("Conv2D"); -======= ->>>>>>> 86b81530a2b807c08c105e1ce72e2680175e324b m.add_functor("MatMul"); m.add_functor("BatchMatMul"); m.add_functor("BroadcastMatMul"); diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index c34492be7fa..6a8269704fd 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -17,8 +17,8 @@ import oneflow as flow from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.nn.module import Module -from oneflow.python.nn.modules.utils import _pair -from oneflow.python.nn.common_types import _size_2_t +from oneflow.python.nn.modules.utils import _single, _pair +from oneflow.python.nn.common_types import _size_1_t, _size_2_t from oneflow.python.nn import init @@ -75,133 +75,17 @@ def split(cls, x, axis, split_num): result_list.append(result) return result_list - -@oneflow_export("nn.Conv2d") +@oneflow_export("nn.Conv1d") @experimental_api -class Conv2d(Module): - r"""Applies a 2D convolution over an input signal composed of several input - planes. - - In the simplest case, the output value of the layer with input size - :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` - can be precisely described as: - - .. math:: - \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + - \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) - - - where :math:`\star` is the valid 2D `cross-correlation`_ operator, - :math:`N` is a batch size, :math:`C` denotes a number of channels, - :math:`H` is a height of input planes in pixels, and :math:`W` is - width in pixels. - - - * :attr:`stride` controls the stride for the cross-correlation, a single - number or a tuple. - * :attr:`padding` controls the amount of implicit padding on both - sides for :attr:`padding` number of points for each dimension. - * :attr:`dilation` controls the spacing between the kernel points; also - known as the à trous algorithm. It is harder to describe, but this `link`_ - has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters (of size - :math:`\frac{\text{out_channels}}{\text{in_channels}}`)., - - The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - - - a single ``int`` -- in which case the same value is used for the height and width dimension - - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, - and the second `int` for the width dimension - - Note: - When `groups == in_channels` and `out_channels == K * in_channels`, - where `K` is a positive integer, this operation is also known as a "depthwise convolution". - - In other words, for an input of size :math:`(N, C_{in}, L_{in})`, - a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments - :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`. - - - Args: - in_channels (int): Number of channels in the input image - out_channels (int): Number of channels produced by the convolution - kernel_size (int or tuple): Size of the convolving kernel - stride (int or tuple, optional): Stride of the convolution. Default: 1 - padding (int or tuple, optional): Zero-padding added to both sides of - the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` - dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 - groups (int, optional): Number of blocked connections from input - channels to output channels. Default: 1 - bias (bool, optional): If ``True``, adds a learnable bias to the - output. Default: ``True`` - - Shape: - - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where - - .. math:: - H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] - \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor - - .. math:: - W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] - \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor - - Attr: - - weight (Tensor): the learnable weights of the module of shape - :math:`(\text{out_channels}, \frac{\text{in_channels}}{\text{groups}},` - :math:`\text{kernel_size[0]}, \text{kernel_size[1]})`. - The values of these weights are sampled from - :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where - :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}` - - - bias (Tensor): the learnable bias of the module of shape - (out_channels). If :attr:`bias` is ``True``, - then the values of these weights are - sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where - :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}` - - For example: - - .. code-block:: python - - >>> import numpy as np - >>> import oneflow.experimental as flow - >>> import oneflow.experimental.nn as nn - >>> flow.enable_eager_execution() - - >>> arr = np.random.randn(20, 16, 50, 100) - >>> input = flow.Tensor(arr) - >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) - >>> output = m(input) - - .. _cross-correlation: - https://en.wikipedia.org/wiki/Cross-correlation - - .. _link: - https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md - """ - +class Conv1d(Module): def __init__( self, in_channels: int, out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = 1, - padding: _size_2_t = 0, - dilation: _size_2_t = 1, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type @@ -209,20 +93,18 @@ def __init__( super().__init__() assert padding_mode == "zeros" - self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride) - self.padding = _pair(padding) - self.dilation = _pair(dilation) + self.kernel_size = _single(kernel_size) + self.stride = _single(stride) + self.padding = _single(padding) + self.dilation = _single(dilation) self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 self.in_channels = in_channels -<<<<<<< HEAD self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) ) - self.out_channel_groups = out_channels // groups self.bias = None if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) @@ -244,10 +126,10 @@ def forward(self, x): out_list = [] for i in range(len(in_split_list)): out_list.append( - flow.F.conv2d( + flow.F.conv1d( in_split_list[i], - self.weight[i : i + 1, :, :, :], - self.bias[i : i + 1, :, :, :], + self.weight[i : i + 1, :, :], + self.bias[i : i + 1], stride=self.stride, padding=self.padding, dilation=self.dilation, @@ -256,7 +138,7 @@ def forward(self, x): ) res = flow.experimental.cat(out_list, dim=in_channel_axis) else: - res = flow.F.conv2d( + res = flow.F.conv1d( x, self.weight, self.bias, @@ -293,8 +175,7 @@ def __init__( self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 -======= ->>>>>>> 5573cea95... align Torch params + self.in_channels = in_channels self.out_channels = out_channels self.weight = flow.nn.Parameter( flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) From cc61549d0dbb106d99e9f6f3c28a387dee57c055 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 15:21:50 +0800 Subject: [PATCH 25/42] add conv1d --- oneflow/python/nn/modules/conv.py | 171 ++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 6a8269704fd..26141a3f6de 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -75,6 +75,7 @@ def split(cls, x, axis, split_num): result_list.append(result) return result_list + @oneflow_export("nn.Conv1d") @experimental_api class Conv1d(Module): @@ -90,6 +91,86 @@ def __init__( bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type ): + r""" + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be + precisely described as: + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) + \star \text{input}(N_i, k) + where :math:`\star` is the valid `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`L` is a length of signal sequence. + + This module supports :ref:`TensorFloat32`. + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a one-element tuple. + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. + * :attr:`dilation` controls the spacing between the kernel points; also + known as the à trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. + {groups_note} + Note: + {depthwise_separable_note} + Note: + {cudnn_reproducibility_note} + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` where + .. math:: + L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, + \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + + For example: + .. code-block:: python + >>> import numpy as np + >>> import oneflow.experimental as flow + >>> import oneflow.experimental.nn as nn + >>> flow.enable_eager_execution() + >>> arr = np.random.randn(20, 16, 50) + >>> input = flow.Tensor(arr) + >>> m = nn.Conv1d(16, 33, 3, stride=2) + >>> output = m(input) + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + super().__init__() assert padding_mode == "zeros" @@ -165,6 +246,96 @@ def __init__( bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type ): + r"""Applies a 2D convolution over an input signal composed of several input + planes. + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` + can be precisely described as: + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) + where :math:`\star` is the valid 2D `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`H` is a height of input planes in pixels, and :math:`W` is + width in pixels. + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a tuple. + * :attr:`padding` controls the amount of implicit padding on both + sides for :attr:`padding` number of points for each dimension. + * :attr:`dilation` controls the spacing between the kernel points; also + known as the à trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. + * :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters (of size + :math:`\frac{\text{out_channels}}{\text{in_channels}}`)., + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + Note: + When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also known as a "depthwise convolution". + In other words, for an input of size :math:`(N, C_{in}, L_{in})`, + a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments + :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`. + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + Attr: + - weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out_channels}, \frac{\text{in_channels}}{\text{groups}},` + :math:`\text{kernel_size[0]}, \text{kernel_size[1]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}` + - bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}` + For example: + .. code-block:: python + >>> import numpy as np + >>> import oneflow.experimental as flow + >>> import oneflow.experimental.nn as nn + >>> flow.enable_eager_execution() + >>> arr = np.random.randn(20, 16, 50, 100) + >>> input = flow.Tensor(arr) + >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> output = m(input) + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ super().__init__() assert padding_mode == "zeros" From 18ae4dc34bdf314883587100ae6fe10241be86e3 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 16:31:55 +0800 Subject: [PATCH 26/42] add conv1d docs rst --- docs/source/experimental.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index 95d7dd1cadf..7e2cdd88b1a 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -65,6 +65,7 @@ Experimental features .. autofunction:: oneflow.experimental.nn.ParameterDict .. autofunction:: oneflow.experimental.nn.ModuleList .. autofunction:: oneflow.experimental.nn.ModuleDict +.. autofunction:: oneflow.experimental.nn.Conv1d .. autofunction:: oneflow.experimental.nn.Conv2d .. autofunction:: oneflow.experimental.nn.ConstantPad2d .. autofunction:: oneflow.experimental.nn.ConvTranspose2d From 271f31ca9de47719a22f2a7f2b322fd1ada3eea7 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 16:32:10 +0800 Subject: [PATCH 27/42] add conv1d module and docs --- oneflow/python/nn/modules/conv.py | 372 ++++++++++++++++-------------- 1 file changed, 202 insertions(+), 170 deletions(-) diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 26141a3f6de..6a9f5ebe1c1 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -79,6 +79,94 @@ def split(cls, x, axis, split_num): @oneflow_export("nn.Conv1d") @experimental_api class Conv1d(Module): + r"""Applies a 1D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be + precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) + \star \text{input}(N_i, k) + + where :math:`\star` is the valid `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`L` is a length of signal sequence. + + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a one-element tuple. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. + + * :attr:`dilation` controls the spacing between the kernel points; also + known as the à trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` where + + .. math:: + L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, + \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + + For example: + + .. code-block:: python + + >>> import numpy as np + >>> import oneflow.experimental as flow + >>> import oneflow.experimental.nn as nn + >>> flow.enable_eager_execution() + + >>> arr = np.random.randn(20, 16, 50) + >>> input = flow.Tensor(arr) + >>> m = nn.Conv1d(16, 33, 3, stride=2)) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ def __init__( self, in_channels: int, @@ -91,86 +179,6 @@ def __init__( bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type ): - r""" - In the simplest case, the output value of the layer with input size - :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be - precisely described as: - .. math:: - \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + - \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) - \star \text{input}(N_i, k) - where :math:`\star` is the valid `cross-correlation`_ operator, - :math:`N` is a batch size, :math:`C` denotes a number of channels, - :math:`L` is a length of signal sequence. - - This module supports :ref:`TensorFloat32`. - * :attr:`stride` controls the stride for the cross-correlation, a single - number or a one-element tuple. - * :attr:`padding` controls the amount of padding applied to the input. It - can be either a string {{'valid', 'same'}} or a tuple of ints giving the - amount of implicit padding applied on both sides. - * :attr:`dilation` controls the spacing between the kernel points; also - known as the à trous algorithm. It is harder to describe, but this `link`_ - has a nice visualization of what :attr:`dilation` does. - {groups_note} - Note: - {depthwise_separable_note} - Note: - {cudnn_reproducibility_note} - Note: - ``padding='valid'`` is the same as no padding. ``padding='same'`` pads - the input so the output has the shape as the input. However, this mode - doesn't support any stride values other than 1. - Args: - in_channels (int): Number of channels in the input image - out_channels (int): Number of channels produced by the convolution - kernel_size (int or tuple): Size of the convolving kernel - stride (int or tuple, optional): Stride of the convolution. Default: 1 - padding (int, tuple or str, optional): Padding added to both sides of - the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` - dilation (int or tuple, optional): Spacing between kernel - elements. Default: 1 - groups (int, optional): Number of blocked connections from input - channels to output channels. Default: 1 - bias (bool, optional): If ``True``, adds a learnable bias to the - output. Default: ``True`` - - Shape: - - Input: :math:`(N, C_{in}, L_{in})` - - Output: :math:`(N, C_{out}, L_{out})` where - .. math:: - L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} - \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor - Attributes: - weight (Tensor): the learnable weights of the module of shape - :math:`(\text{out\_channels}, - \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. - The values of these weights are sampled from - :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where - :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` - bias (Tensor): the learnable bias of the module of shape - (out_channels). If :attr:`bias` is ``True``, then the values of these weights are - sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where - :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` - - For example: - .. code-block:: python - >>> import numpy as np - >>> import oneflow.experimental as flow - >>> import oneflow.experimental.nn as nn - >>> flow.enable_eager_execution() - >>> arr = np.random.randn(20, 16, 50) - >>> input = flow.Tensor(arr) - >>> m = nn.Conv1d(16, 33, 3, stride=2) - >>> output = m(input) - .. _cross-correlation: - https://en.wikipedia.org/wiki/Cross-correlation - .. _link: - https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md - """ - super().__init__() assert padding_mode == "zeros" @@ -234,6 +242,120 @@ def forward(self, x): @oneflow_export("nn.Conv2d") @experimental_api class Conv2d(Module): + r"""Applies a 2D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` + can be precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) + + + where :math:`\star` is the valid 2D `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`H` is a height of input planes in pixels, and :math:`W` is + width in pixels. + + + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a tuple. + * :attr:`padding` controls the amount of implicit padding on both + sides for :attr:`padding` number of points for each dimension. + * :attr:`dilation` controls the spacing between the kernel points; also + known as the à trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. + * :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters (of size + :math:`\frac{\text{out_channels}}{\text{in_channels}}`)., + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Note: + When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also known as a "depthwise convolution". + + In other words, for an input of size :math:`(N, C_{in}, L_{in})`, + a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments + :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`. + + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + Attr: + - weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out_channels}, \frac{\text{in_channels}}{\text{groups}},` + :math:`\text{kernel_size[0]}, \text{kernel_size[1]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}` + + - bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}` + + For example: + + .. code-block:: python + + >>> import numpy as np + >>> import oneflow.experimental as flow + >>> import oneflow.experimental.nn as nn + >>> flow.enable_eager_execution() + + >>> arr = np.random.randn(20, 16, 50, 100) + >>> input = flow.Tensor(arr) + >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ def __init__( self, in_channels: int, @@ -246,96 +368,6 @@ def __init__( bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type ): - r"""Applies a 2D convolution over an input signal composed of several input - planes. - In the simplest case, the output value of the layer with input size - :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` - can be precisely described as: - .. math:: - \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + - \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) - where :math:`\star` is the valid 2D `cross-correlation`_ operator, - :math:`N` is a batch size, :math:`C` denotes a number of channels, - :math:`H` is a height of input planes in pixels, and :math:`W` is - width in pixels. - * :attr:`stride` controls the stride for the cross-correlation, a single - number or a tuple. - * :attr:`padding` controls the amount of implicit padding on both - sides for :attr:`padding` number of points for each dimension. - * :attr:`dilation` controls the spacing between the kernel points; also - known as the à trous algorithm. It is harder to describe, but this `link`_ - has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters (of size - :math:`\frac{\text{out_channels}}{\text{in_channels}}`)., - The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - - a single ``int`` -- in which case the same value is used for the height and width dimension - - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, - and the second `int` for the width dimension - Note: - When `groups == in_channels` and `out_channels == K * in_channels`, - where `K` is a positive integer, this operation is also known as a "depthwise convolution". - In other words, for an input of size :math:`(N, C_{in}, L_{in})`, - a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments - :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`. - Args: - in_channels (int): Number of channels in the input image - out_channels (int): Number of channels produced by the convolution - kernel_size (int or tuple): Size of the convolving kernel - stride (int or tuple, optional): Stride of the convolution. Default: 1 - padding (int or tuple, optional): Zero-padding added to both sides of - the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` - dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 - groups (int, optional): Number of blocked connections from input - channels to output channels. Default: 1 - bias (bool, optional): If ``True``, adds a learnable bias to the - output. Default: ``True`` - Shape: - - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where - .. math:: - H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] - \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor - .. math:: - W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] - \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor - Attr: - - weight (Tensor): the learnable weights of the module of shape - :math:`(\text{out_channels}, \frac{\text{in_channels}}{\text{groups}},` - :math:`\text{kernel_size[0]}, \text{kernel_size[1]})`. - The values of these weights are sampled from - :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where - :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}` - - bias (Tensor): the learnable bias of the module of shape - (out_channels). If :attr:`bias` is ``True``, - then the values of these weights are - sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where - :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}` - For example: - .. code-block:: python - >>> import numpy as np - >>> import oneflow.experimental as flow - >>> import oneflow.experimental.nn as nn - >>> flow.enable_eager_execution() - >>> arr = np.random.randn(20, 16, 50, 100) - >>> input = flow.Tensor(arr) - >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) - >>> output = m(input) - .. _cross-correlation: - https://en.wikipedia.org/wiki/Cross-correlation - .. _link: - https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md - """ super().__init__() assert padding_mode == "zeros" From e430e1b4d24480adfb4850b4de89d7671fdd9010 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 16:34:48 +0800 Subject: [PATCH 28/42] fix bias add error --- oneflow/python/nn/modules/conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index a43734c7489..372a8755e62 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -246,7 +246,7 @@ def forward(self, x): flow.F.conv2d( in_split_list[i], self.weight[i : i + 1, :, :, :], - self.bias[i : i + 1, :, :, :], + self.bias[i : i + 1], stride=self.stride, padding=self.padding, dilation=self.dilation, From 926283a92a2c68f2685b55e549c17fdbbf12e2f4 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 17:59:08 +0800 Subject: [PATCH 29/42] fix groups bug --- oneflow/python/nn/modules/conv.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 6a9f5ebe1c1..01744cd23d8 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -167,6 +167,7 @@ class Conv1d(Module): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + def __init__( self, in_channels: int, @@ -209,16 +210,25 @@ def reset_parameters(self) -> None: def forward(self, x): if x.device.type == "cpu" and self.groups > 1: in_channel_axis = 1 + weight_channel_axis = 0 + bias_channel_axis = 0 in_split_list = ConvUtil.split( x, axis=in_channel_axis, split_num=self.groups ) + weight_split_list = ConvUtil.split( + self.weight, axis=weight_channel_axis, split_num=self.groups + ) + bias_split_list = ConvUtil.split( + self.bias, axis=bias_channel_axis, split_num=self.groups + ) + out_list = [] for i in range(len(in_split_list)): out_list.append( flow.F.conv1d( in_split_list[i], - self.weight[i : i + 1, :, :], - self.bias[i : i + 1], + weight_split_list[i], + bias_split_list[i], stride=self.stride, padding=self.padding, dilation=self.dilation, @@ -356,6 +366,7 @@ class Conv2d(Module): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + def __init__( self, in_channels: int, From 51716bcbaa308c35f527e829da0300de997d67ed Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 24 Jun 2021 17:59:21 +0800 Subject: [PATCH 30/42] add test case --- oneflow/python/test/modules/test_conv1d.py | 265 +++++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 oneflow/python/test/modules/test_conv1d.py diff --git a/oneflow/python/test/modules/test_conv1d.py b/oneflow/python/test/modules/test_conv1d.py new file mode 100644 index 00000000000..7615e0518ec --- /dev/null +++ b/oneflow/python/test/modules/test_conv1d.py @@ -0,0 +1,265 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np + +import oneflow.experimental as flow +import oneflow.experimental.nn as nn +from test_util import GenArgList + + +def _test_conv1d_bias_true(test_case, device): + np_arr = np.array( + [ + [ + [0.90499806, -1.11683071, 0.71605605, -0.56754625, 0.61944169], + [-0.31317389, -0.26271924, 0.95579433, 0.52468461, 1.48926127], + ] + ] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [ + [0.01997352, 0.23834395, 0.00526353], + [-0.04861857, -0.22751901, -0.06725175], + ], + [ + [0.13344523, -0.35202524, 0.15168799], + [-0.25714493, -0.17459838, 0.28768948], + ], + [ + [0.10671382, -0.28205597, -0.39752254], + [0.36393702, 0.07843742, -0.33898622], + ], + [ + [0.20485674, 0.04222689, -0.18986180], + [0.22519711, -0.15910202, -0.35057363], + ], + ] + ) + bias = np.array([0.01012857, 0.38912651, -0.01600273, -0.38833040]) + m = nn.Conv1d(2, 4, 3, stride=1, bias=True) + m.weight = flow.nn.Parameter(flow.Tensor(weight)) + m.bias = flow.nn.Parameter(flow.Tensor(bias)) + m = m.to(device) + np_out = np.array( + [ + [ + [-0.22349545, -0.08447243, -0.37358052], + [1.41303730, -0.04644597, 0.86949122], + [-0.34765026, -0.31004351, -0.14158708], + [-0.74985039, -0.87430149, -0.77354753], + ] + ] + ) + output = m(input) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [ + [ + [0.46498930, 0.11147892, -0.31895390, -0.78394318, -0.43043283], + [0.28337064, -0.19941133, -0.66853344, -0.95190406, -0.46912211], + ] + ] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + +def _test_conv1d_group_bias_true(test_case, device): + np_arr = np.array( + [ + [ + [1.48566079, 0.54937589, 0.62353903, -0.94114172, -0.60260266], + [0.61150503, -0.50289607, 1.41735041, -1.85877609, -1.04875529], + ] + ] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [[0.25576305, 0.40814576, -0.05900212]], + [[-0.24829513, 0.42756805, -0.01354307]], + [[0.44658303, 0.46889144, 0.41060263]], + [[0.30083328, -0.52216130, 0.12215579]], + ] + ) + bias = np.array([-0.03368823, -0.42125040, -0.42130581, -0.17434336]) + m = nn.Conv1d(2, 4, 3, groups=2, stride=1, bias=True) + m.weight = flow.nn.Parameter(flow.Tensor(weight)) + m.bias = flow.nn.Parameter(flow.Tensor(bias)) + m = m.to(device) + np_out = np.array( + [ + [ + [0.53372419, 0.41684598, -0.22277816], + [-0.56368178, -0.27830642, -0.97031319], + [0.19794616, -0.74452549, -1.09052706], + [0.44534814, -1.29277706, 1.09451222], + ] + ] + ) + output = m(input) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [ + [ + [0.00746793, 0.84318173, 0.77063656, 0.76316863, -0.07254519], + [0.74741632, 0.69414645, 1.22690487, 0.47948855, 0.53275841], + ] + ] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + +def _test_conv1d_group_large_out_bias_true(test_case, device): + np_arr = np.array( + [ + [ + [2.17964911, 0.91623521, 1.24746692, 0.73605931, -0.23738743], + [-0.70412433, 0.10727754, 1.02078640, -0.09711888, -1.10814202], + ] + ] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [[-2.07307473e-01, 1.28563240e-01, 3.71991515e-01]], + [[-4.16422307e-01, 3.26921181e-05, -3.85845661e-01]], + [[-1.82592362e-01, 1.43281639e-01, 4.19321984e-01]], + [[-2.71174580e-01, 4.21470925e-02, 3.77335936e-01]], + [[5.46190619e-01, -2.11819887e-01, -2.97858030e-01]], + [[3.34832489e-01, 2.55918801e-01, -5.56600206e-02]], + ] + ) + bias = np.array( + [-0.56865668, 0.17631066, -0.43992457, -0.24307285, -0.53672957, -0.52927947] + ) + m = nn.Conv1d(2, 6, 3, groups=2, stride=1, bias=True) + m.weight = flow.nn.Parameter(flow.Tensor(weight)) + m.bias = flow.nn.Parameter(flow.Tensor(bias)) + m = m.to(device) + np_out = np.array( + [ + [ + [-0.43867296, -0.32441288, -0.82094181], + [-1.21264362, -0.48919463, -0.25154343], + [-0.18354186, -0.11983716, -0.66178048], + [0.33756858, -0.26578707, -0.94211930], + [-1.24808860, -0.66543078, 0.37145507], + [-0.79440582, -0.22671542, -0.15066233], + ] + ] + ) + output = m(input) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [ + [ + [-0.80632210, -0.53444451, -0.12897667, 0.67734540, 0.40546784], + [0.60984850, 0.69609451, 0.71991241, 0.11006390, 0.02381789], + ] + ] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + +def _test_conv1d_group_large_in_bias_true(test_case, device): + np_arr = np.array( + [ + [ + [0.73829210, 0.32275710, -0.73204273, -0.01697334, 1.72585976], + [0.52866709, 0.28417364, 1.12931311, 1.73048413, -0.60748184], + [0.43222603, 0.78825170, -0.62105948, 0.10097823, 0.81639361], + [0.36671457, 0.24468753, -0.58248740, -0.74464536, -0.38901371], + ] + ] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [ + [-0.29574063, -0.31176069, 0.17234495], + [0.06092392, 0.30691007, -0.36685407], + ], + [ + [0.26149744, 0.07149458, 0.32097560], + [0.18960869, -0.37148297, -0.13602243], + ], + ] + ) + bias = np.array([-0.35048512, -0.00937920]) + m = nn.Conv1d(4, 2, 3, groups=2, stride=1, bias=True) + m.weight = flow.nn.Parameter(flow.Tensor(weight)) + m.bias = flow.nn.Parameter(flow.Tensor(bias)) + m = m.to(device) + np_out = np.array( + [[[-1.09048378, -0.49156523, 0.99150705], [0.01852397, 0.54882324, 0.31657016]]] + ) + output = m(input) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [ + [ + [-0.29574063, -0.60750133, -0.43515638, -0.13941574, 0.17234495], + [0.06092392, 0.36783397, 0.00097990, -0.05994400, -0.36685407], + [0.26149744, 0.33299202, 0.65396762, 0.39247018, 0.32097560], + [0.18960869, -0.18187428, -0.31789672, -0.50750542, -0.13602243], + ] + ] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestConv1d(flow.unittest.TestCase): + def test_conv1d(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_conv1d_bias_true, + _test_conv1d_group_bias_true, + _test_conv1d_group_large_out_bias_true, + _test_conv1d_group_large_in_bias_true, + ] + arg_dict["device"] = ["cuda", "cpu"] + + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + +if __name__ == "__main__": + unittest.main() From d50209adf0f1c2bb3cf7e92ffeab3f5ee8985264 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Thu, 24 Jun 2021 22:19:40 +0800 Subject: [PATCH 31/42] Support optional parameter. --- oneflow/api/python/functional/python_arg.cpp | 5 + oneflow/api/python/functional/python_arg.h | 17 +- oneflow/core/common/optional.h | 197 +++++++++++++++++++ oneflow/core/functional/functional_api.yaml | 2 +- oneflow/core/functional/impl/nn_functor.cpp | 17 +- oneflow/core/functional/value_types.h | 6 + oneflow/python/nn/modules/conv.py | 2 +- tools/generate_functional_api.py | 37 +++- 8 files changed, 269 insertions(+), 14 deletions(-) create mode 100644 oneflow/core/common/optional.h diff --git a/oneflow/api/python/functional/python_arg.cpp b/oneflow/api/python/functional/python_arg.cpp index c572a2885fd..57d4280f678 100644 --- a/oneflow/api/python/functional/python_arg.cpp +++ b/oneflow/api/python/functional/python_arg.cpp @@ -70,6 +70,11 @@ Maybe> PythonArg::ObjectAs>(Borrow()); } +template<> +Maybe PythonArg::ObjectAs() const { + return *JUST(detail::cast>(Borrow())); +} + template<> Maybe> PythonArg::ObjectAs>() const { diff --git a/oneflow/api/python/functional/python_arg.h b/oneflow/api/python/functional/python_arg.h index 519f07e07a7..1ac8d231ebc 100644 --- a/oneflow/api/python/functional/python_arg.h +++ b/oneflow/api/python/functional/python_arg.h @@ -61,6 +61,21 @@ class PythonArg { virtual ~PythonArg() = default; + template + friend class ObjectAsHelper; + + template + struct ObjectAsHelper { + Maybe operator()(const PythonArg* self) { return self->ObjectAs(); } + }; + template + struct ObjectAsHelper> { + Maybe> operator()(const PythonArg* self) { + if (self->object_ == Py_None) { return std::make_shared>(); } + return std::make_shared>(JUST(self->ObjectAs())); + } + }; + template operator T() const { if (active_tag_ == HAS_IMMEDIATE) { @@ -70,7 +85,7 @@ class PythonArg { return *reinterpret_cast(immediate_->Ptr()); } CHECK_EQ_OR_THROW(active_tag_, HAS_OBJECT); - return this->ObjectAs>().GetOrThrow(); + return ObjectAsHelper>()(this).GetOrThrow(); } private: diff --git a/oneflow/core/common/optional.h b/oneflow/core/common/optional.h new file mode 100644 index 00000000000..a5796548076 --- /dev/null +++ b/oneflow/core/common/optional.h @@ -0,0 +1,197 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_CORE_COMMON_OPTIONAL_H_ +#define ONEFLOW_CORE_COMMON_OPTIONAL_H_ + +#include "oneflow/core/common/type_traits.h" +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +template +class Storage; + +template +class Storage::value>::type> { + public: + Storage() = default; + + template::value, int>::type = 0> + Storage(Args&&... args) { + new (&value_) T(std::forward(args)...); + } + + Storage& operator=(const T& value) { + value_ = value; + return *this; + } + Storage& operator=(T&& value) { + value_ = std::move(value); + return *this; + } + Storage& operator=(const Storage& rhs) { + value_ = rhs.value_; + return *this; + } + Storage& operator=(Storage&& rhs) { + value_ = std::move(rhs.value_); + return *this; + } + + Maybe value() const { return value_; } + + private: + T value_; +}; + +template +class Storage::value>::type> { + public: + Storage() = default; + + template::value, int>::type = 0> + Storage(Args&&... args) { + value_ = std::make_shared(std::forward(args)...); + } + + Storage(const std::shared_ptr& value) : value_(value) {} + + Storage& operator=(const T& value) { + if (value_) { + *value_ = value; + } else { + value_ = std::make_shared(value); + } + return *this; + } + Storage& operator=(T&& value) { + if (value_) { + *value_ = std::move(value); + } else { + value_ = std::make_shared(value); + } + return *this; + } + Storage& operator=(const Storage& rhs) { + value_ = rhs.value_; + return *this; + } + Storage& operator=(Storage&& rhs) { + value_ = std::move(rhs.value_); + return *this; + } + + Maybe value() const { return value_; } + + private: + std::shared_ptr value_; +}; + +template +class Optional { + public: + Optional() : init_(false) {} + + template, Args...>::value, int>::type = 0> + Optional(Args&&... args) : init_(true), storage_(std::forward(args)...) {} + + ~Optional() = default; + + Optional(const Optional& rhs) : init_(rhs.init_) { + if (init_) { storage_ = rhs.storage_; } + } + + Optional(Optional&& rhs) : init_(rhs.init_) { + if (init_) { storage_ = std::move(rhs.storage_); } + } + + Optional& operator=(const T& val) { + init_ = true; + storage_ = val; + return *this; + } + + Optional& operator=(T&& val) { + init_ = true; + storage_ = std::move(val); + return *this; + } + + Optional& operator=(const Optional& rhs) { + init_ = rhs.init_; + if (init_) { storage_ = rhs.storage_; } + return *this; + } + + Optional& operator=(Optional&& rhs) { + init_ = rhs.init_; + if (init_) { storage_ = std::move(rhs.storage_); } + return *this; + } + + Maybe value() const { + CHECK_OR_RETURN(has_value()) << "Optional has no value."; + return storage_.value(); + } + + bool has_value() const { return init_; } + operator bool() const { return has_value(); } + + private: + bool init_; + Storage storage_; +}; + +template +class Optional { + public: + Optional() : value_ptr_(nullptr) {} + + Optional(T& val) : value_ptr_(&val) {} + + ~Optional() = default; + + Optional& operator=(T& val) { + value_ptr_ = &val; + return *this; + } + + Optional& operator=(const Optional& rhs) { + value_ptr_ = rhs.value_ptr_; + return *this; + } + + Maybe value() const { + CHECK_OR_RETURN(has_value()) << "Optional has no value."; + return *value_ptr_; + } + + void Clear() { value_ptr_ = nullptr; } + + bool has_value() const { return value_ptr_ != nullptr; } + operator bool() const { return has_value(); } + + private: + T* value_ptr_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_COMMON_OPTIONAL_H_ diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index f3e1e22042d..2175ce0598b 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -240,7 +240,7 @@ - name: "conv2d" signature: - "Tensor Conv2D(Tensor x, Tensor weight, Tensor bias, *, Int32List stride, + "Tensor Conv2D(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, Int32List padding, Int32List dilation, Int32 groups=1)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index dd3105efa20..4c3606a0e90 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/optional.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" @@ -56,8 +57,8 @@ class Conv2DFunctor { } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& weight, - const std::shared_ptr& bias, - const std::vector& stride, const std::vector& padding, + const Optional& bias, const std::vector& stride, + const std::vector& padding, const std::vector& dilation, const int32_t& groups) const { MutableAttrMap conv_attrs; std::vector kernel_size_vec; @@ -69,11 +70,15 @@ class Conv2DFunctor { JUST(conv_attrs.SetAttr>("dilation_rate", dilation)); JUST(conv_attrs.SetAttr("groups", groups)); JUST(conv_attrs.SetAttr("data_format", std::string("channels_first"))); - std::shared_ptr conv_out = + const std::shared_ptr& conv_out = JUST(OpInterpUtil::Dispatch(*conv_op_, {x, weight}, conv_attrs)); - MutableAttrMap bias_attrs; - JUST(bias_attrs.SetAttr("axis", 1)); - return OpInterpUtil::Dispatch(*bias_op_, {conv_out, bias}, bias_attrs); + if (bias) { + MutableAttrMap bias_attrs; + JUST(bias_attrs.SetAttr("axis", 1)); + return OpInterpUtil::Dispatch(*bias_op_, {conv_out, JUST(bias.value())}, bias_attrs); + } else { + return conv_out; + } } private: diff --git a/oneflow/core/functional/value_types.h b/oneflow/core/functional/value_types.h index cd8904face4..f47035a8e1d 100644 --- a/oneflow/core/functional/value_types.h +++ b/oneflow/core/functional/value_types.h @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" namespace oneflow { class Shape; @@ -79,6 +80,11 @@ enum ValueType { #define VALUE_TYPE_OF_IMPL(cpp_type, value_type) \ template::value, int>::type = 0> \ + inline ValueType ValueTypeOf() { \ + return value_type; \ + } \ + template>::value, int>::type = 0> \ inline ValueType ValueTypeOf() { \ return value_type; \ } diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index a43734c7489..6734f9940cb 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -246,7 +246,7 @@ def forward(self, x): flow.F.conv2d( in_split_list[i], self.weight[i : i + 1, :, :, :], - self.bias[i : i + 1, :, :, :], + self.bias[i : i + 1] if self.bias else None, stride=self.stride, padding=self.padding, dilation=self.dilation, diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py index 4164dfc3b6f..409858934ad 100644 --- a/tools/generate_functional_api.py +++ b/tools/generate_functional_api.py @@ -61,6 +61,7 @@ #ifndef ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_ #define ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_ +#include "oneflow/core/common/optional.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/scalar.h" @@ -104,6 +105,7 @@ #include "oneflow/api/python/functional/function_def.h" #include "oneflow/api/python/functional/py_function.h" #include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" #include "oneflow/core/functional/functional.h" namespace oneflow {{ @@ -174,6 +176,24 @@ **generic_type_aliases, } +optional_argument_type_aliases = { + "Tensor": "const Optional&", + "TensorTuple": "const Optional&", + "Scalar": "const Optional&", + "ScalarList": "const Optional>&", + "IntList": "const Optional>&", + "Int32List": "const Optional>&", + "Int64List": "const Optional>&", + "FloatList": "const Optional>&", + "DoubleList": "const Optional>&", + "String": "const Optional&", + "StringList": "const Optional>&", + "BoolList": "const Optional>&", + "DataType": "const Optional&", + "Shape": "const Optional&", + **{k: "const Optional<{0}>".format(v) for k, v in generic_type_aliases.items()}, +} + return_type_aliases = { "Void": "Maybe", "Tensor": "Maybe", @@ -254,20 +274,27 @@ def __init__(self, fmt, keyword_allowed=False): self._type = _normalize(fmt[0:sp]) assert self._type in types_allowed, "Unknow type: " + self._type - if self._type in argument_type_aliases: - self._cpp_type = argument_type_aliases[self._type] - else: - self._cpp_type = self._type + optional = False self._name = _normalize(fmt[sp + 1 :]) sp = self._name.find("=") if sp != -1: self._default_value = _normalize(self._name[sp + 1 :]) - if self._default_value in value_aliases: + if self._default_value == "None": + optional = True + self._default_cpp_value = "" + elif self._default_value in value_aliases: self._default_cpp_value = value_aliases[self._default_value] else: self._default_cpp_value = self._default_value self._name = _normalize(self._name[0:sp]) + if not optional and self._type in argument_type_aliases: + self._cpp_type = argument_type_aliases[self._type] + elif optional and self._type in optional_argument_type_aliases: + self._cpp_type = optional_argument_type_aliases[self._type] + else: + self._cpp_type = self._type + @property def has_default_value(self): return self._default_value is not None From f2c9e2931f9f7a8c323acb80b567945fc7d32284 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 25 Jun 2021 10:25:28 +0800 Subject: [PATCH 32/42] fix group bug --- oneflow/python/nn/modules/conv.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 6734f9940cb..b7f9e77f971 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -245,8 +245,21 @@ def forward(self, x): out_list.append( flow.F.conv2d( in_split_list[i], - self.weight[i : i + 1, :, :, :], - self.bias[i : i + 1] if self.bias else None, + self.weight[ + i + * self.out_channel_groups : (i + 1) + * self.out_channel_groups, + :, + :, + :, + ], + self.bias[ + i + * self.out_channel_groups : (i + 1) + * self.out_channel_groups + ] + if self.bias + else None, stride=self.stride, padding=self.padding, dilation=self.dilation, From 2389dbe3d15e27afc12a468aabb7280b5c0207aa Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 25 Jun 2021 10:25:40 +0800 Subject: [PATCH 33/42] add new test case --- oneflow/python/test/modules/test_conv.py | 124 ++++++++++++++++++++++- 1 file changed, 123 insertions(+), 1 deletion(-) diff --git a/oneflow/python/test/modules/test_conv.py b/oneflow/python/test/modules/test_conv.py index 9057dc77d3c..4f2e3e73b6b 100644 --- a/oneflow/python/test/modules/test_conv.py +++ b/oneflow/python/test/modules/test_conv.py @@ -1478,6 +1478,119 @@ def _test_conv2d_large_in_channel(test_case, device): test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) +def _test_conv2d_large_out_channel(test_case, device): + np_arr = np.array( + [ + [ + [ + [0.56573248, -0.19689320, -0.67875558, 0.34328273, 0.31964567], + [-1.33715475, 0.33422229, -1.27643383, 0.37904647, 0.35891593], + [0.84579802, 2.12729621, -0.51423287, 0.61297560, -1.31156564], + [-0.71047139, 1.02679253, -0.76686019, -0.72969633, 0.73425150], + [-0.13592879, -1.03207183, -0.22554775, 0.74148071, 0.96601510], + ], + [ + [0.51595992, 0.49624804, 0.91145641, 0.49247262, 0.41002217], + [-1.08001196, 1.55497086, -0.81963140, -0.45511565, -0.60269165], + [0.05563145, -0.94318372, -1.17058158, -0.73568577, 0.57810956], + [-0.40260276, -0.10309298, 1.12378800, -0.23510537, -0.73893374], + [-0.52712536, -0.00717016, -1.85051966, -1.50790560, 1.38335907], + ], + ] + ] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [ + [ + [-0.19489679, -0.32377058, 0.21736273], + [0.04095296, -0.21552679, -0.14626531], + [-0.19359522, -0.00742865, -0.19832158], + ] + ], + [ + [ + [0.29926914, 0.00931164, 0.26197660], + [0.27611443, -0.15439281, -0.19027126], + [-0.28909120, 0.30367029, -0.05168664], + ] + ], + [ + [ + [-0.03155736, 0.17610769, 0.22111714], + [0.22790670, -0.32897446, -0.03260243], + [-0.10274851, -0.06903386, -0.19438276], + ] + ], + [ + [ + [-0.24573688, -0.06723209, -0.21363299], + [-0.02136187, -0.24994437, -0.18691199], + [0.12189507, 0.29469389, 0.03398871], + ] + ], + ] + ) + m = flow.nn.Conv2d(2, 4, 3, groups=2, bias=False) + m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True) + m = m.to(device) + output = m(input) + print(output) + np_out = np.array( + [ + [ + [ + [-0.21170563, 0.03652292, 0.25926736], + [-0.19168918, 0.49044561, 0.25099146], + [-1.02489340, 0.25361472, -0.51828313], + ], + [ + [0.23977707, -0.56090075, -0.19285655], + [-0.17167747, 0.24558367, -0.30935860], + [-0.33303234, 1.52472734, -0.49013454], + ], + [ + [-0.17137986, 1.21333742, 0.18988736], + [0.31785482, -0.12121570, -0.18676008], + [-0.10680684, -0.30298883, 0.41809759], + ], + [ + [-0.87821335, -0.51665992, -0.44061098], + [0.74804580, 0.53107250, 0.50418228], + [-0.00512899, -0.36455840, -0.23643512], + ], + ] + ] + ) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [ + [ + [ + [0.10437235, -0.21008658, 0.26925275, 0.16488039, 0.47933933], + [0.42143974, -0.26293880, -0.12013602, -0.54157579, 0.14280275], + [-0.06124666, -0.44938356, -0.55658901, -0.49534237, -0.10720548], + [-0.16561902, -0.23929697, -0.82584178, -0.66022277, -0.58654481], + [-0.48268640, -0.18644476, -0.43645298, 0.04623342, -0.25000823], + ], + [ + [-0.27729425, -0.16841865, -0.16093449, 0.11635975, 0.00748415], + [-0.07074942, -0.54079264, -0.75282294, -0.68207347, -0.21203026], + [-0.05160286, -0.29598606, -0.66841042, -0.61680746, -0.37242430], + [0.22569139, -0.12756741, -0.50747585, -0.73316729, -0.37990844], + [0.01914656, 0.24480659, 0.08441254, 0.06526598, -0.16039404], + ], + ] + ] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), ".numpy() doesn't work in lazy mode", @@ -1677,7 +1790,7 @@ def test_conv2d_dilation_backward(test_case): device=device, ) - def test_large_channel_group_conv(test_case): + def test_large_in_channel_group_conv(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_conv2d_large_in_channel, @@ -1686,6 +1799,15 @@ def test_large_channel_group_conv(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + def test_large_out_channel_group_conv(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_conv2d_large_out_channel, + ] + arg_dict["device"] = ["cuda", "cpu"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + if __name__ == "__main__": unittest.main() From a4aae7232f940f6fc0a1bfd12fd6eade43fa2278 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 25 Jun 2021 13:58:36 +0800 Subject: [PATCH 34/42] add more test case --- oneflow/core/functional/functional_api.yaml | 2 +- oneflow/core/functional/impl/nn_functor.cpp | 20 ++- oneflow/python/nn/modules/conv.py | 24 ++- oneflow/python/test/modules/test_conv1d.py | 178 ++++++++++++++++++++ 4 files changed, 206 insertions(+), 18 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 31c9b4d02e6..45730a454f4 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -240,7 +240,7 @@ - name: "conv1d" signature: - "Tensor Conv1D(Tensor x, Tensor weight, Tensor bias, *, Int32List stride, + "Tensor Conv1D(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, Int32List padding, Int32List dilation, Int32 groups=1)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 4d61a910bd3..7cc43b42032 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -57,23 +57,27 @@ class Conv1DFunctor { } Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& weight, - const std::shared_ptr& bias, - const std::vector& stride, const std::vector& padding, + const Optional& bias, const std::vector& stride, + const std::vector& padding, const std::vector& dilation, const int32_t& groups) const { MutableAttrMap conv_attrs; - std::vector kernel_size_vec{(weight->shape())->At(2)}; - JUST(conv_attrs.SetAttr("filters", int64_t((weight->shape())->At(0)))); + std::vector kernel_size_vec{(int32_t)(weight->shape())->At(2)}; + JUST(conv_attrs.SetAttr("filters", (weight->shape())->At(0))); JUST(conv_attrs.SetAttr>("padding_before", padding)); JUST(conv_attrs.SetAttr>("kernel_size", kernel_size_vec)); JUST(conv_attrs.SetAttr>("strides", stride)); JUST(conv_attrs.SetAttr>("dilation_rate", dilation)); JUST(conv_attrs.SetAttr("groups", groups)); JUST(conv_attrs.SetAttr("data_format", std::string("channels_first"))); - std::shared_ptr conv_out = + const std::shared_ptr& conv_out = JUST(OpInterpUtil::Dispatch(*conv_op_, {x, weight}, conv_attrs)); - MutableAttrMap bias_attrs; - JUST(bias_attrs.SetAttr("axis", 1)); - return OpInterpUtil::Dispatch(*bias_op_, {conv_out, bias}, bias_attrs); + if (bias) { + MutableAttrMap bias_attrs; + JUST(bias_attrs.SetAttr("axis", 1)); + return OpInterpUtil::Dispatch(*bias_op_, {conv_out, JUST(bias.value())}, bias_attrs); + } else { + return conv_out; + } } private: diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index c94bfdf4e79..07411b71398 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -195,6 +195,7 @@ def __init__( self.weight = flow.nn.Parameter( flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) ) + self.out_channel_groups = out_channels // groups self.bias = None if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) @@ -215,20 +216,25 @@ def forward(self, x): in_split_list = ConvUtil.split( x, axis=in_channel_axis, split_num=self.groups ) - weight_split_list = ConvUtil.split( - self.weight, axis=weight_channel_axis, split_num=self.groups - ) - bias_split_list = ConvUtil.split( - self.bias, axis=bias_channel_axis, split_num=self.groups - ) - out_list = [] for i in range(len(in_split_list)): out_list.append( flow.F.conv1d( in_split_list[i], - weight_split_list[i], - bias_split_list[i], + self.weight[ + i + * self.out_channel_groups : (i + 1) + * self.out_channel_groups, + :, + :, + ], + self.bias[ + i + * self.out_channel_groups : (i + 1) + * self.out_channel_groups + ] + if self.bias + else None, stride=self.stride, padding=self.padding, dilation=self.dilation, diff --git a/oneflow/python/test/modules/test_conv1d.py b/oneflow/python/test/modules/test_conv1d.py index 7615e0518ec..6eeed068ce1 100644 --- a/oneflow/python/test/modules/test_conv1d.py +++ b/oneflow/python/test/modules/test_conv1d.py @@ -23,6 +23,42 @@ from test_util import GenArgList +def _test_conv1d_bias_false(test_case, device): + np_arr = np.array( + [[[1.28795946, -0.29217920, 0.20338029, 0.78604293, -1.89607573]]] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [[0.10197904, 0.33723050, -0.25743008]], + [[0.27720425, -0.52435774, -0.38381988]], + [[0.56016803, -0.10063095, -0.10760903]], + ] + ) + m = nn.Conv1d(1, 3, 3, stride=1, bias=False) + m.weight = flow.nn.Parameter(flow.Tensor(weight)) + m = m.to(device) + output = m(input) + np_out = np.array( + [ + [ + [-0.01954307, -0.16356121, 0.77392507], + [0.43217283, -0.48933625, 0.37196174], + [0.72899038, -0.26872110, 0.23886177], + ] + ] + ) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [[[0.93935132, 0.65159315, -0.09726584, -1.03661716, -0.74885899]]] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + def _test_conv1d_bias_true(test_case, device): np_arr = np.array( [ @@ -85,6 +121,78 @@ def _test_conv1d_bias_true(test_case, device): test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) +def _test_conv1d_dilation(test_case, device): + np_arr = np.array( + [[[-0.43016902, 1.74619496, -0.57338119, 0.25563857, 0.12575546]]] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [[-0.35057205, -0.31304273, 0.46250814]], + [[-0.40786612, 0.36518192, 0.46280444]], + [[-0.00921835, -0.38710043, 0.47566161]], + ] + ) + m = nn.Conv1d(1, 3, 3, stride=1, bias=False) + m.weight = flow.nn.Parameter(flow.Tensor(weight)) + m = m.to(device) + output = m(input) + np_out = np.array( + [ + [ + [-0.66102189, -0.31443936, 0.17914855], + [0.54776692, -0.80329150, 0.38541752], + [-0.94472277, 0.32745653, -0.03385513], + ] + ] + ) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [[[-0.76765651, -1.10261774, 0.29835641, 1.06601286, 1.40097415]]] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + +def _test_conv1d_stride(test_case, device): + np_arr = np.array( + [[[-1.01312506, -0.40687919, 1.59853160, 0.53594196, -1.89935565]]] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [[0.57514840, 0.26589182, -0.02654600]], + [[-0.10313249, -0.20797005, -0.48268208]], + [[-0.22216944, -0.14962578, 0.57433963]], + ] + ) + m = nn.Conv1d(1, 3, 3, stride=2, bias=False) + m.weight = flow.nn.Parameter(flow.Tensor(weight)) + m = m.to(device) + output = m(input) + np_out = np.array( + [ + [ + [-0.73331773, 1.11231577], + [-0.58247775, 0.64046454], + [1.20406508, -1.52621090], + ] + ] + ) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [[[0.24984647, -0.09170401, 0.31495798, -0.09170401, 0.06511152]]] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + def _test_conv1d_group_bias_true(test_case, device): np_arr = np.array( [ @@ -242,6 +350,72 @@ def _test_conv1d_group_large_in_bias_true(test_case, device): test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) +def _test_conv1d_compilcate(test_case, device): + np_arr = np.array( + [ + [ + [-1.00674784, 0.51784992, 0.39896572, 0.11018554, 0.91136694], + [1.95886874, 0.89779067, 0.47482130, 0.33313531, -0.49350029], + [-0.19280219, 0.04023677, 1.66438103, -0.83563608, 0.15925731], + [1.49166429, 1.45189261, -1.86512125, 0.34329697, 0.20413807], + ] + ] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [ + [-0.36045218, 0.37349278, 0.04565236], + [0.02423280, -0.09459515, -0.30684742], + ], + [ + [-0.30345008, -0.11965130, -0.26765293], + [0.09876197, 0.03346226, 0.27484050], + ], + [ + [-0.37798449, 0.00242459, -0.34125558], + [-0.05174343, -0.10443231, 0.09526101], + ], + [ + [0.34196907, -0.32667893, 0.40264183], + [0.38025281, 0.26807079, -0.09074812], + ], + ] + ) + bias = np.array([-0.03499984, -0.21616256, 0.13312563, -0.24104381]) + m = nn.Conv1d(4, 4, 3, groups=2, stride=2, padding=2, dilation=2, bias=True) + m.weight = flow.nn.Parameter(flow.Tensor(weight)) + m.bias = flow.nn.Parameter(flow.Tensor(bias)) + m = m.to(device) + np_out = np.array( + [ + [ + [-0.72379637, 0.67248386, 0.21977007], + [-0.00643994, -0.12861520, -0.41589433], + [-0.76877236, 0.29273134, -0.42040929], + [1.06121790, -0.73787093, -0.37839717], + ] + ] + ) + output = m(input) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [ + [ + [-0.41006082, 0.00000000, -0.63206136, 0.00000000, 0.03184089], + [0.06186188, 0.00000000, 0.02985496, 0.00000000, -0.09313981], + [-0.36026976, 0.00000000, -0.29888350, 0.00000000, -0.26286808], + [0.49214786, 0.00000000, 0.49666074, 0.00000000, 0.16815135], + ] + ] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), ".numpy() doesn't work in lazy mode", @@ -251,9 +425,13 @@ def test_conv1d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_conv1d_bias_true, + _test_conv1d_bias_false, + _test_conv1d_dilation, + _test_conv1d_stride, _test_conv1d_group_bias_true, _test_conv1d_group_large_out_bias_true, _test_conv1d_group_large_in_bias_true, + _test_conv1d_compilcate, ] arg_dict["device"] = ["cuda", "cpu"] From 65961010fd376c7c9787722aedc296a63195c57e Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Mon, 28 Jun 2021 10:01:01 +0800 Subject: [PATCH 35/42] small fix --- oneflow/core/common/optional.h | 12 ------------ oneflow/core/functional/impl/nn_functor.cpp | 3 --- 2 files changed, 15 deletions(-) diff --git a/oneflow/core/common/optional.h b/oneflow/core/common/optional.h index 9174d36dd44..866b6f5c339 100644 --- a/oneflow/core/common/optional.h +++ b/oneflow/core/common/optional.h @@ -105,24 +105,16 @@ class Storage::value>::type> { std::shared_ptr value_; }; -<<<<<<< HEAD -======= } // namespace internal ->>>>>>> master template class Optional { public: Optional() : init_(false) {} -<<<<<<< HEAD - template, Args...>::value, int>::type = 0> -======= template, Args...>::value, int>::type = 0> ->>>>>>> master Optional(Args&&... args) : init_(true), storage_(std::forward(args)...) {} ~Optional() = default; @@ -169,11 +161,7 @@ class Optional { private: bool init_; -<<<<<<< HEAD - Storage storage_; -======= internal::Storage storage_; ->>>>>>> master }; template diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 43e57be78c0..7cc43b42032 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -293,10 +293,7 @@ class NormalizationFunctor { ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("BiasAdd"); -<<<<<<< HEAD m.add_functor("Conv1D"); -======= ->>>>>>> master m.add_functor("Conv2D"); m.add_functor("MatMul"); m.add_functor("BatchMatMul"); From 9df9be3bf735ccab274837a2bb3a511231f6c592 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Mon, 28 Jun 2021 10:31:25 +0800 Subject: [PATCH 36/42] add torch reference --- oneflow/python/nn/modules/conv.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 01d8eca5bfe..4bf2f17c2ca 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -79,7 +79,10 @@ def split(cls, x, axis, split_num): @oneflow_export("nn.Conv1d") @experimental_api class Conv1d(Module): - r"""Applies a 1D convolution over an input signal composed of several input + r"""The interface is consistent with PyTorch. + The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html#conv1d + + Applies a 1D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size @@ -158,7 +161,7 @@ class Conv1d(Module): >>> arr = np.random.randn(20, 16, 50) >>> input = flow.Tensor(arr) - >>> m = nn.Conv1d(16, 33, 3, stride=2)) + >>> m = nn.Conv1d(16, 33, 3, stride=2) >>> output = m(input) .. _cross-correlation: @@ -258,7 +261,10 @@ def forward(self, x): @oneflow_export("nn.Conv2d") @experimental_api class Conv2d(Module): - r"""Applies a 2D convolution over an input signal composed of several input + r"""The interface is consistent with PyTorch. + The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.Conv2d.html#conv2d + + Applies a 2D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size From 1cebcf3e2b30fd59934a39c7c8c3894ea4fb6279 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 29 Jun 2021 16:29:01 +0800 Subject: [PATCH 37/42] add conv base functor --- oneflow/core/functional/impl/nn_functor.cpp | 62 ++++++++------------- 1 file changed, 22 insertions(+), 40 deletions(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 7cc43b42032..3c7b8926868 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -48,20 +48,20 @@ class BiasAddFunctor { std::shared_ptr op_; }; -class Conv1DFunctor { +class ConvBaseFunctor { public: - Conv1DFunctor() { - conv_op_ = - CHECK_JUST(one::OpBuilder("conv1d").Input("in").Input("weight").Output("out").Build()); - bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); - } + ConvBaseFunctor() = default; + virtual ~ConvBaseFunctor() = default; Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& weight, const Optional& bias, const std::vector& stride, const std::vector& padding, const std::vector& dilation, const int32_t& groups) const { MutableAttrMap conv_attrs; - std::vector kernel_size_vec{(int32_t)(weight->shape())->At(2)}; + std::vector kernel_size_vec; + for (int i = 0; i < kernel_size_shape_num; i++) { + kernel_size_vec.push_back((weight->shape())->At(i + 2)); + } JUST(conv_attrs.SetAttr("filters", (weight->shape())->At(0))); JUST(conv_attrs.SetAttr>("padding_before", padding)); JUST(conv_attrs.SetAttr>("kernel_size", kernel_size_vec)); @@ -80,47 +80,29 @@ class Conv1DFunctor { } } - private: + protected: std::shared_ptr conv_op_; - std::shared_ptr bias_op_; + std::shared_ptr bias_op_ = + CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); + int32_t kernel_size_shape_num; }; -class Conv2DFunctor { +class Conv1DFunctor : public ConvBaseFunctor { + public: + Conv1DFunctor() { + conv_op_ = + CHECK_JUST(one::OpBuilder("conv1d").Input("in").Input("weight").Output("out").Build()); + kernel_size_shape_num = 1; + } +}; + +class Conv2DFunctor : public ConvBaseFunctor { public: Conv2DFunctor() { conv_op_ = CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); - bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); - } - Maybe operator()(const std::shared_ptr& x, - const std::shared_ptr& weight, - const Optional& bias, const std::vector& stride, - const std::vector& padding, - const std::vector& dilation, const int32_t& groups) const { - MutableAttrMap conv_attrs; - std::vector kernel_size_vec; - for (int i = 0; i < 2; i++) { kernel_size_vec.push_back((weight->shape())->At(i + 2)); } - JUST(conv_attrs.SetAttr("filters", (weight->shape())->At(0))); - JUST(conv_attrs.SetAttr>("padding_before", padding)); - JUST(conv_attrs.SetAttr>("kernel_size", kernel_size_vec)); - JUST(conv_attrs.SetAttr>("strides", stride)); - JUST(conv_attrs.SetAttr>("dilation_rate", dilation)); - JUST(conv_attrs.SetAttr("groups", groups)); - JUST(conv_attrs.SetAttr("data_format", std::string("channels_first"))); - const std::shared_ptr& conv_out = - JUST(OpInterpUtil::Dispatch(*conv_op_, {x, weight}, conv_attrs)); - if (bias) { - MutableAttrMap bias_attrs; - JUST(bias_attrs.SetAttr("axis", 1)); - return OpInterpUtil::Dispatch(*bias_op_, {conv_out, JUST(bias.value())}, bias_attrs); - } else { - return conv_out; - } + kernel_size_shape_num = 2; } - - private: - std::shared_ptr conv_op_; - std::shared_ptr bias_op_; }; class MatMulBaseFunctor { From 8b7d2eb477d9e83e304668d44d6809d0fe65ab7d Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 29 Jun 2021 16:29:14 +0800 Subject: [PATCH 38/42] remove useless print --- oneflow/python/test/modules/test_conv.py | 1 - 1 file changed, 1 deletion(-) diff --git a/oneflow/python/test/modules/test_conv.py b/oneflow/python/test/modules/test_conv.py index 4f2e3e73b6b..d839fca8b3b 100644 --- a/oneflow/python/test/modules/test_conv.py +++ b/oneflow/python/test/modules/test_conv.py @@ -1538,7 +1538,6 @@ def _test_conv2d_large_out_channel(test_case, device): m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True) m = m.to(device) output = m(input) - print(output) np_out = np.array( [ [ From b0f10ea7598a3fca59ba0497fa1869722e79942c Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 1 Jul 2021 08:49:55 +0800 Subject: [PATCH 39/42] reorganize code structure --- oneflow/core/functional/impl/nn_functor.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 5cd243fcc6a..1cf22774222 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -50,7 +50,9 @@ class BiasAddFunctor { class ConvBaseFunctor { public: - ConvBaseFunctor() = default; + explicit ConvBaseFunctor(const int& num_spatial_dims) : num_spatial_dims_(num_spatial_dims) { + bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); + } virtual ~ConvBaseFunctor() = default; Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& weight, @@ -59,7 +61,7 @@ class ConvBaseFunctor { const std::vector& dilation, const int32_t& groups) const { MutableAttrMap conv_attrs; std::vector kernel_size_vec; - for (int i = 0; i < kernel_size_shape_num; i++) { + for (int i = 0; i < num_spatial_dims_; i++) { kernel_size_vec.push_back((weight->shape())->At(i + 2)); } JUST(conv_attrs.SetAttr("filters", (weight->shape())->At(0))); @@ -82,26 +84,23 @@ class ConvBaseFunctor { protected: std::shared_ptr conv_op_; - std::shared_ptr bias_op_ = - CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); - int32_t kernel_size_shape_num; + std::shared_ptr bias_op_; + int32_t num_spatial_dims_; }; class Conv1DFunctor : public ConvBaseFunctor { public: - Conv1DFunctor() { + Conv1DFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/1) { conv_op_ = CHECK_JUST(one::OpBuilder("conv1d").Input("in").Input("weight").Output("out").Build()); - kernel_size_shape_num = 1; } }; class Conv2DFunctor : public ConvBaseFunctor { public: - Conv2DFunctor() { + Conv2DFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/2) { conv_op_ = CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); - kernel_size_shape_num = 2; } }; From 662dfce2bd3c02461c1dd67377be5cf18f1756db Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 1 Jul 2021 09:51:39 +0800 Subject: [PATCH 40/42] fix name and vector size --- oneflow/core/functional/functional_api.yaml | 4 ++-- oneflow/core/functional/impl/nn_functor.cpp | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 233c6776f6f..8436777406f 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -240,13 +240,13 @@ - name: "conv1d" signature: - "Tensor Conv1D(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, + "Tensor Conv1d(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, Int32List padding, Int32List dilation, Int32 groups=1)" bind_python: True - name: "conv2d" signature: - "Tensor Conv2D(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, + "Tensor Conv2d(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, Int32List padding, Int32List dilation, Int32 groups=1)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 1cf22774222..cb1e4e91823 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -60,7 +60,7 @@ class ConvBaseFunctor { const std::vector& padding, const std::vector& dilation, const int32_t& groups) const { MutableAttrMap conv_attrs; - std::vector kernel_size_vec; + std::vector kernel_size_vec(num_spatial_dims_); for (int i = 0; i < num_spatial_dims_; i++) { kernel_size_vec.push_back((weight->shape())->At(i + 2)); } @@ -88,17 +88,17 @@ class ConvBaseFunctor { int32_t num_spatial_dims_; }; -class Conv1DFunctor : public ConvBaseFunctor { +class Conv1dFunctor : public ConvBaseFunctor { public: - Conv1DFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/1) { + Conv1dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/1) { conv_op_ = CHECK_JUST(one::OpBuilder("conv1d").Input("in").Input("weight").Output("out").Build()); } }; -class Conv2DFunctor : public ConvBaseFunctor { +class Conv2dFunctor : public ConvBaseFunctor { public: - Conv2DFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/2) { + Conv2dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/2) { conv_op_ = CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); } @@ -313,8 +313,8 @@ class NormalizationFunctor { ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("BiasAdd"); - m.add_functor("Conv1D"); - m.add_functor("Conv2D"); + m.add_functor("Conv1d"); + m.add_functor("Conv2d"); m.add_functor("MatMul"); m.add_functor("BatchMatMul"); m.add_functor("BroadcastMatMul"); From 087638defc4131ddfe917bbbca628556c8fdc66d Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 2 Jul 2021 10:06:13 +0800 Subject: [PATCH 41/42] fix pushback to at --- oneflow/core/functional/impl/nn_functor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index e45d91d6a5f..687d7e77eb8 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -62,7 +62,7 @@ class ConvBaseFunctor { MutableAttrMap conv_attrs; std::vector kernel_size_vec(num_spatial_dims_); for (int i = 0; i < num_spatial_dims_; i++) { - kernel_size_vec.push_back((weight->shape())->At(i + 2)); + kernel_size_vec.at(i) = ((weight->shape())->At(i + 2)); } JUST(conv_attrs.SetAttr("filters", (weight->shape())->At(0))); JUST(conv_attrs.SetAttr>("padding_before", padding)); From bd56a616aca953e1f15320189fe3dbbc2cafce20 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 2 Jul 2021 15:19:47 +0800 Subject: [PATCH 42/42] small fix for deconv docs --- oneflow/python/nn/modules/deconv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/python/nn/modules/deconv.py b/oneflow/python/nn/modules/deconv.py index 3c0f234abb8..be572ba18da 100644 --- a/oneflow/python/nn/modules/deconv.py +++ b/oneflow/python/nn/modules/deconv.py @@ -114,13 +114,13 @@ class ConvTranspose2d(Module): \times (\text{kernel_size}[1] - 1) + \text{output_padding}[1] + 1 Attributes: - weight (Tensor): the learnable weights of the module of shape + ConvTranspose2d.weight (Tensor): the learnable weights of the module of shape :math:`(\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},` :math:`\text{kernel_size[0]}, \text{kernel_size[1]})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel_size}[i]}` - bias (Tensor): the learnable bias of the module of shape (out_channels) + ConvTranspose2d.bias (Tensor): the learnable bias of the module of shape (out_channels) If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel_size}[i]}`