Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add oneDNN binary op #7319

Merged
merged 170 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from 150 commits
Commits
Show all changes
170 commits
Select commit Hold shift + click to select a range
c79f6e8
add
luqiang-guo Dec 13, 2021
63eec78
merge master
luqiang-guo Dec 15, 2021
95311b3
merge master
luqiang-guo Dec 15, 2021
d068bbf
Solve the thread pool problem
luqiang-guo Dec 20, 2021
85c0163
merge master
luqiang-guo Dec 21, 2021
4783a91
add device local logical cores
luqiang-guo Dec 21, 2021
26cd0e7
fix error
luqiang-guo Dec 21, 2021
624d7e9
Delete threadpool
luqiang-guo Dec 21, 2021
7b6e4d2
Merge branch 'master' into dev_parallel_loop
luqiang-guo Dec 21, 2021
543c726
fix include file
luqiang-guo Dec 21, 2021
36a755b
Merge branch 'dev_parallel_loop' of https://github.com/Oneflow-Inc/on…
luqiang-guo Dec 21, 2021
2288a11
fix clang -lopm
luqiang-guo Dec 21, 2021
5fec766
fix clang error omp.h
luqiang-guo Dec 21, 2021
3de09b6
fix omp cmake
luqiang-guo Dec 21, 2021
9c05b6c
omp.h
luqiang-guo Dec 21, 2021
17bd1bb
fix #ifdef
luqiang-guo Dec 21, 2021
c226521
test clang13 -lomp
luqiang-guo Dec 22, 2021
c4a5179
test -fopenmp
luqiang-guo Dec 22, 2021
4b028a6
add fopenmp
luqiang-guo Dec 22, 2021
0eb1059
Merge branch 'master' into dev_parallel_loop
luqiang-guo Dec 23, 2021
784badd
rename OMP_FLAGS
luqiang-guo Dec 23, 2021
e500586
Merge branch 'dev_parallel_loop' of https://github.com/Oneflow-Inc/on…
luqiang-guo Dec 23, 2021
b689676
Merge branch 'master' into dev_parallel_loop
luqiang-guo Dec 23, 2021
bafea64
Merge branch 'master' into dev_parallel_loop
luqiang-guo Dec 24, 2021
3d6c191
Merge branch 'master' into dev_parallel_loop
luqiang-guo Dec 24, 2021
72c39e9
Merge branch 'dev_parallel_loop' of https://github.com/Oneflow-Inc/on…
luqiang-guo Dec 24, 2021
d00a1da
static analysis libopm-12-dev
luqiang-guo Dec 24, 2021
18f363f
add tbb
luqiang-guo Dec 31, 2021
71dac72
refien
jackalcooper Dec 31, 2021
6eee93e
refine
jackalcooper Dec 31, 2021
267900b
refine
jackalcooper Dec 31, 2021
fbad306
refine
jackalcooper Dec 31, 2021
76c169c
revert
jackalcooper Jan 1, 2022
0f43fd3
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
jackalcooper Jan 1, 2022
86baed9
add tbb
luqiang-guo Jan 4, 2022
5006973
Merge branch 'master' into dev_test_tbb
luqiang-guo Jan 4, 2022
c8740bf
success add tbb
luqiang-guo Jan 7, 2022
9e18b53
tbb onednn ok
luqiang-guo Jan 9, 2022
cdac12d
Merge branch 'master' into dev_test_tbb
luqiang-guo Jan 9, 2022
85ed101
fix ninja onednn
luqiang-guo Jan 9, 2022
2cdca83
component
luqiang-guo Jan 10, 2022
5bf716e
install tbb include file
luqiang-guo Jan 10, 2022
9eb33c0
Merge branch 'master' into dev_test_tbb
luqiang-guo Jan 10, 2022
5eed042
Merge branch 'master' into dev_test_tbb
luqiang-guo Jan 11, 2022
e0e704a
updata tbb master zip
luqiang-guo Jan 12, 2022
f7ad349
Merge branch 'dev_test_tbb' of https://github.com/Oneflow-Inc/oneflow…
luqiang-guo Jan 12, 2022
ec8b5e0
Merge branch 'master' into dev_test_tbb
luqiang-guo Jan 12, 2022
d4d02c7
fix md5
luqiang-guo Jan 12, 2022
12faccd
Merge branch 'dev_test_tbb' of https://github.com/Oneflow-Inc/oneflow…
luqiang-guo Jan 12, 2022
1b396e5
Merge branch 'master' into dev_test_tbb
luqiang-guo Jan 12, 2022
a6d9f0b
refine
jackalcooper Jan 12, 2022
bb4141e
refjine
jackalcooper Jan 12, 2022
afd426a
fix
jackalcooper Jan 12, 2022
22157d1
Merge branch 'master' into dev_test_tbb
luqiang-guo Jan 12, 2022
4cc7858
Merge branch 'master' into dev_test_tbb
luqiang-guo Jan 13, 2022
53ec863
Merge branch 'master' into dev_parallel_loop
luqiang-guo Jan 13, 2022
35e6105
cmake option
luqiang-guo Jan 13, 2022
134c5f0
Merge branch 'dev_test_tbb' of https://github.com/Oneflow-Inc/oneflow…
luqiang-guo Jan 13, 2022
32623f7
modified clang 10 OMP
luqiang-guo Jan 13, 2022
e375b4c
add line
luqiang-guo Jan 13, 2022
33497de
fix add OMP flags
luqiang-guo Jan 16, 2022
ef9c444
fix tbb
luqiang-guo Jan 16, 2022
964f0a8
fix
luqiang-guo Jan 16, 2022
6eb32c8
fix
luqiang-guo Jan 16, 2022
689bcd7
fix'
luqiang-guo Jan 16, 2022
9367609
fix
luqiang-guo Jan 17, 2022
a8f8bc6
Merge branch 'dev_test_tbb' into dev_parallel_loop
luqiang-guo Jan 17, 2022
eb38712
Merge branch 'dev_test_tbb' into dev_parallel_loop
luqiang-guo Jan 17, 2022
34cc993
fix
luqiang-guo Jan 17, 2022
bfca0d6
fix OF_RUNTIME_TBB
luqiang-guo Jan 17, 2022
91fb423
fix
luqiang-guo Jan 17, 2022
934b567
Merge branch 'dev_test_tbb' into dev_parallel_loop
luqiang-guo Jan 17, 2022
5b51aa0
modified binary op
luqiang-guo Jan 20, 2022
47040c1
fix
luqiang-guo Jan 21, 2022
de6e03a
fix
luqiang-guo Jan 21, 2022
0c9002b
fux error
luqiang-guo Jan 24, 2022
0aad3df
Merge branch 'master' into dev_fix_tbb_error
luqiang-guo Jan 24, 2022
00d5dac
fix
luqiang-guo Jan 24, 2022
68fe7d4
Merge branch 'dev_fix_tbb_error' of https://github.com/Oneflow-Inc/on…
luqiang-guo Jan 24, 2022
707f892
fix
luqiang-guo Jan 24, 2022
e535ab1
fix
luqiang-guo Jan 24, 2022
5037f19
Merge branch 'master' into dev_fix_tbb_error
luqiang-guo Jan 24, 2022
3836c29
refine
jackalcooper Jan 24, 2022
c8ea8b8
refine
jackalcooper Jan 24, 2022
b488247
Merge branch 'master' into dev_fix_tbb_error
luqiang-guo Jan 24, 2022
b4dbd08
Merge branch 'master' into dev_fix_tbb_error
oneflow-ci-bot Jan 24, 2022
722d003
fix
luqiang-guo Jan 24, 2022
5307906
Merge branch 'dev_fix_tbb_error' of https://github.com/Oneflow-Inc/on…
luqiang-guo Jan 24, 2022
0ebb4c0
Merge branch 'master' into dev_fix_tbb_error
oneflow-ci-bot Jan 24, 2022
412ba88
Merge branch 'master' into dev_fix_tbb_error
luqiang-guo Jan 25, 2022
c355bb2
add seq
luqiang-guo Jan 25, 2022
fee0ac2
Merge branch 'dev_fix_tbb_error' of https://github.com/Oneflow-Inc/on…
luqiang-guo Jan 25, 2022
2638976
Merge branch 'master' into dev_fix_tbb_error
luqiang-guo Jan 25, 2022
eda5a8d
refine
luqiang-guo Jan 25, 2022
ee357df
Merge branch 'dev_parallel_loop' of https://github.com/Oneflow-Inc/on…
luqiang-guo Jan 25, 2022
9c6db4a
merge master
luqiang-guo Jan 25, 2022
db21828
Merge branch 'master' into dev_fix_tbb_error
oneflow-ci-bot Jan 25, 2022
a8af104
fix
luqiang-guo Jan 25, 2022
68e3cbf
Merge branch 'dev_fix_tbb_error' of https://github.com/Oneflow-Inc/on…
luqiang-guo Jan 25, 2022
d1995f2
Merge branch 'dev_fix_tbb_error' into dev_parallel_loop
luqiang-guo Jan 25, 2022
e4c760f
fix
luqiang-guo Jan 25, 2022
866c1e7
Merge branch 'master' into dev_parallel_loop
luqiang-guo Jan 25, 2022
1809cb6
fix
luqiang-guo Jan 25, 2022
da9e173
add set_num_threads
luqiang-guo Jan 25, 2022
2d52493
Merge branch 'master' into dev_parallel_loop
luqiang-guo Jan 25, 2022
0e8f7cc
merge master
luqiang-guo Jan 25, 2022
fe7d975
fix
luqiang-guo Jan 25, 2022
08f76a7
Merge branch 'dev_add_onednn_binary' of https://github.com/Oneflow-In…
luqiang-guo Jan 25, 2022
73a8fe6
fi
luqiang-guo Jan 25, 2022
c0010d1
fix error
luqiang-guo Jan 26, 2022
0baa117
Merge branch 'dev_parallel_loop' of https://github.com/Oneflow-Inc/on…
luqiang-guo Jan 26, 2022
d3ca51e
Merge branch 'dev_parallel_loop' into dev_add_onednn_binary
luqiang-guo Jan 26, 2022
e01fc91
fix
luqiang-guo Jan 26, 2022
8d5b327
refine
luqiang-guo Jan 26, 2022
6719ea1
refine
luqiang-guo Jan 26, 2022
a10e6ac
fix
luqiang-guo Jan 26, 2022
494957f
refine
luqiang-guo Jan 26, 2022
62fcd18
fix
luqiang-guo Jan 26, 2022
731c1d9
refine
luqiang-guo Jan 26, 2022
84d8786
refine
luqiang-guo Jan 26, 2022
209216a
refine
luqiang-guo Jan 26, 2022
8d78ccf
refine
luqiang-guo Jan 26, 2022
2494287
refine
luqiang-guo Jan 26, 2022
7ee8afa
Merge branch 'master' into dev_parallel_loop
luqiang-guo Jan 26, 2022
9284d15
fix
luqiang-guo Jan 27, 2022
47b17d6
Merge branch 'master' into dev_parallel_loop
luqiang-guo Jan 27, 2022
624f6d2
Merge branch 'dev_parallel_loop' into dev_add_onednn_binary
luqiang-guo Jan 27, 2022
58e6ad6
Merge branch 'dev_add_onednn_binary' of https://github.com/Oneflow-In…
luqiang-guo Jan 27, 2022
0a568d8
refine
luqiang-guo Jan 27, 2022
d37bd5c
fix
luqiang-guo Jan 27, 2022
9d9df5b
Merge branch 'dev_parallel_loop' of https://github.com/Oneflow-Inc/on…
luqiang-guo Jan 27, 2022
3269a43
fix
luqiang-guo Jan 27, 2022
c1ed873
fix
luqiang-guo Jan 27, 2022
3d1f687
fix
luqiang-guo Jan 27, 2022
f9d3e18
fix
luqiang-guo Jan 27, 2022
16dd419
refine
luqiang-guo Jan 27, 2022
d1c6373
refine
luqiang-guo Jan 27, 2022
d99dd82
refine
luqiang-guo Jan 27, 2022
8fcd664
refine
luqiang-guo Jan 27, 2022
b8cc3f3
refine
luqiang-guo Jan 27, 2022
f333d1b
refine
luqiang-guo Jan 27, 2022
fe6ca0f
refine
luqiang-guo Jan 27, 2022
1377294
Merge branch 'master' into dev_parallel_loop
luqiang-guo Jan 27, 2022
7cb0cdb
fix
luqiang-guo Jan 28, 2022
42850af
Merge branch 'master' into dev_parallel_loop
luqiang-guo Jan 28, 2022
47cfa64
merge parallel loop
luqiang-guo Jan 28, 2022
06f500a
fix
luqiang-guo Jan 28, 2022
6940844
fix
luqiang-guo Jan 28, 2022
b55d21f
merge master
luqiang-guo Jan 29, 2022
7134dd8
refine
luqiang-guo Jan 29, 2022
337ef44
refine
luqiang-guo Feb 10, 2022
26c59cf
Merge branch 'master' into dev_add_onednn_binary
luqiang-guo Feb 10, 2022
543f934
auto format by CI
oneflow-ci-bot Feb 10, 2022
1778912
Merge branch 'dev_add_onednn_binary' of https://github.com/Oneflow-In…
luqiang-guo Feb 11, 2022
7f22f82
fix
luqiang-guo Feb 11, 2022
dbdc88d
rename mm_, dynamic_cast
luqiang-guo Feb 11, 2022
38ed4d1
Merge branch 'master' into dev_add_onednn_binary
luqiang-guo Feb 11, 2022
1251b0b
auto format by CI
oneflow-ci-bot Feb 11, 2022
af757ce
Merge branch 'master' into dev_add_onednn_binary
luqiang-guo Feb 12, 2022
c5a9296
fix MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGI…
luqiang-guo Feb 12, 2022
dcc4dd7
Merge branch 'master' into dev_add_onednn_binary
luqiang-guo Feb 12, 2022
040978c
fix 0-dim tensor
luqiang-guo Feb 13, 2022
5ec6d58
Merge branch 'dev_add_onednn_binary' of https://github.com/Oneflow-In…
luqiang-guo Feb 13, 2022
031a81f
Merge branch 'master' into dev_add_onednn_binary
luqiang-guo Feb 13, 2022
d58c8bd
fix onednn format tag
luqiang-guo Feb 14, 2022
d470a1d
Merge branch 'dev_add_onednn_binary' of https://github.com/Oneflow-In…
luqiang-guo Feb 14, 2022
5e09fd3
Merge branch 'master' into dev_add_onednn_binary
luqiang-guo Feb 14, 2022
9381fb0
auto format by CI
oneflow-ci-bot Feb 14, 2022
533b220
Merge branch 'master' into dev_add_onednn_binary
luqiang-guo Feb 14, 2022
637df74
Merge branch 'master' into dev_add_onednn_binary
oneflow-ci-bot Feb 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions oneflow/core/common/preprocessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ limitations under the License.

#define OF_PP_PAIR_SECOND(pair) OF_PP_INTERNAL_PAIR_SECOND(pair)

#define OF_PP_PAIR_THIRD(pair) OF_PP_INTERNAL_PAIR_THIRD(pair)

#define OF_PP_TUPLE_SIZE(t) OF_PP_INTERNAL_TUPLE_SIZE(t)

#define OF_PP_TUPLE_ELEM(n, t) OF_PP_INTERNAL_TUPLE_ELEM(n, t)
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/common/preprocessor_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ limitations under the License.
#define OF_PP_INTERNAL_PAIR_FIRST_I(t) OF_PP_INTERNAL_FIRST_ARG t
#define OF_PP_INTERNAL_PAIR_SECOND(t) OF_PP_INTERNAL_PAIR_SECOND_I(t)
#define OF_PP_INTERNAL_PAIR_SECOND_I(t) OF_PP_INTERNAL_SECOND_ARG t
#define OF_PP_INTERNAL_PAIR_THIRD(t) OF_PP_INTERNAL_PAIR_THIRD_I(t)
#define OF_PP_INTERNAL_PAIR_THIRD_I(t) OF_PP_INTERNAL_THIRD_ARG t

#define OF_PP_INTERNAL_FIRST_ARG(x, ...) x
#define OF_PP_INTERNAL_SECOND_ARG(x, y, ...) y
#define OF_PP_INTERNAL_THIRD_ARG(x, y, z, ...) z

#define OF_PP_INTERNAL_MAKE_TUPLE(...) (__VA_ARGS__)
#define OF_PP_INTERNAL_MAKE_TUPLE_SEQ(...) (OF_PP_INTERNAL_MAKE_TUPLE(__VA_ARGS__))
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/ep/cpu/primitive/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class AddOneDnnImpl : public Add {
for (int i = 1; i < arity; i++) {
if (srcs[i] == dst) { LOG(FATAL) << "Only the first parameter can be operated inplace"; }
}
CpuStream* cpu_stream = stream->As<CpuStream>();
size_t num_threads = dynamic_cast<CpuDevice*>(cpu_stream->device())->GetNumThreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dynamic_cast是在runtime的时候根据对象中vtpr所指向的vtbl中的rtti信息来做动态cast,效率比较低,在你这个场景中,如果已经确定这里是CpuDevice类型的话,推荐直接用static_cast

从另一方面讲,即使这里是符合使用dynamic_cast的条件,它的返回值有可能是null,那么就需要对返回值做判断,否则直接后面又用返回值去调用其它函数,是有问题的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我去自习研究下

Copy link
Contributor

@daquexian daquexian Feb 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是可以先实现上次说的 CpuStream/CudaStream 的 device() 方法直接返回 CpuDevice/CudaDevice 的功能,这样这里就没必要 cast 了

CpuNumThreadsGuard guard(num_threads);
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved

dnnl::engine* onednn_engine = stream->As<CpuStream>()->onednn_engine();
dnnl::stream* onednn_stream = stream->As<CpuStream>()->onednn_stream();

Expand Down
224 changes: 223 additions & 1 deletion oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ limitations under the License.
*/

#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cpu/primitive/binary_functor.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
#include "oneflow/core/ndarray/ndarray_util.h"
#include "oneflow/core/ndarray/xpu_var_ndarray.h"
#include "oneflow/core/ep/cpu/cpu_stream.h"
#include "oneflow/core/ep/cpu/cpu_device.h"

namespace oneflow {

Expand All @@ -34,9 +37,29 @@ T GetValue(Scalar value) {
return value.Value<T>();
}

template<>
int8_t GetValue<int8_t>(Scalar value) {
return static_cast<int8_t>(GetValue<int64_t>(value));
}

template<>
int32_t GetValue<int32_t>(Scalar value) {
return static_cast<int32_t>(GetValue<int64_t>(value));
}

template<>
uint8_t GetValue<uint8_t>(Scalar value) {
return static_cast<uint8_t>(GetValue<uint64_t>(value));
}

template<>
float16 GetValue<float16>(Scalar value) {
return static_cast<float16>(GetValue<float>(value));
return static_cast<float16>(GetValue<double>(value));
}

template<>
float GetValue<float>(Scalar value) {
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved
return static_cast<float>(GetValue<double>(value));
}

template<BinaryOp binary_op, typename Src, typename Dst,
Expand Down Expand Up @@ -67,6 +90,7 @@ class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
size_t num_src1_dims, const int64_t* src1_dims, const void* src1,
void* dst) override {

DimVector src0_dim_vec;
DimVector src1_dim_vec;
DimVector dst_dim_vec;
Expand Down Expand Up @@ -130,6 +154,171 @@ std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary() {
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR)

#ifdef WITH_ONEDNN

inline void OneDnnBroadcastDims(dnnl::memory::dims& src0, size_t num_src0_dims,
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved
const int64_t* src0_dims, dnnl::memory::dims& src1,
size_t num_src1_dims, const int64_t* src1_dims,
dnnl::memory::dims& dst) {
const int64_t num_dims = dst.size();
const int64_t num_src0_padding_dims = num_dims - num_src0_dims;
const int64_t num_src1_padding_dims = num_dims - num_src1_dims;
for (int64_t i = 0; i < num_dims; i++) {
size_t src0_dim = i < num_src0_padding_dims ? 1 : src0_dims[i - num_src0_padding_dims];
size_t src1_dim = i < num_src1_padding_dims ? 1 : src1_dims[i - num_src1_padding_dims];
CHECK((src0_dim == src1_dim || src0_dim == 1 || src1_dim == 1));
src0[i] = src0_dim;
src1[i] = src1_dim;
dst[i] = std::max(src0_dim, src1_dim);
}
}

template<typename T, dnnl::algorithm algorithm, dnnl::memory::data_type src_onednn,
dnnl::memory::data_type dst_onednn>
class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
public:
OF_DISALLOW_COPY_AND_MOVE(OneDnnBroadcastElementwiseBinaryImpl);
OneDnnBroadcastElementwiseBinaryImpl(){};
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved
~OneDnnBroadcastElementwiseBinaryImpl() override = default;

void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,
const void* src1, void* dst) override {
T scalar_val = GetValue<T>(src0);
const int64_t src0_dims = 1;
Launch(stream, num_src1_dims, src1_dims, src1, 1, &src0_dims, &scalar_val, dst);
}
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
Scalar src1, void* dst) override {
T scalar_val = GetValue<T>(src1);
const int64_t src1_dims = 1;
Launch(stream, num_src0_dims, src0_dims, src0, 1, &src1_dims, &scalar_val, dst);
}
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
size_t num_src1_dims, const int64_t* src1_dims, const void* src1,
void* dst) override {
CpuStream* cpu_stream = stream->As<CpuStream>();
size_t num_threads = dynamic_cast<CpuDevice*>(cpu_stream->device())->GetNumThreads();
CpuNumThreadsGuard guard(num_threads);
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved

dnnl::engine* onednn_engine = stream->As<CpuStream>()->onednn_engine();
dnnl::stream* onednn_stream = stream->As<CpuStream>()->onednn_stream();
size_t num_dims = std::max(num_src0_dims, num_src1_dims);
dnnl::memory::dims src_0_dims(num_dims);
dnnl::memory::dims src_1_dims(num_dims);
dnnl::memory::dims dst_dims(num_dims);

const void* mm_src0 = nullptr;
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved
const void* mm_src1 = nullptr;

if (src1 == dst) {
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved
mm_src0 = src1;
mm_src1 = src0;
OneDnnBroadcastDims(src_0_dims, num_src1_dims, src1_dims, src_1_dims, num_src0_dims,
luqiang-guo marked this conversation as resolved.
Show resolved Hide resolved
src0_dims, dst_dims);
} else {
mm_src0 = src0;
mm_src1 = src1;
OneDnnBroadcastDims(src_0_dims, num_src0_dims, src0_dims, src_1_dims, num_src1_dims,
src1_dims, dst_dims);
}

auto src_0_md = dnnl::memory::desc(src_0_dims, src_onednn,
static_cast<dnnl::memory::format_tag>(num_dims + 1));
auto src_1_md = dnnl::memory::desc(src_1_dims, src_onednn,
static_cast<dnnl::memory::format_tag>(num_dims + 1));
auto dst_md = dnnl::memory::desc(dst_dims, dst_onednn,
static_cast<dnnl::memory::format_tag>(num_dims + 1));

auto src_0_mem = dnnl::memory(src_0_md, *onednn_engine, (void*)mm_src0);
auto src_1_mem = dnnl::memory(src_1_md, *onednn_engine, (void*)mm_src1);
auto dst_mem = dnnl::memory(dst_md, *onednn_engine, dst);

auto binary_d = dnnl::binary::desc(algorithm, src_0_md, src_1_md, dst_md);
auto binary_pd = dnnl::binary::primitive_desc(binary_d, *onednn_engine);
auto binary_prim = dnnl::binary(binary_pd);

std::unordered_map<int, dnnl::memory> binary_args{
{DNNL_ARG_SRC_0, src_0_mem}, {DNNL_ARG_SRC_1, src_1_mem}, {DNNL_ARG_DST, dst_mem}};

binary_prim.execute(*onednn_stream, binary_args);
onednn_stream->wait();
}
};

#define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8, int8_t) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kUInt8, uint8_t) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f16, DataType::kFloat16, float16)

// OneDNN binary op does not support s32
// CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ

#define CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \
CPU_PRIMITIVE_INT32_TYPE_SEQ \
CPU_PRIMITIVE_INT64_TYPE_SEQ

#define BINARY_ONEDNN_ADD OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd, dnnl::algorithm::binary_add)
#define BINARY_ONEDNN_SUB OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub, dnnl::algorithm::binary_sub)
#define BINARY_ONEDNN_MUL OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul, dnnl::algorithm::binary_mul)
#define BINARY_ONEDNN_DIV OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv, dnnl::algorithm::binary_div)
#define BINARY_ONEDNN_MAX OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax, dnnl::algorithm::binary_max)
#define BINARY_ONEDNN_MIN OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin, dnnl::algorithm::binary_min)

#define BINARY_ONEDNN_EQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual, dnnl::algorithm::binary_eq)
#define BINARY_ONEDNN_NE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual, dnnl::algorithm::binary_ne)
#define BINARY_ONEDNN_LT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan, dnnl::algorithm::binary_lt)
#define BINARY_ONEDNN_LE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual, dnnl::algorithm::binary_le)
#define BINARY_ONEDNN_GT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan, dnnl::algorithm::binary_gt)
#define BINARY_ONEDNN_GE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual, dnnl::algorithm::binary_ge)

#define BINARY_MATH_OP_ONEDNN_PAIR \
BINARY_ONEDNN_ADD \
BINARY_ONEDNN_SUB \
BINARY_ONEDNN_MUL \
BINARY_ONEDNN_DIV \
BINARY_ONEDNN_MAX \
BINARY_ONEDNN_MIN

#define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR \
BINARY_ONEDNN_EQ \
BINARY_ONEDNN_NE \
BINARY_ONEDNN_LT \
BINARY_ONEDNN_LE \
BINARY_ONEDNN_GT \
BINARY_ONEDNN_GE

#define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_UNIMPLEMENTED \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd, AND) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR)

template<typename T, dnnl::algorithm algorithm, dnnl::memory::data_type src_onednn,
dnnl::memory::data_type dst_onednn>
std::unique_ptr<BroadcastElementwiseBinary> NewOneDnnBroadcastElementwiseBinary() {
return std::unique_ptr<BroadcastElementwiseBinary>(
new OneDnnBroadcastElementwiseBinaryImpl<T, algorithm, src_onednn, dst_onednn>());
}

#define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op_pair, data_type_pair) \
{std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \
NewOneDnnBroadcastElementwiseBinary< \
OF_PP_PAIR_THIRD(data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair), \
OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>},

#define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op_pair, src_data_type_pair, dst_data_type_pair) \
{std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(src_data_type_pair), \
OF_PP_PAIR_SECOND(src_data_type_pair)), \
NewOneDnnBroadcastElementwiseBinary< \
OF_PP_PAIR_THIRD(src_data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair), \
OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>},

#endif // WITH_ONEDNN

class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl);
Expand Down Expand Up @@ -158,6 +347,38 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
&NdarrayUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(src_data_type_pair)>::OF_PP_CAT( \
Broadcast, OF_PP_PAIR_SECOND(binary_op_pair))>},

#ifdef WITH_ONEDNN
static const std::map<std::tuple<BinaryOp, DataType, DataType>,
std::function<std::unique_ptr<BroadcastElementwiseBinary>()>>
new_broadcast_elementwise_binary_handle{
// For oneDNN binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_ONEDNN_PAIR,
CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ)
// For OneDNN comparasion binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR, CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ,
CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ)
// OneDNN unimplemented binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow, Pow),
NDARRAY_BINARY_TYPE_SEQ)
// OneDNN unimplemented comparasion binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_COMPARISION_OP_ONEDNN_UNIMPLEMENTED, NDARRAY_BINARY_TYPE_SEQ,
CPU_PRIMITIVE_BOOL_TYPE_SEQ)
// OneDNN unimplemented data type binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_NDARRAY_PAIR,
CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ)
// OneDNN unimplemented data type comparasion binary op
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_COMPARISION_OP_NDARRAY_PAIR,
CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)};
#else
static const std::map<std::tuple<BinaryOp, DataType, DataType>,
std::function<std::unique_ptr<BroadcastElementwiseBinary>()>>
new_broadcast_elementwise_binary_handle{
Expand All @@ -167,6 +388,7 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_COMPARISION_OP_NDARRAY_PAIR, NDARRAY_BINARY_TYPE_SEQ,
CPU_PRIMITIVE_BOOL_TYPE_SEQ)};
#endif

#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/ep/cpu/primitive/type_seq.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ limitations under the License.
#define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)

#define CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool)
#define CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8)
#define CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ \
Expand Down
Loading