Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for absolute opeartion #1406

Merged
merged 3 commits into from
Jul 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/api/python/intrin.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tvm.intrin
tvm.ceil
tvm.trunc
tvm.round

tvm.abs

.. autofunction:: tvm.call_packed
.. autofunction:: tvm.call_pure_intrin
Expand All @@ -26,3 +26,4 @@ tvm.intrin
.. autofunction:: tvm.ceil
.. autofunction:: tvm.trunc
.. autofunction:: tvm.round
.. autofunction:: tvm.abs
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ List of operators
topi.ceil
topi.trunc
topi.round
topi.abs
topi.exp
topi.tanh
topi.log
Expand Down Expand Up @@ -84,6 +85,7 @@ topi
.. autofunction:: topi.ceil
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.abs
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.log
Expand Down
2 changes: 2 additions & 0 deletions docs/nnvm_top.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ This level enables typical convnet models.
nnvm.symbol.ceil
nnvm.symbol.round
nnvm.symbol.trunc
nnvm.symbol.abs
nnvm.symbol.leaky_relu
nnvm.symbol.__add_scalar__
nnvm.symbol.__sub_scalar__
Expand Down Expand Up @@ -157,6 +158,7 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.ceil
.. autofunction:: nnvm.symbol.round
.. autofunction:: nnvm.symbol.trunc
.. autofunction:: nnvm.symbol.abs
.. autofunction:: nnvm.symbol.leaky_relu
.. autofunction:: nnvm.symbol.__add_scalar__
.. autofunction:: nnvm.symbol.__sub_scalar__
Expand Down
21 changes: 20 additions & 1 deletion include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ using HalideIR::likely_if_innermost;
using HalideIR::cast;
using HalideIR::min;
using HalideIR::max;
using HalideIR::abs;
using HalideIR::select;

/*!
Expand Down Expand Up @@ -71,6 +70,26 @@ inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}

/*!
* \brief Calculate absolute value of x, elementwise
* \param x The input data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add \return

*
* \return The aboslute value of input data x
*/
inline Expr abs(Expr x) {
if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
return x;
} else {
LOG(WARNING) << "Warning: Data type " << x.type()
<<" not supported for absolute op. Skipping absolute op...";
return x;
}
}

} // namespace tvm

#endif // TVM_IR_OPERATOR_H_
4 changes: 4 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def _compute(attrs, x, _):
reg.register_pattern("round", OpPattern.ELEMWISE)
reg.register_schedule("round", _fschedule_broadcast)

# abs
reg.register_pattern("abs", OpPattern.ELEMWISE)
reg.register_schedule("abs", _fschedule_broadcast)

# trunc
reg.register_pattern("trunc", OpPattern.ELEMWISE)
reg.register_schedule("trunc", _fschedule_broadcast)
Expand Down
12 changes: 12 additions & 0 deletions nnvm/src/top/tensor/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(round)
return Array<Tensor>{ topi::round(inputs[0]) };
});

// abs
NNVM_REGISTER_ELEMWISE_UNARY_OP(abs)
.describe(R"code(Take absolute value of elements of the input.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::abs(inputs[0]) };
});

// sigmoid
NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
.describe(R"code(Computes sigmoid.
Expand Down
5 changes: 5 additions & 0 deletions nnvm/tests/python/compiler/test_top_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def test_trunc():
def test_round():
check_map(sym.round, np.round)

def test_abs():
check_map(sym.abs, np.abs)
check_map(sym.abs, np.abs, dtype = "int32")
check_map(sym.abs, np.abs, dtype = "int8")

def test_shift():
n = 3
Expand All @@ -40,4 +44,5 @@ def test_shift():
test_floor()
test_ceil()
test_round()
test_abs()
test_trunc()
16 changes: 16 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ def trunc(x):
return call_pure_intrin(x.dtype, "trunc", x)


def abs(x):
"""Get absolute value of the input element-wise.

Parameters
----------
x : Expr
Input argument.

Returns
-------
y : Expr
The result.
"""
return _make.abs(x)


def round(x):
"""Round elements of the array to the nearest integer.

Expand Down
6 changes: 6 additions & 0 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <tvm/ir.h>
#include <ir/IROperator.h>
#include <tvm/api_registry.h>
#include <tvm/ir_operator.h>

namespace tvm {
namespace ir {
Expand All @@ -16,6 +17,11 @@ TVM_REGISTER_API("_Var")
*ret = Variable::make(args[1], args[0]);
});

TVM_REGISTER_API("make.abs")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::abs(args[0]);
});

TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]);
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
.set_body(DispatchExtern<CUDAMath>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round")
.set_body(DispatchExtern<Direct>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round")
.set_body(DispatchExtern<Direct>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/llvm/intrin_rule_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs")
.set_body(DispatchGLSLPureIntrin<GLSLstd450FAbs>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);

Expand Down
1 change: 1 addition & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);

/*!
* \brief Creates an operation that returns identity of a given tensor
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ def trunc(x):
return tvm.compute(x.shape, lambda *i: tvm.trunc(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def abs(x):
"""Take absolute value of the input of x, element-wise.

Parameters
----------
x : tvm.Tensor
Input argument.

Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.abs(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def round(x):
"""Round elements of x to nearest integer.
Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def check_device(device):
test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100)
test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
Expand Down