From 9308f546ead9f6b51cae63ff0c1c638420e7e40a Mon Sep 17 00:00:00 2001 From: binbinHan Date: Thu, 5 Aug 2021 12:25:18 +0800 Subject: [PATCH] Refacotr maximum minimum py2cpp (#5724) * Refactor functional matmul and add apis. * Export batch matmul and fix python module * Check inplace valid in C++ * Support scalar add tensor. * Support inplace when broadcasting add * Refactor functional sub. * Refactor mul and div * Refactor functional sub. * Fix div * Fix add * Fix add * refacotr_maximum_minimum__py2cpp * refine * refine * minor fix * refactor * auto format by CI Co-authored-by: hjchen2 Co-authored-by: Yao Chi Co-authored-by: Yinggang Wang Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot --- oneflow/core/functional/functional_api.yaml | 16 ++----- .../core/functional/impl/binary_functor.cpp | 34 ------------- oneflow/core/functional/impl/math_functor.cpp | 48 +++++++++++++++++++ python/oneflow/nn/modules/math_ops.py | 42 +--------------- 4 files changed, 54 insertions(+), 86 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 44b662e794b..82c5e63a499 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -835,20 +835,12 @@ Int32List stride, Bool ceil_mode, Bool count_include_pad, Int64 divisor_override=0)" bind_python: False -- name: "broadcast_min" - signature: "Tensor BroadcastMin(Tensor x, Tensor y)" +- name: "minimum" + signature: "Tensor Minimum(Tensor x, Tensor y)" bind_python: True -- name: "broadcast_max" - signature: "Tensor BroadcastMax(Tensor x, Tensor y)" - bind_python: True - -- name: "elementwise_min" - signature: "Tensor ElementwiseMin(Tensor x, Tensor y)" - bind_python: True - -- name: "elementwise_max" - signature: "Tensor ElementwiseMax(Tensor x, Tensor y)" +- name: "maximum" + signature: "Tensor Maximum(Tensor x, Tensor y)" bind_python: True - name: "elementwise_min_grad" diff --git a/oneflow/core/functional/impl/binary_functor.cpp b/oneflow/core/functional/impl/binary_functor.cpp index bf711283ae2..8c84197f442 100644 --- a/oneflow/core/functional/impl/binary_functor.cpp +++ b/oneflow/core/functional/impl/binary_functor.cpp @@ -194,36 +194,6 @@ class ScalarDivByTensorFunctor : public BinaryFunctor { } }; -class BroadcastMinimumFunctor : public BinaryFunctor { - public: - BroadcastMinimumFunctor() { - op_ = CHECK_JUST(one::OpBuilder("broadcast_minimum").Input("x").Input("y").Output("z").Build()); - } -}; - -class BroadcastMaximumFunctor : public BinaryFunctor { - public: - BroadcastMaximumFunctor() { - op_ = CHECK_JUST(one::OpBuilder("broadcast_maximum").Input("x").Input("y").Output("z").Build()); - } -}; - -class ElementwiseMinimumFunctor : public BinaryFunctor { - public: - ElementwiseMinimumFunctor() { - op_ = - CHECK_JUST(one::OpBuilder("elementwise_minimum").Input("x").Input("y").Output("z").Build()); - } -}; - -class ElementwiseMaximumFunctor : public BinaryFunctor { - public: - ElementwiseMaximumFunctor() { - op_ = - CHECK_JUST(one::OpBuilder("elementwise_maximum").Input("x").Input("y").Output("z").Build()); - } -}; - class ReshapeLikeFunctor : public BinaryFunctor { public: ReshapeLikeFunctor() { @@ -242,8 +212,6 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("BroadcastSub"); m.add_functor("BroadcastMul"); m.add_functor("BroadcastDiv"); - m.add_functor("BroadcastMin"); - m.add_functor("BroadcastMax"); m.add_functor("BroadcastEqual"); m.add_functor("BroadcastNotEqual"); m.add_functor("BroadcastGreater"); @@ -256,8 +224,6 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ScalarSubByTensor"); m.add_functor("ScalarMulByTensor"); m.add_functor("ScalarDivByTensor"); - m.add_functor("ElementwiseMin"); - m.add_functor("ElementwiseMax"); m.add_functor("BroadcastFMod"); m.add_functor("ReshapeLike"); }; diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index 020625793a7..4e0ac71acd1 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -401,6 +401,52 @@ class SelectFirstFunctor { std::shared_ptr op_; }; +class MinimumFunctor { + public: + MinimumFunctor() { + elementwise_minimum_op_ = + CHECK_JUST(one::OpBuilder("elementwise_minimum").Input("x").Input("y").Output("z").Build()); + broadcast_minimum_op_ = + CHECK_JUST(one::OpBuilder("broadcast_minimum").Input("x").Input("y").Output("z").Build()); + } + + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& y) const { + if (*x->shape() == *y->shape()) { + return OpInterpUtil::Dispatch(*elementwise_minimum_op_, {x, y}); + } else { + return OpInterpUtil::Dispatch(*broadcast_minimum_op_, {x, y}); + } + } + + private: + std::shared_ptr elementwise_minimum_op_; + std::shared_ptr broadcast_minimum_op_; +}; + +class MaximumFunctor { + public: + MaximumFunctor() { + elementwise_maximum_op_ = + CHECK_JUST(one::OpBuilder("elementwise_maximum").Input("x").Input("y").Output("z").Build()); + broadcast_maximum_op_ = + CHECK_JUST(one::OpBuilder("broadcast_maximum").Input("x").Input("y").Output("z").Build()); + } + + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& y) const { + if (*x->shape() == *y->shape()) { + return OpInterpUtil::Dispatch(*elementwise_maximum_op_, {x, y}); + } else { + return OpInterpUtil::Dispatch(*broadcast_maximum_op_, {x, y}); + } + } + + private: + std::shared_ptr elementwise_maximum_op_; + std::shared_ptr broadcast_maximum_op_; +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -421,6 +467,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ClipByScalarMax"); m.add_functor("ClipByScalarMaxGrad"); m.add_functor("SelectFirst"); + m.add_functor("Minimum"); + m.add_functor("Maximum"); }; } // namespace functional diff --git a/python/oneflow/nn/modules/math_ops.py b/python/oneflow/nn/modules/math_ops.py index c7d288a7d73..507eee43833 100644 --- a/python/oneflow/nn/modules/math_ops.py +++ b/python/oneflow/nn/modules/math_ops.py @@ -1483,22 +1483,6 @@ def topk_op(input, k, dim: int = None, largest: bool = True, sorted: bool = True return Topk(k=k, dim=dim, largest=largest, sorted=sorted)(input) -class ElementwiseMinimum(Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x, y): - return flow.F.elementwise_min(x, y) - - -class BroadcastMinimum(Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x, y): - return flow.F.broadcast_min(x, y) - - @register_tensor_op("minimum") def minimum(x, y): r"""Computes the element-wise minimum of x and y. @@ -1520,26 +1504,7 @@ def minimum(x, y): >>> flow.minimum(x, y) tensor([1., 0., 1.], dtype=oneflow.float32) """ - if x.shape == y.shape: - return ElementwiseMinimum()(x, y) - else: - return BroadcastMinimum()(x, y) - - -class ElementwiseMaximum(Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x, y): - return flow.F.elementwise_max(x, y) - - -class BroadcastMaximum(Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x, y): - return flow.F.broadcast_max(x, y) + return flow.F.minimum(x, y) @register_tensor_op("maximum") @@ -1563,10 +1528,7 @@ def maximum(x, y): >>> flow.maximum(x, y) tensor([3., 1., 4.], dtype=oneflow.float32) """ - if x.shape == y.shape: - return ElementwiseMaximum()(x, y) - else: - return BroadcastMaximum()(x, y) + return flow.F.maximum(x, y) if __name__ == "__main__":