diff --git a/docs/api/python/intrin.rst b/docs/api/python/intrin.rst index a455f20a5874..3942c57f1a04 100644 --- a/docs/api/python/intrin.rst +++ b/docs/api/python/intrin.rst @@ -14,7 +14,7 @@ tvm.intrin tvm.ceil tvm.trunc tvm.round - + tvm.abs .. autofunction:: tvm.call_packed .. autofunction:: tvm.call_pure_intrin @@ -26,3 +26,4 @@ tvm.intrin .. autofunction:: tvm.ceil .. autofunction:: tvm.trunc .. autofunction:: tvm.round +.. autofunction:: tvm.abs diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index a14355d0f796..7f150ddbf7cd 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -13,6 +13,7 @@ List of operators topi.ceil topi.trunc topi.round + topi.abs topi.exp topi.tanh topi.log @@ -84,6 +85,7 @@ topi .. autofunction:: topi.ceil .. autofunction:: topi.trunc .. autofunction:: topi.round +.. autofunction:: topi.abs .. autofunction:: topi.exp .. autofunction:: topi.tanh .. autofunction:: topi.log diff --git a/docs/nnvm_top.rst b/docs/nnvm_top.rst index 4e1e536dbb26..96a37b779e1e 100644 --- a/docs/nnvm_top.rst +++ b/docs/nnvm_top.rst @@ -79,6 +79,7 @@ This level enables typical convnet models. nnvm.symbol.ceil nnvm.symbol.round nnvm.symbol.trunc + nnvm.symbol.abs nnvm.symbol.leaky_relu nnvm.symbol.__add_scalar__ nnvm.symbol.__sub_scalar__ @@ -157,6 +158,7 @@ Detailed Definitions .. autofunction:: nnvm.symbol.ceil .. autofunction:: nnvm.symbol.round .. autofunction:: nnvm.symbol.trunc +.. autofunction:: nnvm.symbol.abs .. autofunction:: nnvm.symbol.leaky_relu .. autofunction:: nnvm.symbol.__add_scalar__ .. autofunction:: nnvm.symbol.__sub_scalar__ diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index 268205ecb07b..947c3b736d80 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -18,7 +18,6 @@ using HalideIR::likely_if_innermost; using HalideIR::cast; using HalideIR::min; using HalideIR::max; -using HalideIR::abs; using HalideIR::select; /*! @@ -71,6 +70,26 @@ inline Expr pow(Expr x, Expr y) { return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); } +/*! + * \brief Calculate absolute value of x, elementwise + * \param x The input data + * + * \return The aboslute value of input data x + */ +inline Expr abs(Expr x) { + if (x.type().is_int()) { + return select(x >= make_zero(x.type()), x, -x); + } else if (x.type().is_float()) { + return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic); + } else if (x.type().is_uint()) { + return x; + } else { + LOG(WARNING) << "Warning: Data type " << x.type() + <<" not supported for absolute op. Skipping absolute op..."; + return x; + } +} + } // namespace tvm #endif // TVM_IR_OPERATOR_H_ diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index 3ff59e5d0042..9a8caab29479 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -68,6 +68,10 @@ def _compute(attrs, x, _): reg.register_pattern("round", OpPattern.ELEMWISE) reg.register_schedule("round", _fschedule_broadcast) +# abs +reg.register_pattern("abs", OpPattern.ELEMWISE) +reg.register_schedule("abs", _fschedule_broadcast) + # trunc reg.register_pattern("trunc", OpPattern.ELEMWISE) reg.register_schedule("trunc", _fschedule_broadcast) diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index 57e7137f5151..403e08f95653 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -81,6 +81,18 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(round) return Array{ topi::round(inputs[0]) }; }); +// abs +NNVM_REGISTER_ELEMWISE_UNARY_OP(abs) +.describe(R"code(Take absolute value of elements of the input. +)code" NNVM_ADD_FILELINE) +.set_support_level(3) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array{ topi::abs(inputs[0]) }; +}); + // sigmoid NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid) .describe(R"code(Computes sigmoid. diff --git a/nnvm/tests/python/compiler/test_top_level3.py b/nnvm/tests/python/compiler/test_top_level3.py index 125836a7848e..c8bd37c38e5b 100644 --- a/nnvm/tests/python/compiler/test_top_level3.py +++ b/nnvm/tests/python/compiler/test_top_level3.py @@ -28,6 +28,10 @@ def test_trunc(): def test_round(): check_map(sym.round, np.round) +def test_abs(): + check_map(sym.abs, np.abs) + check_map(sym.abs, np.abs, dtype = "int32") + check_map(sym.abs, np.abs, dtype = "int8") def test_shift(): n = 3 @@ -40,4 +44,5 @@ def test_shift(): test_floor() test_ceil() test_round() + test_abs() test_trunc() diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index fe4fb8faa9c1..422f2d682d2b 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -285,6 +285,22 @@ def trunc(x): return call_pure_intrin(x.dtype, "trunc", x) +def abs(x): + """Get absolute value of the input element-wise. + + Parameters + ---------- + x : Expr + Input argument. + + Returns + ------- + y : Expr + The result. + """ + return _make.abs(x) + + def round(x): """Round elements of the array to the nearest integer. diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index f8fbe902ca0b..bc9293c20b7a 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -7,6 +7,7 @@ #include #include #include +#include namespace tvm { namespace ir { @@ -16,6 +17,11 @@ TVM_REGISTER_API("_Var") *ret = Variable::make(args[1], args[0]); }); +TVM_REGISTER_API("make.abs") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = tvm::abs(args[0]); + }); + TVM_REGISTER_API("make._range_by_min_extent") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = Range::make_by_min_extent(args[0], args[1]); diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index d291463bce9f..43461a15932d 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round") .set_body(DispatchExtern); diff --git a/src/codegen/intrin_rule_metal.cc b/src/codegen/intrin_rule_metal.cc index 659a82df0757..3c210919132e 100644 --- a/src/codegen/intrin_rule_metal.cc +++ b/src/codegen/intrin_rule_metal.cc @@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round") .set_body(DispatchExtern); diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index 127fe47d5b75..d91deaeda5fe 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round") .set_body(DispatchExtern); diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index 0d6f8d07b79f..4b2a3ca5bd02 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -34,6 +34,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); diff --git a/src/codegen/llvm/intrin_rule_nvptx.cc b/src/codegen/llvm/intrin_rule_nvptx.cc index 60957198d5f3..d0b9f3693192 100644 --- a/src/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/codegen/llvm/intrin_rule_nvptx.cc @@ -39,6 +39,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc") .set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs") +.set_body(DispatchExternLibDevice); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp") .set_body(DispatchExternLibDevice); diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc index 9d105a35385f..b9bee94e9c24 100644 --- a/src/codegen/llvm/intrin_rule_rocm.cc +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -38,6 +38,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") .set_body(DispatchExternOCML); diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index cf9c446ea3f8..a7fa46bda60a 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -41,6 +41,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc") .set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs") +.set_body(DispatchGLSLPureIntrin); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") .set_body(DispatchGLSLPureIntrin); diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 7082106d09ee..88c77f0afc52 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -35,6 +35,7 @@ TOPI_DECLARE_UNARY_OP(floor); TOPI_DECLARE_UNARY_OP(ceil); TOPI_DECLARE_UNARY_OP(round); TOPI_DECLARE_UNARY_OP(trunc); +TOPI_DECLARE_UNARY_OP(abs); /*! * \brief Creates an operation that returns identity of a given tensor diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 62b1a34a2c96..a5d28d351719 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -125,6 +125,23 @@ def trunc(x): return tvm.compute(x.shape, lambda *i: tvm.trunc(x(*i))) +@tvm.tag_scope(tag=tag.ELEMWISE) +def abs(x): + """Take absolute value of the input of x, element-wise. + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return tvm.compute(x.shape, lambda *i: tvm.abs(x(*i))) + + @tvm.tag_scope(tag=tag.ELEMWISE) def round(x): """Round elements of x to nearest integer. diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index 0205407790d6..4190c8e1d213 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -46,6 +46,7 @@ def check_device(device): test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.trunc, "trunc", np.trunc, -100, 100) + test_apply(topi.abs, "fabs", np.abs, -100, 100) test_apply(topi.round, "round", np.round, -100, 100) test_apply(topi.exp, "exp", np.exp, -1, 1) test_apply(topi.tanh, "tanh", np.tanh, -10, 10)