Skip to content

Commit

Permalink
Refacotr maximum minimum py2cpp (#5724)
Browse files Browse the repository at this point in the history
* Refactor functional matmul and add apis.

* Export batch matmul and fix python module

* Check inplace valid in C++

* Support scalar add tensor.

* Support inplace when broadcasting add

* Refactor functional sub.

* Refactor mul and div

* Refactor functional sub.

* Fix div

* Fix add

* Fix add

* refacotr_maximum_minimum__py2cpp

* refine

* refine

* minor fix

* refactor

* auto format by CI

Co-authored-by: hjchen2 <chenhoujiangcug@gmail.com>
Co-authored-by: Yao Chi <later@usopp.net>
Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
6 people committed Aug 5, 2021
1 parent 0160f26 commit 9308f54
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 86 deletions.
16 changes: 4 additions & 12 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -835,20 +835,12 @@
Int32List stride, Bool ceil_mode, Bool count_include_pad, Int64 divisor_override=0)"
bind_python: False

- name: "broadcast_min"
signature: "Tensor BroadcastMin(Tensor x, Tensor y)"
- name: "minimum"
signature: "Tensor Minimum(Tensor x, Tensor y)"
bind_python: True

- name: "broadcast_max"
signature: "Tensor BroadcastMax(Tensor x, Tensor y)"
bind_python: True

- name: "elementwise_min"
signature: "Tensor ElementwiseMin(Tensor x, Tensor y)"
bind_python: True

- name: "elementwise_max"
signature: "Tensor ElementwiseMax(Tensor x, Tensor y)"
- name: "maximum"
signature: "Tensor Maximum(Tensor x, Tensor y)"
bind_python: True

- name: "elementwise_min_grad"
Expand Down
34 changes: 0 additions & 34 deletions oneflow/core/functional/impl/binary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,36 +194,6 @@ class ScalarDivByTensorFunctor : public BinaryFunctor {
}
};

class BroadcastMinimumFunctor : public BinaryFunctor {
public:
BroadcastMinimumFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_minimum").Input("x").Input("y").Output("z").Build());
}
};

class BroadcastMaximumFunctor : public BinaryFunctor {
public:
BroadcastMaximumFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_maximum").Input("x").Input("y").Output("z").Build());
}
};

class ElementwiseMinimumFunctor : public BinaryFunctor {
public:
ElementwiseMinimumFunctor() {
op_ =
CHECK_JUST(one::OpBuilder("elementwise_minimum").Input("x").Input("y").Output("z").Build());
}
};

class ElementwiseMaximumFunctor : public BinaryFunctor {
public:
ElementwiseMaximumFunctor() {
op_ =
CHECK_JUST(one::OpBuilder("elementwise_maximum").Input("x").Input("y").Output("z").Build());
}
};

class ReshapeLikeFunctor : public BinaryFunctor {
public:
ReshapeLikeFunctor() {
Expand All @@ -242,8 +212,6 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BroadcastSubFunctor>("BroadcastSub");
m.add_functor<impl::BroadcastMulFunctor>("BroadcastMul");
m.add_functor<impl::BroadcastDivFunctor>("BroadcastDiv");
m.add_functor<impl::BroadcastMinimumFunctor>("BroadcastMin");
m.add_functor<impl::BroadcastMaximumFunctor>("BroadcastMax");
m.add_functor<impl::BroadcastEqualFunctor>("BroadcastEqual");
m.add_functor<impl::BroadcastNotEqualFunctor>("BroadcastNotEqual");
m.add_functor<impl::BroadcastGreaterFunctor>("BroadcastGreater");
Expand All @@ -256,8 +224,6 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ScalarSubByTensorFunctor>("ScalarSubByTensor");
m.add_functor<impl::ScalarMulByTensorFunctor>("ScalarMulByTensor");
m.add_functor<impl::ScalarDivByTensorFunctor>("ScalarDivByTensor");
m.add_functor<impl::ElementwiseMinimumFunctor>("ElementwiseMin");
m.add_functor<impl::ElementwiseMaximumFunctor>("ElementwiseMax");
m.add_functor<impl::BroadcastFModFunctor>("BroadcastFMod");
m.add_functor<impl::ReshapeLikeFunctor>("ReshapeLike");
};
Expand Down
48 changes: 48 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,52 @@ class SelectFirstFunctor {
std::shared_ptr<OpExpr> op_;
};

class MinimumFunctor {
public:
MinimumFunctor() {
elementwise_minimum_op_ =
CHECK_JUST(one::OpBuilder("elementwise_minimum").Input("x").Input("y").Output("z").Build());
broadcast_minimum_op_ =
CHECK_JUST(one::OpBuilder("broadcast_minimum").Input("x").Input("y").Output("z").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y) const {
if (*x->shape() == *y->shape()) {
return OpInterpUtil::Dispatch<Tensor>(*elementwise_minimum_op_, {x, y});
} else {
return OpInterpUtil::Dispatch<Tensor>(*broadcast_minimum_op_, {x, y});
}
}

private:
std::shared_ptr<OpExpr> elementwise_minimum_op_;
std::shared_ptr<OpExpr> broadcast_minimum_op_;
};

class MaximumFunctor {
public:
MaximumFunctor() {
elementwise_maximum_op_ =
CHECK_JUST(one::OpBuilder("elementwise_maximum").Input("x").Input("y").Output("z").Build());
broadcast_maximum_op_ =
CHECK_JUST(one::OpBuilder("broadcast_maximum").Input("x").Input("y").Output("z").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y) const {
if (*x->shape() == *y->shape()) {
return OpInterpUtil::Dispatch<Tensor>(*elementwise_maximum_op_, {x, y});
} else {
return OpInterpUtil::Dispatch<Tensor>(*broadcast_maximum_op_, {x, y});
}
}

private:
std::shared_ptr<OpExpr> elementwise_maximum_op_;
std::shared_ptr<OpExpr> broadcast_maximum_op_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand All @@ -421,6 +467,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ClipByScalarMaxFunctor>("ClipByScalarMax");
m.add_functor<impl::ClipByScalarMaxGradFunctor>("ClipByScalarMaxGrad");
m.add_functor<impl::SelectFirstFunctor>("SelectFirst");
m.add_functor<impl::MinimumFunctor>("Minimum");
m.add_functor<impl::MaximumFunctor>("Maximum");
};

} // namespace functional
Expand Down
42 changes: 2 additions & 40 deletions python/oneflow/nn/modules/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,22 +1483,6 @@ def topk_op(input, k, dim: int = None, largest: bool = True, sorted: bool = True
return Topk(k=k, dim=dim, largest=largest, sorted=sorted)(input)


class ElementwiseMinimum(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x, y):
return flow.F.elementwise_min(x, y)


class BroadcastMinimum(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x, y):
return flow.F.broadcast_min(x, y)


@register_tensor_op("minimum")
def minimum(x, y):
r"""Computes the element-wise minimum of x and y.
Expand All @@ -1520,26 +1504,7 @@ def minimum(x, y):
>>> flow.minimum(x, y)
tensor([1., 0., 1.], dtype=oneflow.float32)
"""
if x.shape == y.shape:
return ElementwiseMinimum()(x, y)
else:
return BroadcastMinimum()(x, y)


class ElementwiseMaximum(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x, y):
return flow.F.elementwise_max(x, y)


class BroadcastMaximum(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x, y):
return flow.F.broadcast_max(x, y)
return flow.F.minimum(x, y)


@register_tensor_op("maximum")
Expand All @@ -1563,10 +1528,7 @@ def maximum(x, y):
>>> flow.maximum(x, y)
tensor([3., 1., 4.], dtype=oneflow.float32)
"""
if x.shape == y.shape:
return ElementwiseMaximum()(x, y)
else:
return BroadcastMaximum()(x, y)
return flow.F.maximum(x, y)


if __name__ == "__main__":
Expand Down

0 comments on commit 9308f54

Please sign in to comment.