From 0dc88d5f6d130a311ec2eb251873249dfd064efc Mon Sep 17 00:00:00 2001 From: luqiang guo <702572275@qq.com> Date: Mon, 14 Feb 2022 11:12:27 -0600 Subject: [PATCH] Add oneDNN binary op (#7319) * add * merge master * Solve the thread pool problem * add device local logical cores * fix error * Delete threadpool * fix include file * fix clang -lopm * fix clang error omp.h * fix omp cmake * omp.h * fix #ifdef * test clang13 -lomp * test -fopenmp * add fopenmp * rename OMP_FLAGS * static analysis libopm-12-dev * add tbb * refien * refine * refine * refine * revert * add tbb * success add tbb * tbb onednn ok * fix ninja onednn * component * install tbb include file * updata tbb master zip * fix md5 * refine * refjine * fix * cmake option * modified clang 10 OMP * add line * fix add OMP flags * fix tbb * fix * fix * fix' * fix * fix * fix OF_RUNTIME_TBB * fix * modified binary op * fix * fix * fux error * fix * fix * fix * refine * refine * fix * add seq * refine * fix * fix * fix * add set_num_threads * fix * fi * fix error * fix * refine * refine * fix * refine * fix * refine * refine * refine * refine * refine * fix * refine * fix * fix * fix * fix * fix * refine * refine * refine * refine * refine * refine * refine * fix * fix * fix * refine * refine * auto format by CI * fix * rename mm_, dynamic_cast * auto format by CI * fix MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY * fix 0-dim tensor * fix onednn format tag * auto format by CI Co-authored-by: jackalcooper Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot --- oneflow/core/common/preprocessor.h | 2 + oneflow/core/common/preprocessor_internal.h | 3 + oneflow/core/ep/cpu/primitive/add.cpp | 4 + .../broadcast_elementwise_binary.cpp | 210 ++++++++++++++++++ oneflow/core/ep/cpu/primitive/type_seq.h | 2 + .../kernels/math_binary_broadcast_kernels.cpp | 112 ++++++++-- 6 files changed, 316 insertions(+), 17 deletions(-) diff --git a/oneflow/core/common/preprocessor.h b/oneflow/core/common/preprocessor.h index 62a9d12dd72..8f219375b02 100644 --- a/oneflow/core/common/preprocessor.h +++ b/oneflow/core/common/preprocessor.h @@ -27,6 +27,8 @@ limitations under the License. #define OF_PP_PAIR_SECOND(pair) OF_PP_INTERNAL_PAIR_SECOND(pair) +#define OF_PP_PAIR_THIRD(pair) OF_PP_INTERNAL_PAIR_THIRD(pair) + #define OF_PP_TUPLE_SIZE(t) OF_PP_INTERNAL_TUPLE_SIZE(t) #define OF_PP_TUPLE_ELEM(n, t) OF_PP_INTERNAL_TUPLE_ELEM(n, t) diff --git a/oneflow/core/common/preprocessor_internal.h b/oneflow/core/common/preprocessor_internal.h index a399d069101..78ccc47c861 100644 --- a/oneflow/core/common/preprocessor_internal.h +++ b/oneflow/core/common/preprocessor_internal.h @@ -91,9 +91,12 @@ limitations under the License. #define OF_PP_INTERNAL_PAIR_FIRST_I(t) OF_PP_INTERNAL_FIRST_ARG t #define OF_PP_INTERNAL_PAIR_SECOND(t) OF_PP_INTERNAL_PAIR_SECOND_I(t) #define OF_PP_INTERNAL_PAIR_SECOND_I(t) OF_PP_INTERNAL_SECOND_ARG t +#define OF_PP_INTERNAL_PAIR_THIRD(t) OF_PP_INTERNAL_PAIR_THIRD_I(t) +#define OF_PP_INTERNAL_PAIR_THIRD_I(t) OF_PP_INTERNAL_THIRD_ARG t #define OF_PP_INTERNAL_FIRST_ARG(x, ...) x #define OF_PP_INTERNAL_SECOND_ARG(x, y, ...) y +#define OF_PP_INTERNAL_THIRD_ARG(x, y, z, ...) z #define OF_PP_INTERNAL_MAKE_TUPLE(...) (__VA_ARGS__) #define OF_PP_INTERNAL_MAKE_TUPLE_SEQ(...) (OF_PP_INTERNAL_MAKE_TUPLE(__VA_ARGS__)) diff --git a/oneflow/core/ep/cpu/primitive/add.cpp b/oneflow/core/ep/cpu/primitive/add.cpp index 95b2b472b49..a90bb666655 100644 --- a/oneflow/core/ep/cpu/primitive/add.cpp +++ b/oneflow/core/ep/cpu/primitive/add.cpp @@ -88,6 +88,10 @@ class AddOneDnnImpl : public Add { for (int i = 1; i < arity; i++) { if (srcs[i] == dst) { LOG(FATAL) << "Only the first parameter can be operated inplace"; } } + CpuStream* cpu_stream = stream->As(); + size_t num_threads = static_cast(cpu_stream->device())->GetNumThreads(); + CpuNumThreadsGuard guard(num_threads); + dnnl::engine* onednn_engine = stream->As()->onednn_engine(); dnnl::stream* onednn_stream = stream->As()->onednn_stream(); diff --git a/oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp b/oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp index 03c4b82566f..df26698afc6 100644 --- a/oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp +++ b/oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp @@ -15,11 +15,14 @@ limitations under the License. */ #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" +#include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/cpu/primitive/binary_functor.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/ep/cpu/cpu_stream.h" +#include "oneflow/core/ep/cpu/cpu_device.h" namespace oneflow { @@ -130,6 +133,180 @@ std::unique_ptr NewBroadcastElementwiseBinary() { OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR) +#ifdef WITH_ONEDNN + +uint32_t OnednnFormatTagMap[kMaxNumDims] = {dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd, + dnnl_abcde, dnnl_abcdef, dnnl_abcdefg, dnnl_abcdefgh}; + +inline void OneDnnBroadcastDims(dnnl::memory::dims* src0, size_t num_src0_dims, + const int64_t* src0_dims, dnnl::memory::dims* src1, + size_t num_src1_dims, const int64_t* src1_dims, + dnnl::memory::dims& dst) { + const int64_t num_dims = dst.size(); + const int64_t num_src0_padding_dims = num_dims - num_src0_dims; + const int64_t num_src1_padding_dims = num_dims - num_src1_dims; + for (int64_t i = 0; i < num_dims; i++) { + int64_t src0_dim = i < num_src0_padding_dims ? 1 : src0_dims[i - num_src0_padding_dims]; + int64_t src1_dim = i < num_src1_padding_dims ? 1 : src1_dims[i - num_src1_padding_dims]; + CHECK((src0_dim == src1_dim || src0_dim == 1 || src1_dim == 1)); + (*src0)[i] = src0_dim; + (*src1)[i] = src1_dim; + dst[i] = std::max(src0_dim, src1_dim); + } +} + +template +class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { + public: + OF_DISALLOW_COPY_AND_MOVE(OneDnnBroadcastElementwiseBinaryImpl); + OneDnnBroadcastElementwiseBinaryImpl(){}; + ~OneDnnBroadcastElementwiseBinaryImpl() override = default; + + void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims, + const void* src1, void* dst) override { + T scalar_val = GetValue(src0); + const int64_t src0_dims = 1; + Launch(stream, num_src1_dims, src1_dims, src1, 1, &src0_dims, &scalar_val, dst); + } + void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, + Scalar src1, void* dst) override { + T scalar_val = GetValue(src1); + const int64_t src1_dims = 1; + Launch(stream, num_src0_dims, src0_dims, src0, 1, &src1_dims, &scalar_val, dst); + } + void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, + size_t num_src1_dims, const int64_t* src1_dims, const void* src1, + void* dst) override { + CpuStream* cpu_stream = stream->As(); + size_t num_threads = static_cast(cpu_stream->device())->GetNumThreads(); + CpuNumThreadsGuard guard(num_threads); + + dnnl::engine* onednn_engine = stream->As()->onednn_engine(); + dnnl::stream* onednn_stream = stream->As()->onednn_stream(); + size_t num_dims = std::max(num_src0_dims, num_src1_dims); + dnnl::memory::dims src_0_dims(num_dims); + dnnl::memory::dims src_1_dims(num_dims); + dnnl::memory::dims dst_dims(num_dims); + const void* onednn_src0 = nullptr; + const void* onednn_src1 = nullptr; + + // OneDNN inplace operations only support src_0 + if (src1 == dst) { + onednn_src0 = src1; + onednn_src1 = src0; + OneDnnBroadcastDims(&src_0_dims, num_src1_dims, src1_dims, &src_1_dims, num_src0_dims, + src0_dims, dst_dims); + } else { + onednn_src0 = src0; + onednn_src1 = src1; + OneDnnBroadcastDims(&src_0_dims, num_src0_dims, src0_dims, &src_1_dims, num_src1_dims, + src1_dims, dst_dims); + } + + CheckInplace(num_dims, src_0_dims.data(), onednn_src0, src_1_dims.data(), onednn_src1, + dst_dims.data(), dst); + + auto src_0_md = + dnnl::memory::desc(src_0_dims, src_onednn, + static_cast(OnednnFormatTagMap[num_dims - 1])); + auto src_1_md = + dnnl::memory::desc(src_1_dims, src_onednn, + static_cast(OnednnFormatTagMap[num_dims - 1])); + auto dst_md = + dnnl::memory::desc(dst_dims, dst_onednn, + static_cast(OnednnFormatTagMap[num_dims - 1])); + + auto src_0_mem = dnnl::memory(src_0_md, *onednn_engine, (void*)onednn_src0); + auto src_1_mem = dnnl::memory(src_1_md, *onednn_engine, (void*)onednn_src1); + auto dst_mem = dnnl::memory(dst_md, *onednn_engine, dst); + + auto binary_d = dnnl::binary::desc(algorithm, src_0_md, src_1_md, dst_md); + auto binary_pd = dnnl::binary::primitive_desc(binary_d, *onednn_engine); + auto binary_prim = dnnl::binary(binary_pd); + + std::unordered_map binary_args{ + {DNNL_ARG_SRC_0, src_0_mem}, {DNNL_ARG_SRC_1, src_1_mem}, {DNNL_ARG_DST, dst_mem}}; + + binary_prim.execute(*onednn_stream, binary_args); + onednn_stream->wait(); + } +}; + +#define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8, int8_t) \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool) \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kUInt8, uint8_t) \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float) \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f16, DataType::kFloat16, float16) + +// OneDNN binary op does not support s32 +// CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ + +#define CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ \ + CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ + CPU_PRIMITIVE_INT32_TYPE_SEQ \ + CPU_PRIMITIVE_INT64_TYPE_SEQ + +#define BINARY_ONEDNN_ADD OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd, dnnl::algorithm::binary_add) +#define BINARY_ONEDNN_SUB OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub, dnnl::algorithm::binary_sub) +#define BINARY_ONEDNN_MUL OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul, dnnl::algorithm::binary_mul) +#define BINARY_ONEDNN_DIV OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv, dnnl::algorithm::binary_div) +#define BINARY_ONEDNN_MAX OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax, dnnl::algorithm::binary_max) +#define BINARY_ONEDNN_MIN OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin, dnnl::algorithm::binary_min) + +#define BINARY_ONEDNN_EQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual, dnnl::algorithm::binary_eq) +#define BINARY_ONEDNN_NE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual, dnnl::algorithm::binary_ne) +#define BINARY_ONEDNN_LT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan, dnnl::algorithm::binary_lt) +#define BINARY_ONEDNN_LE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual, dnnl::algorithm::binary_le) +#define BINARY_ONEDNN_GT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan, dnnl::algorithm::binary_gt) +#define BINARY_ONEDNN_GE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual, dnnl::algorithm::binary_ge) + +#define BINARY_MATH_OP_ONEDNN_PAIR \ + BINARY_ONEDNN_ADD \ + BINARY_ONEDNN_SUB \ + BINARY_ONEDNN_MUL \ + BINARY_ONEDNN_DIV \ + BINARY_ONEDNN_MAX \ + BINARY_ONEDNN_MIN + +#define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR \ + BINARY_ONEDNN_EQ \ + BINARY_ONEDNN_NE \ + BINARY_ONEDNN_LT \ + BINARY_ONEDNN_LE \ + BINARY_ONEDNN_GT \ + BINARY_ONEDNN_GE + +#define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_UNIMPLEMENTED \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd, AND) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR) + +template +std::unique_ptr NewOneDnnBroadcastElementwiseBinary() { + return std::unique_ptr( + new OneDnnBroadcastElementwiseBinaryImpl()); +} + +#define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op_pair, data_type_pair) \ + {std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(data_type_pair), \ + OF_PP_PAIR_SECOND(data_type_pair)), \ + NewOneDnnBroadcastElementwiseBinary< \ + OF_PP_PAIR_THIRD(data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair), \ + OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>}, + +#define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \ + binary_op_pair, src_data_type_pair, dst_data_type_pair) \ + {std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(src_data_type_pair), \ + OF_PP_PAIR_SECOND(dst_data_type_pair)), \ + NewOneDnnBroadcastElementwiseBinary< \ + OF_PP_PAIR_THIRD(src_data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair), \ + OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>}, + +#endif // WITH_ONEDNN + class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory { public: OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl); @@ -158,6 +335,38 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF &NdarrayUtil::OF_PP_CAT( \ Broadcast, OF_PP_PAIR_SECOND(binary_op_pair))>}, +#ifdef WITH_ONEDNN + static const std::map, + std::function()>> + new_broadcast_elementwise_binary_handle{ + // For oneDNN binary op + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_ONEDNN_PAIR, + CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ) + // For OneDNN comparasion binary op + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, + BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR, CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ, + CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ) + // OneDNN unimplemented binary op + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow, Pow), + NDARRAY_BINARY_TYPE_SEQ) + // OneDNN unimplemented comparasion binary op + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, + BINARY_LOGICAL_COMPARISION_OP_ONEDNN_UNIMPLEMENTED, NDARRAY_BINARY_TYPE_SEQ, + CPU_PRIMITIVE_BOOL_TYPE_SEQ) + // OneDNN unimplemented data type binary op + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, + BINARY_MATH_OP_NDARRAY_PAIR, + CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ) + // OneDNN unimplemented data type comparasion binary op + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, + BINARY_LOGICAL_COMPARISION_OP_NDARRAY_PAIR, + CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)}; +#else static const std::map, std::function()>> new_broadcast_elementwise_binary_handle{ @@ -167,6 +376,7 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, BINARY_LOGICAL_COMPARISION_OP_NDARRAY_PAIR, NDARRAY_BINARY_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)}; +#endif #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY diff --git a/oneflow/core/ep/cpu/primitive/type_seq.h b/oneflow/core/ep/cpu/primitive/type_seq.h index 16588ee8e0c..c5a5db6f9ae 100644 --- a/oneflow/core/ep/cpu/primitive/type_seq.h +++ b/oneflow/core/ep/cpu/primitive/type_seq.h @@ -34,6 +34,8 @@ limitations under the License. #define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) +#define CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool) #define CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ \ OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8) #define CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ \ diff --git a/oneflow/user/kernels/math_binary_broadcast_kernels.cpp b/oneflow/user/kernels/math_binary_broadcast_kernels.cpp index b02d0618f5a..5d71785320b 100644 --- a/oneflow/user/kernels/math_binary_broadcast_kernels.cpp +++ b/oneflow/user/kernels/math_binary_broadcast_kernels.cpp @@ -20,9 +20,97 @@ limitations under the License. #include "oneflow/core/ndarray/xpu_var_ndarray.h" #include "oneflow/user/ops/math_binary_broadcast_seq.h" #include "oneflow/core/kernel/cuda_graph_support.h" - +#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" namespace oneflow { +template +std::unique_ptr NewBroadcastElementwiseBinaryPrimitive( + Context* ctx) { + const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex("x", 0); + const user_op::TensorDesc* z = ctx->TensorDesc4ArgNameAndIndex("z", 0); + size_t num_axes = z->shape().NumAxes(); + return ep::primitive::NewPrimitive( + ctx->device_type(), binary_op, x->data_type(), z->data_type(), num_axes); +} + +template +class MathBinaryBroadcastEpKernel final : public user_op::OpKernel, + public user_op::CudaGraphSupport { + public: + MathBinaryBroadcastEpKernel() = default; + ~MathBinaryBroadcastEpKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); + user_op::Tensor* z = ctx->Tensor4ArgNameAndIndex("z", 0); + + auto primitive = + NewBroadcastElementwiseBinaryPrimitive(ctx); + CHECK(primitive.get() != nullptr) << "Exceeds maximum supported dimensions"; + + const int64_t x_elem_cnt = x->shape().elem_cnt(); + const int64_t y_elem_cnt = y->shape().elem_cnt(); + size_t num_src0_dims = x->shape().NumAxes(); + size_t num_src1_dims = y->shape().NumAxes(); + + int64_t zero_dim = 1; + int64_t* src0_dims = const_cast(x->shape().ptr()); + int64_t* src1_dims = const_cast(y->shape().ptr()); + + if (x_elem_cnt != 0 && y_elem_cnt != 0) { + if (num_src0_dims == 0) { + num_src0_dims = 1; + src0_dims = &zero_dim; + } + if (num_src1_dims == 0) { + num_src1_dims = 1; + src1_dims = &zero_dim; + } + + primitive->Launch(ctx->stream(), num_src0_dims, src0_dims, x->dptr(), num_src1_dims, + src1_dims, y->dptr(), z->mut_dptr()); + } else { + // For 0-d Tensor + return; + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template +auto MathBinaryBroadcastPrimitiveExists() { + return hob::make_custom("MathBinaryBroadcastPrimitiveExists", [](const user_op::KernelRegContext& + ctx) { + return NewBroadcastElementwiseBinaryPrimitive(&ctx). + operator bool(); + }); +} + +#define REGISTER_BINARY_BROADCAST_EP_KERNEL(math_type_pair, binary_op) \ + REGISTER_USER_KERNEL(math_type_pair) \ + .SetCreateFn>() \ + .SetIsMatchedHob(MathBinaryBroadcastPrimitiveExists() == true); + +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_add", ep::primitive::BinaryOp::kAdd) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_sub", ep::primitive::BinaryOp::kSub) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_mul", ep::primitive::BinaryOp::kMul) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_div", ep::primitive::BinaryOp::kDiv) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_minimum", ep::primitive::BinaryOp::kMin) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_maximum", ep::primitive::BinaryOp::kMax) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_pow", ep::primitive::BinaryOp::kPow) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_equal", ep::primitive::BinaryOp::kEqual) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_not_equal", ep::primitive::BinaryOp::kNotEqual) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_greater", ep::primitive::BinaryOp::kGreaterThan) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_greater_equal", + ep::primitive::BinaryOp::kGreaterEqual) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_less", ep::primitive::BinaryOp::kLessThan) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_less_equal", ep::primitive::BinaryOp::kLessEqual) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_logical_and", ep::primitive::BinaryOp::kLogicalAnd) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_logical_or", ep::primitive::BinaryOp::kLogicalOr) +REGISTER_BINARY_BROADCAST_EP_KERNEL("broadcast_logical_xor", ep::primitive::BinaryOp::kLogicalXor) + template& z, const XpuVarNdarray& x, const XpuVarNdarray& y)> @@ -47,6 +135,10 @@ class MathBinaryBroadcastKernel final : public user_op::OpKernel, public user_op bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; +#define MATH_BINARY_BROADCAST_DEFAULT_FUNC_SEQ \ + OF_PP_MAKE_TUPLE_SEQ("broadcast_floor_mod", FloorMod) \ + OF_PP_MAKE_TUPLE_SEQ("broadcast_fmod", FMod) + #define REGISTER_MATH_BINARY_BROADCAST_KERNEL(math_type_pair, device, data_type_pair) \ REGISTER_USER_KERNEL(OF_PP_PAIR_FIRST(math_type_pair)) \ .SetCreateFn::OF_PP_CAT( \ - Broadcast, OF_PP_PAIR_SECOND(math_type_pair))>>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device) \ - && (user_op::HobDataType("x", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \ - && (user_op::HobDataType("z", 0) == DataType::kBool)); - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( - REGISTER_MATH_BINARY_BROADCAST_LOGICAL_KERNEL, MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ, - DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ) - } // namespace oneflow