From cc8a86dc3f11fc248fdc7fc9addbb337465a1c4e Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Tue, 9 Jan 2024 16:09:44 +0800 Subject: [PATCH] feat: add Erf cpu/cuda kernel --- .../src/kernels/simple_unary/cpu_kernel.cc | 16 ++++++++++++ .../src/kernels/simple_unary/cuda_kernel.cc | 16 +++++++++++- .../test/kernels/simple_unary/test_cpu.cpp | 1 + .../test/kernels/simple_unary/test_cuda.cpp | 1 + src/07onnx/test/test_simple_unary.cpp | 25 +++++++++++++++++++ 5 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 src/07onnx/test/test_simple_unary.cpp diff --git a/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc b/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc index d34528569..5e83938df 100644 --- a/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc @@ -18,6 +18,7 @@ namespace refactor::kernel { Op::Sigmoid, Op::Tanh, Op::Neg, + Op::Erf, }; return supportedOp.contains(op) && a.dataType.isCpuNumberic() ? std::make_unique(op, a.dataType, a.elementsSize()) @@ -155,6 +156,21 @@ namespace refactor::kernel { default: UNREACHABLE(); } + case Op::Erf: + switch (dataType) { + CASE(std::erf, F32); + CASE(std::erf, F64); + CASE(std::erf, I8); + CASE(std::erf, I16); + CASE(std::erf, I32); + CASE(std::erf, I64); + CASE(std::erf, U8); + CASE(std::erf, U16); + CASE(std::erf, U32); + CASE(std::erf, U64); + default: + UNREACHABLE(); + } default: UNREACHABLE(); } diff --git a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc index e3c260dbc..7403b0dde 100644 --- a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc @@ -18,7 +18,8 @@ namespace refactor::kernel { auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { static const std::unordered_set supportedOp{Op::Abs, Op::Relu, Op::Sqrt, - Op::Sigmoid, Op::Tanh, Op::Neg}; + Op::Sigmoid, Op::Tanh, Op::Neg, + Op::Erf}; #ifndef USE_CUDA return nullptr; #endif @@ -140,6 +141,19 @@ extern "C" __global__ void kernel( {__(Op::Neg, DT::BF16), "-x"}, {__(Op::Neg, DT::F32 ), "-x"}, {__(Op::Neg, DT::F64 ), "-x"}, + + {__(Op::Erf, DT::F32 ), "erff(x)"}, + {__(Op::Erf, DT::F64 ), "erf(x)"}, + {__(Op::Erf, DT::U8 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::I8 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::U16 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::I16 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::U32 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::I32 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::U64 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::I64 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::FP16), "__float2half(erff(__half2float(x)))"}, + {__(Op::Erf, DT::BF16), "__float2bfloat16(erff(__bfloat162float(x)))"}, }; // clang-format on diff --git a/src/04kernel/test/kernels/simple_unary/test_cpu.cpp b/src/04kernel/test/kernels/simple_unary/test_cpu.cpp index da1cb6f83..e24d2091f 100644 --- a/src/04kernel/test/kernels/simple_unary/test_cpu.cpp +++ b/src/04kernel/test/kernels/simple_unary/test_cpu.cpp @@ -31,4 +31,5 @@ TEST(kernel, SimpleUnaryCpu) { testOp(SimpleUnaryType::Abs, std::abs); testOp(SimpleUnaryType::Sqrt, std::sqrt); testOp(SimpleUnaryType::Tanh, std::tanh); + testOp(SimpleUnaryType::Erf, std::erf); } diff --git a/src/04kernel/test/kernels/simple_unary/test_cuda.cpp b/src/04kernel/test/kernels/simple_unary/test_cuda.cpp index 6ff5d798b..ce8d66f8c 100644 --- a/src/04kernel/test/kernels/simple_unary/test_cuda.cpp +++ b/src/04kernel/test/kernels/simple_unary/test_cuda.cpp @@ -51,6 +51,7 @@ TEST(kernel, SimpleUnaryCuda) { testOp(SimpleUnaryType::Sqrt); testOp(SimpleUnaryType::Sigmoid); testOp(SimpleUnaryType::Tanh); + testOp(SimpleUnaryType::Erf); } #endif diff --git a/src/07onnx/test/test_simple_unary.cpp b/src/07onnx/test/test_simple_unary.cpp new file mode 100644 index 000000000..12529a1a5 --- /dev/null +++ b/src/07onnx/test/test_simple_unary.cpp @@ -0,0 +1,25 @@ +#include "../src/operators/simple_unary.hh" +#include "onnx/operators.h" +#include + +using namespace refactor; +using namespace onnx; + +TEST(infer, SimpleUnary) { + onnx::register_(); + + { + // Erf Test + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}, + }; + count_t inputs[]{0}; + auto infered = SimpleUnary(SimpleUnaryType::Erf).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); + } +}