Skip to content

Commit

Permalink
Add oneDNN binary op (#7319)
Browse files Browse the repository at this point in the history
* add

* merge master

* Solve the thread pool problem

* add device  local logical cores

* fix error

* Delete threadpool

* fix include file

* fix clang -lopm

* fix  clang error omp.h

* fix omp cmake

* omp.h

* fix #ifdef

* test clang13 -lomp

* test -fopenmp

* add fopenmp

* rename OMP_FLAGS

* static analysis libopm-12-dev

* add tbb

* refien

* refine

* refine

* refine

* revert

* add tbb

* success  add tbb

* tbb onednn ok

* fix ninja onednn

* component

* install tbb include file

* updata tbb master zip

* fix md5

* refine

* refjine

* fix

* cmake option

* modified  clang 10 OMP

* add line

* fix add OMP flags

* fix tbb

* fix

* fix

* fix'

* fix

* fix

* fix OF_RUNTIME_TBB

* fix

* modified binary op

* fix

* fix

* fux error

* fix

* fix

* fix

* refine

* refine

* fix

* add seq

* refine

* fix

* fix

* fix

* add set_num_threads

* fix

* fi

* fix  error

* fix

* refine

* refine

* fix

* refine

* fix

* refine

* refine

* refine

* refine

* refine

* fix

* refine

* fix

* fix

* fix

* fix

* fix

* refine

* refine

* refine

* refine

* refine

* refine

* refine

* fix

* fix

* fix

* refine

* refine

* auto format by CI

* fix

* rename  mm_, dynamic_cast

* auto format by CI

* fix MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY

* fix 0-dim tensor

* fix onednn format tag

* auto format by CI

Co-authored-by: jackalcooper <jackalcooper@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
4 people committed Feb 14, 2022
1 parent 404a7de commit 0dc88d5
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 17 deletions.
2 changes: 2 additions & 0 deletions oneflow/core/common/preprocessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ limitations under the License.

#define OF_PP_PAIR_SECOND(pair) OF_PP_INTERNAL_PAIR_SECOND(pair)

#define OF_PP_PAIR_THIRD(pair) OF_PP_INTERNAL_PAIR_THIRD(pair)

#define OF_PP_TUPLE_SIZE(t) OF_PP_INTERNAL_TUPLE_SIZE(t)

#define OF_PP_TUPLE_ELEM(n, t) OF_PP_INTERNAL_TUPLE_ELEM(n, t)
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/common/preprocessor_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ limitations under the License.
#define OF_PP_INTERNAL_PAIR_FIRST_I(t) OF_PP_INTERNAL_FIRST_ARG t
#define OF_PP_INTERNAL_PAIR_SECOND(t) OF_PP_INTERNAL_PAIR_SECOND_I(t)
#define OF_PP_INTERNAL_PAIR_SECOND_I(t) OF_PP_INTERNAL_SECOND_ARG t
#define OF_PP_INTERNAL_PAIR_THIRD(t) OF_PP_INTERNAL_PAIR_THIRD_I(t)
#define OF_PP_INTERNAL_PAIR_THIRD_I(t) OF_PP_INTERNAL_THIRD_ARG t

#define OF_PP_INTERNAL_FIRST_ARG(x, ...) x
#define OF_PP_INTERNAL_SECOND_ARG(x, y, ...) y
#define OF_PP_INTERNAL_THIRD_ARG(x, y, z, ...) z

#define OF_PP_INTERNAL_MAKE_TUPLE(...) (__VA_ARGS__)
#define OF_PP_INTERNAL_MAKE_TUPLE_SEQ(...) (OF_PP_INTERNAL_MAKE_TUPLE(__VA_ARGS__))
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/ep/cpu/primitive/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class AddOneDnnImpl : public Add {
for (int i = 1; i < arity; i++) {
if (srcs[i] == dst) { LOG(FATAL) << "Only the first parameter can be operated inplace"; }
}
CpuStream* cpu_stream = stream->As<CpuStream>();
size_t num_threads = static_cast<CpuDevice*>(cpu_stream->device())->GetNumThreads();
CpuNumThreadsGuard guard(num_threads);

dnnl::engine* onednn_engine = stream->As<CpuStream>()->onednn_engine();
dnnl::stream* onednn_stream = stream->As<CpuStream>()->onednn_stream();

Expand Down
210 changes: 210 additions & 0 deletions oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ limitations under the License.
*/

#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cpu/primitive/binary_functor.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
#include "oneflow/core/ndarray/ndarray_util.h"
#include "oneflow/core/ndarray/xpu_var_ndarray.h"
#include "oneflow/core/ep/cpu/cpu_stream.h"
#include "oneflow/core/ep/cpu/cpu_device.h"

namespace oneflow {

Expand Down Expand Up @@ -130,6 +133,180 @@ std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary() {
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR)

#ifdef WITH_ONEDNN

uint32_t OnednnFormatTagMap[kMaxNumDims] = {dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd,
dnnl_abcde, dnnl_abcdef, dnnl_abcdefg, dnnl_abcdefgh};

inline void OneDnnBroadcastDims(dnnl::memory::dims* src0, size_t num_src0_dims,
const int64_t* src0_dims, dnnl::memory::dims* src1,
size_t num_src1_dims, const int64_t* src1_dims,
dnnl::memory::dims& dst) {
const int64_t num_dims = dst.size();
const int64_t num_src0_padding_dims = num_dims - num_src0_dims;
const int64_t num_src1_padding_dims = num_dims - num_src1_dims;
for (int64_t i = 0; i < num_dims; i++) {
int64_t src0_dim = i < num_src0_padding_dims ? 1 : src0_dims[i - num_src0_padding_dims];
int64_t src1_dim = i < num_src1_padding_dims ? 1 : src1_dims[i - num_src1_padding_dims];
CHECK((src0_dim == src1_dim || src0_dim == 1 || src1_dim == 1));
(*src0)[i] = src0_dim;
(*src1)[i] = src1_dim;
dst[i] = std::max(src0_dim, src1_dim);
}
}

template<typename T, dnnl::algorithm algorithm, dnnl::memory::data_type src_onednn,
dnnl::memory::data_type dst_onednn>
class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
public:
OF_DISALLOW_COPY_AND_MOVE(OneDnnBroadcastElementwiseBinaryImpl);
OneDnnBroadcastElementwiseBinaryImpl(){};
~OneDnnBroadcastElementwiseBinaryImpl() override = default;

void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,
const void* src1, void* dst) override {
T scalar_val = GetValue<T>(src0);
const int64_t src0_dims = 1;
Launch(stream, num_src1_dims, src1_dims, src1, 1, &src0_dims, &scalar_val, dst);
}
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
Scalar src1, void* dst) override {
T scalar_val = GetValue<T>(src1);
const int64_t src1_dims = 1;
Launch(stream, num_src0_dims, src0_dims, src0, 1, &src1_dims, &scalar_val, dst);
}
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
size_t num_src1_dims, const int64_t* src1_dims, const void* src1,
void* dst) override {
CpuStream* cpu_stream = stream->As<CpuStream>();
size_t num_threads = static_cast<CpuDevice*>(cpu_stream->device())->GetNumThreads();
CpuNumThreadsGuard guard(num_threads);

dnnl::engine* onednn_engine = stream->As<CpuStream>()->onednn_engine();
dnnl::stream* onednn_stream = stream->As<CpuStream>()->onednn_stream();
size_t num_dims = std::max(num_src0_dims, num_src1_dims);
dnnl::memory::dims src_0_dims(num_dims);
dnnl::memory::dims src_1_dims(num_dims);
dnnl::memory::dims dst_dims(num_dims);
const void* onednn_src0 = nullptr;
const void* onednn_src1 = nullptr;

// OneDNN inplace operations only support src_0
if (src1 == dst) {
onednn_src0 = src1;
onednn_src1 = src0;
OneDnnBroadcastDims(&src_0_dims, num_src1_dims, src1_dims, &src_1_dims, num_src0_dims,
src0_dims, dst_dims);
} else {
onednn_src0 = src0;
onednn_src1 = src1;
OneDnnBroadcastDims(&src_0_dims, num_src0_dims, src0_dims, &src_1_dims, num_src1_dims,
src1_dims, dst_dims);
}

CheckInplace(num_dims, src_0_dims.data(), onednn_src0, src_1_dims.data(), onednn_src1,
dst_dims.data(), dst);

auto src_0_md =
dnnl::memory::desc(src_0_dims, src_onednn,
static_cast<dnnl::memory::format_tag>(OnednnFormatTagMap[num_dims - 1]));
auto src_1_md =
dnnl::memory::desc(src_1_dims, src_onednn,
static_cast<dnnl::memory::format_tag>(OnednnFormatTagMap[num_dims - 1]));
auto dst_md =
dnnl::memory::desc(dst_dims, dst_onednn,
static_cast<dnnl::memory::format_tag>(OnednnFormatTagMap[num_dims - 1]));

auto src_0_mem = dnnl::memory(src_0_md, *onednn_engine, (void*)onednn_src0);
auto src_1_mem = dnnl::memory(src_1_md, *onednn_engine, (void*)onednn_src1);
auto dst_mem = dnnl::memory(dst_md, *onednn_engine, dst);

auto binary_d = dnnl::binary::desc(algorithm, src_0_md, src_1_md, dst_md);
auto binary_pd = dnnl::binary::primitive_desc(binary_d, *onednn_engine);
auto binary_prim = dnnl::binary(binary_pd);

std::unordered_map<int, dnnl::memory> binary_args{
{DNNL_ARG_SRC_0, src_0_mem}, {DNNL_ARG_SRC_1, src_1_mem}, {DNNL_ARG_DST, dst_mem}};

binary_prim.execute(*onednn_stream, binary_args);
onednn_stream->wait();
}
};

#define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8, int8_t) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kUInt8, uint8_t) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f16, DataType::kFloat16, float16)

// OneDNN binary op does not support s32
// CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ

#define CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \
CPU_PRIMITIVE_INT32_TYPE_SEQ \
CPU_PRIMITIVE_INT64_TYPE_SEQ

#define BINARY_ONEDNN_ADD OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd, dnnl::algorithm::binary_add)
#define BINARY_ONEDNN_SUB OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub, dnnl::algorithm::binary_sub)
#define BINARY_ONEDNN_MUL OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul, dnnl::algorithm::binary_mul)
#define BINARY_ONEDNN_DIV OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv, dnnl::algorithm::binary_div)
#define BINARY_ONEDNN_MAX OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax, dnnl::algorithm::binary_max)
#define BINARY_ONEDNN_MIN OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin, dnnl::algorithm::binary_min)

#define BINARY_ONEDNN_EQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual, dnnl::algorithm::binary_eq)
#define BINARY_ONEDNN_NE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual, dnnl::algorithm::binary_ne)
#define BINARY_ONEDNN_LT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan, dnnl::algorithm::binary_lt)
#define BINARY_ONEDNN_LE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual, dnnl::algorithm::binary_le)
#define BINARY_ONEDNN_GT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan, dnnl::algorithm::binary_gt)
#define BINARY_ONEDNN_GE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual, dnnl::algorithm::binary_ge)

#define BINARY_MATH_OP_ONEDNN_PAIR \
BINARY_ONEDNN_ADD \
BINARY_ONEDNN_SUB \
BINARY_ONEDNN_MUL \
BINARY_ONEDNN_DIV \
BINARY_ONEDNN_MAX \
BINARY_ONEDNN_MIN

#define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR \
BINARY_ONEDNN_EQ \
BINARY_ONEDNN_NE \
BINARY_ONEDNN_LT \
BINARY_ONEDNN_LE \
BINARY_ONEDNN_GT \
BINARY_ONEDNN_GE

#define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_UNIMPLEMENTED \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd, AND) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR)

template<typename T, dnnl::algorithm algorithm, dnnl::memory::data_type src_onednn,
dnnl::memory::data_type dst_onednn>
std::unique_ptr<BroadcastElementwiseBinary> NewOneDnnBroadcastElementwiseBinary() {
return std::unique_ptr<BroadcastElementwiseBinary>(
new OneDnnBroadcastElementwiseBinaryImpl<T, algorithm, src_onednn, dst_onednn>());
}

#define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op_pair, data_type_pair) \
{std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \
NewOneDnnBroadcastElementwiseBinary< \
OF_PP_PAIR_THIRD(data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair), \
OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>},

#define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op_pair, src_data_type_pair, dst_data_type_pair) \
{std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(src_data_type_pair), \
OF_PP_PAIR_SECOND(dst_data_type_pair)), \
NewOneDnnBroadcastElementwiseBinary< \
OF_PP_PAIR_THIRD(src_data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair), \
OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>},

#endif // WITH_ONEDNN

class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl);
Expand Down Expand Up @@ -158,6 +335,38 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
&NdarrayUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(src_data_type_pair)>::OF_PP_CAT( \
Broadcast, OF_PP_PAIR_SECOND(binary_op_pair))>},

#ifdef WITH_ONEDNN
static const std::map<std::tuple<BinaryOp, DataType, DataType>,
std::function<std::unique_ptr<BroadcastElementwiseBinary>()>>
new_broadcast_elementwise_binary_handle{
// For oneDNN binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_ONEDNN_PAIR,
CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ)
// For OneDNN comparasion binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR, CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ,
CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ)
// OneDNN unimplemented binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow, Pow),
NDARRAY_BINARY_TYPE_SEQ)
// OneDNN unimplemented comparasion binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_COMPARISION_OP_ONEDNN_UNIMPLEMENTED, NDARRAY_BINARY_TYPE_SEQ,
CPU_PRIMITIVE_BOOL_TYPE_SEQ)
// OneDNN unimplemented data type binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_NDARRAY_PAIR,
CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ)
// OneDNN unimplemented data type comparasion binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_COMPARISION_OP_NDARRAY_PAIR,
CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)};
#else
static const std::map<std::tuple<BinaryOp, DataType, DataType>,
std::function<std::unique_ptr<BroadcastElementwiseBinary>()>>
new_broadcast_elementwise_binary_handle{
Expand All @@ -167,6 +376,7 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_COMPARISION_OP_NDARRAY_PAIR, NDARRAY_BINARY_TYPE_SEQ,
CPU_PRIMITIVE_BOOL_TYPE_SEQ)};
#endif

#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/ep/cpu/primitive/type_seq.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ limitations under the License.
#define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)

#define CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool)
#define CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8)
#define CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ \
Expand Down
Loading

0 comments on commit 0dc88d5

Please sign in to comment.