Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev scalar op #5778

Merged
merged 98 commits into from
Aug 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
86c7f34
add logical scalar kernel
MARD1NO Aug 7, 2021
be02f4e
add logical scalar op register
MARD1NO Aug 7, 2021
7b4ad8e
add functional api yaml
MARD1NO Aug 7, 2021
f8a23bc
modify math functor
MARD1NO Aug 7, 2021
460348b
fix
MARD1NO Aug 7, 2021
dde46ce
reuse functor
MARD1NO Aug 7, 2021
3e1a778
fix
MARD1NO Aug 7, 2021
41874d9
modify equal
MARD1NO Aug 7, 2021
1f65f61
modify greater
MARD1NO Aug 7, 2021
72ed00d
modify greater equal
MARD1NO Aug 7, 2021
7021121
modify less equal
MARD1NO Aug 7, 2021
f0f6c68
modify less than
MARD1NO Aug 7, 2021
af2edb1
add not equal
MARD1NO Aug 7, 2021
6443025
modify not equal
MARD1NO Aug 7, 2021
ac29ca8
fix format
MARD1NO Aug 7, 2021
685982d
remove partial sum
MARD1NO Aug 7, 2021
ce35587
add newline
MARD1NO Aug 7, 2021
551aaf7
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 8, 2021
63a0e45
reuse base class
MARD1NO Aug 8, 2021
00a2ea9
fix bin_op to binary_op
MARD1NO Aug 8, 2021
5f1dd97
modify to Scalar
MARD1NO Aug 8, 2021
f9a03a8
first restruct and anotate cuda
MARD1NO Aug 8, 2021
94fd74a
modify to no grad user op
MARD1NO Aug 9, 2021
f60dc78
restruct code and add dtype
MARD1NO Aug 9, 2021
a58d8f0
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 9, 2021
ed20010
export to pybind
MARD1NO Aug 9, 2021
f345e51
remove redundant logic in python
MARD1NO Aug 9, 2021
e3f92d9
bind python as false
MARD1NO Aug 9, 2021
e2ad126
remove annotation
MARD1NO Aug 9, 2021
1ba1a6b
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 9, 2021
ab35e6f
fix dtype
MARD1NO Aug 9, 2021
3b25ea2
support scalar in input or output
MARD1NO Aug 9, 2021
1e0616c
fix
MARD1NO Aug 9, 2021
e056745
Add magic method
MARD1NO Aug 9, 2021
22fc133
add docs
MARD1NO Aug 9, 2021
1199469
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 9, 2021
3c36a12
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 9, 2021
7b67e07
auto format by CI
oneflow-ci-bot Aug 9, 2021
dfddb2e
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 10, 2021
e4e04f8
Merge branch 'dev_scalar_op' of https://github.com/Oneflow-Inc/oneflo…
MARD1NO Aug 10, 2021
abb9224
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 10, 2021
57ad348
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 10, 2021
a46eacf
fix randn test
MARD1NO Aug 10, 2021
a885ce6
Merge branch 'dev_scalar_op' of https://github.com/Oneflow-Inc/oneflo…
MARD1NO Aug 10, 2021
8dee020
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 10, 2021
a54e317
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 10, 2021
63c0a3c
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 10, 2021
8aed8db
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 10, 2021
2be3f94
modify back
MARD1NO Aug 10, 2021
de9055a
Merge branch 'dev_scalar_op' of https://github.com/Oneflow-Inc/oneflo…
MARD1NO Aug 10, 2021
a7fef83
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 10, 2021
0b5bcbf
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 11, 2021
3388534
small fix
MARD1NO Aug 11, 2021
0a35e77
fix 0d tensor
MARD1NO Aug 11, 2021
ad57791
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 11, 2021
b52bf60
auto format by CI
oneflow-ci-bot Aug 11, 2021
69cf7db
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 11, 2021
45868cb
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 11, 2021
5de2674
fix 0d test
MARD1NO Aug 11, 2021
329f883
auto format by CI
oneflow-ci-bot Aug 11, 2021
f66a82d
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 11, 2021
89513a1
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 11, 2021
25f2486
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 11, 2021
898c15f
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 11, 2021
d58ada5
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 11, 2021
3ea3100
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 12, 2021
bf11877
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 12, 2021
6393f48
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 12, 2021
71b4cad
fix ddp bug
daquexian Aug 13, 2021
84ea407
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 13, 2021
5c269e8
Merge branch 'dev_scalar_op' of https://github.com/Oneflow-Inc/oneflo…
MARD1NO Aug 13, 2021
2e5454d
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 13, 2021
f051144
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 13, 2021
5e01878
Merge branch 'master' into dev_scalar_op
chengtbf Aug 13, 2021
016d8f7
fix to use is
MARD1NO Aug 13, 2021
88b210b
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 13, 2021
83eda84
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 13, 2021
5d8e953
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 14, 2021
b0bed30
remove [0] in ddp.py
daquexian Aug 14, 2021
c1940bc
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 14, 2021
413c4c9
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 14, 2021
205dd5b
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 14, 2021
733d81b
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 14, 2021
9b7fd2c
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 14, 2021
1412bc7
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 14, 2021
86f36e6
fix unittest
MARD1NO Aug 14, 2021
646f0ee
fix format
MARD1NO Aug 14, 2021
e6e3117
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 14, 2021
58bd751
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 14, 2021
5cfc731
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 14, 2021
87cbb70
fix wrong unittest
MARD1NO Aug 14, 2021
353fa85
skip free eager test
MARD1NO Aug 15, 2021
7f67ae4
Merge branch 'master' into dev_scalar_op
MARD1NO Aug 15, 2021
bb2a615
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 15, 2021
5679757
fix to use is not none
MARD1NO Aug 15, 2021
1a51a38
Merge branch 'dev_scalar_op' of https://github.com/Oneflow-Inc/oneflo…
MARD1NO Aug 15, 2021
b1ff78c
merge graph test
MARD1NO Aug 15, 2021
4866bc3
Merge branch 'master' into dev_scalar_op
oneflow-ci-bot Aug 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
234 changes: 234 additions & 0 deletions oneflow/api/python/functional/python_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,234 @@ py::object PyScatter(py::args py_args, py::kwargs py_kwargs) {
return py::cast(result.GetPtrOrThrow());
}

py::object PyEqual(py::args py_args, py::kwargs py_kwargs) {
// "broadcast_equal(Tensor x, Tensor y)"
// "scalar_logical_equal(Tensor in, Scalar scalar)"
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
PyObject* args = py_args.ptr();
size_t nargs = PyTuple_Size(args);
CHECK_EQ_OR_THROW(nargs, 2) << "2 positional inputs are required.";
const auto& result = [&]() -> Maybe<Tensor> { // NOLINT
PyObject* input = PyTuple_GetItem(args, 0);
PyObject* other = PyTuple_GetItem(args, 1);
bool input_is_tensor = PyTensorCheck(input);
bool other_is_tensor = PyTensorCheck(other);
CHECK_OR_RETURN(input_is_tensor || other_is_tensor) << "Inputs must have one tensor at least.";
CHECK_OR_RETURN(PyTensorCheck(input) || PyScalarCheck(input))
<< "The first input should be a tensor or scalar.";
CHECK_OR_RETURN(PyTensorCheck(other) || PyScalarCheck(other))
<< "The second input should be a tensor or scalar.";

if (PyTensorCheck(input) && PyTensorCheck(other)) {
auto a = JUST(PyUnpackTensor(input));
auto b = JUST(PyUnpackTensor(other));
return functional::BroadcastEqual(a, b);
} else {
if (PyTensorCheck(input)) {
CHECK_OR_RETURN(PyScalarCheck(other)) << "The second input should be a scalar.";
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
auto a = JUST(PyUnpackTensor(input));
auto b = *JUST(PyUnpackScalar(other));
return functional::ScalarLogicalEqual(a, b);
} else {
CHECK_OR_RETURN(PyScalarCheck(input)) << "The first input should be a scalar.";
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
auto a = *JUST(PyUnpackScalar(input));
auto b = JUST(PyUnpackTensor(other));
return functional::ScalarLogicalEqual(b, a);
}
}
}();
return py::cast(result.GetPtrOrThrow());
}

py::object PyNotEqual(py::args py_args, py::kwargs py_kwargs) {
// "broadcast_not_equal(Tensor x, Tensor y)"
// "scalar_logical_not_equal(Tensor in, Scalar scalar)"
PyObject* args = py_args.ptr();
size_t nargs = PyTuple_Size(args);
CHECK_EQ_OR_THROW(nargs, 2) << "2 positional inputs are required.";
const auto& result = [&]() -> Maybe<Tensor> { // NOLINT
PyObject* input = PyTuple_GetItem(args, 0);
PyObject* other = PyTuple_GetItem(args, 1);
bool input_is_tensor = PyTensorCheck(input);
bool other_is_tensor = PyTensorCheck(other);
CHECK_OR_RETURN(input_is_tensor || other_is_tensor) << "Inputs must have one tensor at least.";
CHECK_OR_RETURN(PyTensorCheck(input) || PyScalarCheck(input))
<< "The first input should be a tensor or scalar.";
CHECK_OR_RETURN(PyTensorCheck(other) || PyScalarCheck(other))
<< "The second input should be a tensor or scalar.";

if (PyTensorCheck(input) && PyTensorCheck(other)) {
auto a = JUST(PyUnpackTensor(input));
auto b = JUST(PyUnpackTensor(other));
return functional::BroadcastNotEqual(a, b);
} else {
if (PyTensorCheck(input)) {
CHECK_OR_RETURN(PyScalarCheck(other)) << "The second input should be a scalar.";
auto a = JUST(PyUnpackTensor(input));
auto b = *JUST(PyUnpackScalar(other));
return functional::ScalarLogicalNotEqual(a, b);
} else {
CHECK_OR_RETURN(PyScalarCheck(input)) << "The first input should be a scalar.";
auto a = *JUST(PyUnpackScalar(input));
auto b = JUST(PyUnpackTensor(other));
return functional::ScalarLogicalNotEqual(b, a);
}
}
}();
return py::cast(result.GetPtrOrThrow());
}

py::object PyGreater(py::args py_args, py::kwargs py_kwargs) {
// "broadcast_greater(Tensor x, Tensor y)"
// "scalar_logical_greater(Tensor in, Scalar scalar)"
PyObject* args = py_args.ptr();
size_t nargs = PyTuple_Size(args);
CHECK_EQ_OR_THROW(nargs, 2) << "2 positional inputs are required.";
const auto& result = [&]() -> Maybe<Tensor> { // NOLINT
PyObject* input = PyTuple_GetItem(args, 0);
PyObject* other = PyTuple_GetItem(args, 1);
bool input_is_tensor = PyTensorCheck(input);
bool other_is_tensor = PyTensorCheck(other);
CHECK_OR_RETURN(input_is_tensor || other_is_tensor) << "Inputs must have one tensor at least.";
CHECK_OR_RETURN(PyTensorCheck(input) || PyScalarCheck(input))
<< "The first input should be a tensor or scalar.";
CHECK_OR_RETURN(PyTensorCheck(other) || PyScalarCheck(other))
<< "The second input should be a tensor or scalar.";

if (PyTensorCheck(input) && PyTensorCheck(other)) {
auto a = JUST(PyUnpackTensor(input));
auto b = JUST(PyUnpackTensor(other));
return functional::BroadcastGreater(a, b);
} else {
if (PyTensorCheck(input)) {
CHECK_OR_RETURN(PyScalarCheck(other)) << "The second input should be a scalar.";
auto a = JUST(PyUnpackTensor(input));
auto b = *JUST(PyUnpackScalar(other));
return functional::ScalarLogicalGreater(a, b);
} else {
CHECK_OR_RETURN(PyScalarCheck(input)) << "The first input should be a scalar.";
auto a = *JUST(PyUnpackScalar(input));
auto b = JUST(PyUnpackTensor(other));
return functional::ScalarLogicalGreater(b, a);
}
}
}();
return py::cast(result.GetPtrOrThrow());
}

py::object PyGreaterEqual(py::args py_args, py::kwargs py_kwargs) {
// "broadcast_greater_equal(Tensor x, Tensor y)"
// "scalar_logical_greater_equal(Tensor in, Scalar scalar)"
PyObject* args = py_args.ptr();
size_t nargs = PyTuple_Size(args);
CHECK_EQ_OR_THROW(nargs, 2) << "2 positional inputs are required.";
const auto& result = [&]() -> Maybe<Tensor> { // NOLINT
PyObject* input = PyTuple_GetItem(args, 0);
PyObject* other = PyTuple_GetItem(args, 1);
bool input_is_tensor = PyTensorCheck(input);
bool other_is_tensor = PyTensorCheck(other);
CHECK_OR_RETURN(input_is_tensor || other_is_tensor) << "Inputs must have one tensor at least.";
CHECK_OR_RETURN(PyTensorCheck(input) || PyScalarCheck(input))
<< "The first input should be a tensor or scalar.";
CHECK_OR_RETURN(PyTensorCheck(other) || PyScalarCheck(other))
<< "The second input should be a tensor or scalar.";

if (PyTensorCheck(input) && PyTensorCheck(other)) {
auto a = JUST(PyUnpackTensor(input));
auto b = JUST(PyUnpackTensor(other));
return functional::BroadcastGreaterEqual(a, b);
} else {
if (PyTensorCheck(input)) {
CHECK_OR_RETURN(PyScalarCheck(other)) << "The second input should be a scalar.";
auto a = JUST(PyUnpackTensor(input));
auto b = *JUST(PyUnpackScalar(other));
return functional::ScalarLogicalGreaterEqual(a, b);
} else {
CHECK_OR_RETURN(PyScalarCheck(input)) << "The first input should be a scalar.";
auto a = *JUST(PyUnpackScalar(input));
auto b = JUST(PyUnpackTensor(other));
return functional::ScalarLogicalGreaterEqual(b, a);
}
}
}();
return py::cast(result.GetPtrOrThrow());
}

py::object PyLess(py::args py_args, py::kwargs py_kwargs) {
// "broadcast_less(Tensor x, Tensor y)"
// "scalar_logical_less(Tensor in, Scalar scalar)"
PyObject* args = py_args.ptr();
size_t nargs = PyTuple_Size(args);
CHECK_EQ_OR_THROW(nargs, 2) << "2 positional inputs are required.";
const auto& result = [&]() -> Maybe<Tensor> { // NOLINT
PyObject* input = PyTuple_GetItem(args, 0);
PyObject* other = PyTuple_GetItem(args, 1);
bool input_is_tensor = PyTensorCheck(input);
bool other_is_tensor = PyTensorCheck(other);
CHECK_OR_RETURN(input_is_tensor || other_is_tensor) << "Inputs must have one tensor at least.";
CHECK_OR_RETURN(PyTensorCheck(input) || PyScalarCheck(input))
<< "The first input should be a tensor or scalar.";
CHECK_OR_RETURN(PyTensorCheck(other) || PyScalarCheck(other))
<< "The second input should be a tensor or scalar.";

if (PyTensorCheck(input) && PyTensorCheck(other)) {
auto a = JUST(PyUnpackTensor(input));
auto b = JUST(PyUnpackTensor(other));
return functional::BroadcastLess(a, b);
} else {
if (PyTensorCheck(input)) {
CHECK_OR_RETURN(PyScalarCheck(other)) << "The second input should be a scalar.";
auto a = JUST(PyUnpackTensor(input));
auto b = *JUST(PyUnpackScalar(other));
return functional::ScalarLogicalLess(a, b);
} else {
CHECK_OR_RETURN(PyScalarCheck(input)) << "The first input should be a scalar.";
auto a = *JUST(PyUnpackScalar(input));
auto b = JUST(PyUnpackTensor(other));
return functional::ScalarLogicalLess(b, a);
}
}
}();
return py::cast(result.GetPtrOrThrow());
}

py::object PyLessEqual(py::args py_args, py::kwargs py_kwargs) {
// "broadcast_less_equal(Tensor x, Tensor y)"
// "scalar_logical_less_equal(Tensor in, Scalar scalar)"
PyObject* args = py_args.ptr();
size_t nargs = PyTuple_Size(args);
CHECK_EQ_OR_THROW(nargs, 2) << "2 positional inputs are required.";
const auto& result = [&]() -> Maybe<Tensor> { // NOLINT
PyObject* input = PyTuple_GetItem(args, 0);
PyObject* other = PyTuple_GetItem(args, 1);
bool input_is_tensor = PyTensorCheck(input);
bool other_is_tensor = PyTensorCheck(other);
CHECK_OR_RETURN(input_is_tensor || other_is_tensor) << "Inputs must have one tensor at least.";
CHECK_OR_RETURN(PyTensorCheck(input) || PyScalarCheck(input))
<< "The first input should be a tensor or scalar.";
CHECK_OR_RETURN(PyTensorCheck(other) || PyScalarCheck(other))
<< "The second input should be a tensor or scalar.";

if (PyTensorCheck(input) && PyTensorCheck(other)) {
auto a = JUST(PyUnpackTensor(input));
auto b = JUST(PyUnpackTensor(other));
return functional::BroadcastLessEqual(a, b);
} else {
if (PyTensorCheck(input)) {
CHECK_OR_RETURN(PyScalarCheck(other)) << "The second input should be a scalar.";
auto a = JUST(PyUnpackTensor(input));
auto b = *JUST(PyUnpackScalar(other));
return functional::ScalarLogicalLessEqual(a, b);
} else {
CHECK_OR_RETURN(PyScalarCheck(input)) << "The first input should be a scalar.";
auto a = *JUST(PyUnpackScalar(input));
auto b = JUST(PyUnpackTensor(other));
return functional::ScalarLogicalLessEqual(b, a);
}
}
}();
return py::cast(result.GetPtrOrThrow());
}

} // namespace functional
} // namespace one

Expand All @@ -352,6 +580,12 @@ ONEFLOW_API_PYBIND11_MODULE("F", m) {
m.def("pow", &functional::PyPow);
m.def("clamp", &functional::PyClamp);
m.def("scatter", &functional::PyScatter);
m.def("equal", &functional::PyEqual);
m.def("not_equal", &functional::PyNotEqual);
m.def("greater", &functional::PyGreater);
m.def("greater_equal", &functional::PyGreaterEqual);
m.def("less", &functional::PyLess);
m.def("less_equal", &functional::PyLessEqual);
}

} // namespace oneflow
26 changes: 26 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,32 @@
signature: "Tensor ReduceSumLike(Tensor in, Tensor like, *,Int32List axis)"
bind_python: True

- name: "scalar_logical_equal"
signature: "Tensor ScalarLogicalEqual(Tensor in, Scalar scalar)"
bind_python: False

- name: "scalar_logical_not_equal"
signature: "Tensor ScalarLogicalNotEqual(Tensor in, Scalar scalar)"
bind_python: False

- name: "scalar_logical_greater"
signature: "Tensor ScalarLogicalGreater(Tensor in, Scalar scalar)"
bind_python: False

- name: "scalar_logical_greater_equal"
signature: "Tensor ScalarLogicalGreaterEqual(Tensor in, Scalar scalar)"
bind_python: False

- name: "scalar_logical_less"
signature: "Tensor ScalarLogicalLess(Tensor in, Scalar scalar)"
bind_python: False

- name: "scalar_logical_less_equal"
signature: "Tensor ScalarLogicalLessEqual(Tensor in, Scalar scalar)"
bind_python: False

- name: "split"
signature: "TensorTuple Split(Tensor x, *, Int64 split_size, Int64 dim=0)"
- name: "rand"
signature: "Tensor Rand(*, Shape shape, DataType dtype=None, Device device=None, Generator generator=None)"
bind_python: True
Expand Down
67 changes: 67 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,67 @@ class MaximumFunctor {
std::shared_ptr<OpExpr> broadcast_maximum_op_;
};

class ScalarLogicalBaseFunctor {
public:
explicit ScalarLogicalBaseFunctor(std::string op_name) {
op_ = CHECK_JUST(one::OpBuilder(op_name).Input("in").Output("out").Build());
}
virtual ~ScalarLogicalBaseFunctor() = default;
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {
MutableAttrMap attrs;

if (IsFloatingDataType(x->dtype())) {
JUST(attrs.SetAttr<double>("float_operand", JUST(scalar.As<double>())));
JUST(attrs.SetAttr<bool>("has_float_operand", true));
JUST(attrs.SetAttr<bool>("has_int_operand", false));
} else if (IsIntegralDataType(x->dtype())) {
JUST(attrs.SetAttr<int64_t>("int_operand", JUST(scalar.As<int64_t>())));
JUST(attrs.SetAttr<bool>("has_float_operand", false));
JUST(attrs.SetAttr<bool>("has_int_operand", true));
} else {
UNIMPLEMENTED_THEN_RETURN() << "The scalar in ScalarAdd shoule be float or int.";
}

return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class ScalarLogicalEqualFunctor : public ScalarLogicalBaseFunctor {
public:
ScalarLogicalEqualFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_equal") {}
};

class ScalarLogicalNotEqualFunctor : public ScalarLogicalBaseFunctor {
public:
ScalarLogicalNotEqualFunctor()
: ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_not_equal") {}
};

class ScalarLogicalGreaterFunctor : public ScalarLogicalBaseFunctor {
public:
ScalarLogicalGreaterFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_greater") {}
};

class ScalarLogicalGreaterEqualFunctor : public ScalarLogicalBaseFunctor {
public:
ScalarLogicalGreaterEqualFunctor()
: ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_greater_equal") {}
};

class ScalarLogicalLessFunctor : public ScalarLogicalBaseFunctor {
public:
ScalarLogicalLessFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_less") {}
};

class ScalarLogicalLessEqualFunctor : public ScalarLogicalBaseFunctor {
public:
ScalarLogicalLessEqualFunctor()
: ScalarLogicalBaseFunctor(/*op_name=*/"scalar_logical_less_equal") {}
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand All @@ -469,6 +530,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SelectFirstFunctor>("SelectFirst");
m.add_functor<impl::MinimumFunctor>("Minimum");
m.add_functor<impl::MaximumFunctor>("Maximum");
m.add_functor<impl::ScalarLogicalEqualFunctor>("ScalarLogicalEqual");
m.add_functor<impl::ScalarLogicalNotEqualFunctor>("ScalarLogicalNotEqual");
m.add_functor<impl::ScalarLogicalGreaterFunctor>("ScalarLogicalGreater");
m.add_functor<impl::ScalarLogicalGreaterEqualFunctor>("ScalarLogicalGreaterEqual");
m.add_functor<impl::ScalarLogicalLessFunctor>("ScalarLogicalLess");
m.add_functor<impl::ScalarLogicalLessEqualFunctor>("ScalarLogicalLessEqual");
};

} // namespace functional
Expand Down