Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[primitive] add binary tests #8109

Merged
merged 32 commits into from
May 7, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1f2cd7d
[Primitive] Unit test
liujuncheng Apr 26, 2022
a5fd294
Merge branch 'master' into dev_primitive_unit_test
liujuncheng Apr 27, 2022
9a62b31
add copy_nd test
guo-ran Apr 27, 2022
38475ce
add test
guo-ran Apr 28, 2022
518684a
add softmax test
guo-ran Apr 29, 2022
71579b5
add test
guo-ran Apr 29, 2022
5d7cd02
binary test
guo-ran Apr 29, 2022
f92d31a
Merge branch 'master' into dev_add_primitive_tests
guo-ran May 1, 2022
3205ac7
add broadcast
guo-ran May 1, 2022
d5c82b9
fix softmax test
guo-ran May 1, 2022
3eff841
add
guo-ran May 1, 2022
8f1e03e
refine
guo-ran May 1, 2022
194a208
add test
guo-ran May 1, 2022
adc9904
add test
guo-ran May 1, 2022
426dc34
Merge branch 'dev_add_primitive_tests' of https://github.com/Oneflow-…
guo-ran May 1, 2022
eab44fe
add test
guo-ran May 1, 2022
b94388a
Merge branch 'dev_add_primitive_tests' of work24:/home/guoran/git_rep…
guo-ran May 1, 2022
e719b91
Merge branch 'master' into dev_add_primitive_tests
guo-ran May 1, 2022
f7723ea
fix of_tidy
guo-ran May 1, 2022
919447b
Merge branch 'dev_add_primitive_tests' of work24:/home/guoran/git_rep…
guo-ran May 1, 2022
5e13dc6
Merge branch 'dev_add_primitive_tests' of https://github.com/Oneflow-…
guo-ran May 1, 2022
1b07bff
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
guo-ran May 1, 2022
5f05ff1
revert copy_nd softmax
guo-ran May 5, 2022
78bd8a8
fix float (#8146)
luqiang-guo May 6, 2022
4aed021
Merge branch 'master' into dev_add_primitive_tests
guo-ran May 6, 2022
cb3ee1f
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] May 6, 2022
8f37c28
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] May 6, 2022
8d6ac76
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] May 6, 2022
6fcdada
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] May 6, 2022
6bb3d34
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] May 7, 2022
dd18461
Merge branch 'master' into dev_add_primitive_tests
guo-ran May 7, 2022
b416375
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] May 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,19 @@ class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
}
};

#define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8, int8_t) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool) \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kUInt8, uint8_t) \
Copy link
Contributor Author

@guo-ran guo-ran May 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onednn实现的int8和uint8类型过不了primitive单测 @luqiang-guo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

计算策略不同,onednn在溢出的时候是截断处理,torch 直接溢出

OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float) \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float类型过不了left_scalar和right_scalar的单测,是不是实现时没考虑位置不同计算的区别 @luqiang-guo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复

OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f16, DataType::kFloat16, float16)
#define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool)
// OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float)

// OneDNN binary op does not support s32
// CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ

#define CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \
CPU_PRIMITIVE_INT8_TYPE_SEQ \
CPU_PRIMITIVE_UINT8_TYPE_SEQ \
CPU_PRIMITIVE_INT32_TYPE_SEQ \
CPU_PRIMITIVE_INT64_TYPE_SEQ

Expand Down
270 changes: 270 additions & 0 deletions oneflow/core/ep/test/primitive/binary_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <gtest/gtest.h>
#include "oneflow/core/ep/test/primitive/primitive_test.h"
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/include/primitive/memcpy.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include <Eigen/Core>
#include <unsupported/Eigen/CXX11/Tensor>

namespace oneflow {

namespace ep {

namespace primitive {

namespace test {

namespace {

template<typename T>
Scalar GetScalar(T value) {
return Scalar(value);
}

template<>
Scalar GetScalar<Eigen::half>(Eigen::half value) {
return Scalar(static_cast<float>(value));
}

template<BinaryOp binary_op, DataType src_data_type, typename Src, DataType dst_data_type,
typename Dst>
void TestElementwiseBroadcastBinary(DeviceManagerRegistry* registry,
const std::set<DeviceType>& device_types, int test_type) {
const int num_axes = 4;
const int broadcast_dim0 = 16;
const int broadcast_dim1 = 3;
const int broadcast_dim2 = 4;
const int broadcast_dim3 = 8;
bool is_broadcast = false;
bool left_scalar = false;
bool right_scalar = false;
if (test_type == 0) {
// do nothing
} else if (test_type == 1) {
is_broadcast = true;
} else if (test_type == 2) {
left_scalar = true;
} else if (test_type == 3) {
right_scalar = true;
} else {
UNIMPLEMENTED();
}
const int a_dim0 = left_scalar ? 1 : broadcast_dim0;
const int a_dim1 = left_scalar ? 1 : broadcast_dim1;
const int a_dim2 = left_scalar ? 1 : broadcast_dim2;
const int a_dim3 = left_scalar ? 1 : (is_broadcast ? 1 : broadcast_dim3);
const int b_dim0 = right_scalar ? 1 : broadcast_dim0;
const int b_dim1 = right_scalar ? 1 : (is_broadcast ? 1 : broadcast_dim1);
const int b_dim2 = right_scalar ? 1 : broadcast_dim2;
const int b_dim3 = right_scalar ? 1 : broadcast_dim3;
const int a_broadcast0 = left_scalar ? broadcast_dim0 : 1;
const int a_broadcast1 = left_scalar ? broadcast_dim1 : 1;
const int a_broadcast2 = left_scalar ? broadcast_dim2 : 1;
const int a_broadcast3 = left_scalar ? broadcast_dim3 : (is_broadcast ? broadcast_dim3 : 1);
const int b_broadcast0 = right_scalar ? broadcast_dim0 : 1;
const int b_broadcast1 = right_scalar ? broadcast_dim1 : (is_broadcast ? broadcast_dim1 : 1);
const int b_broadcast2 = right_scalar ? broadcast_dim2 : 1;
const int b_broadcast3 = right_scalar ? broadcast_dim3 : 1;
const Eigen::array<int, 4> a_broadcast = {a_broadcast0, a_broadcast1, a_broadcast2, a_broadcast3};
const Eigen::array<int, 4> b_broadcast = {b_broadcast0, b_broadcast1, b_broadcast2, b_broadcast3};
Eigen::Tensor<Src, 4, Eigen::RowMajor> a(a_dim0, a_dim1, a_dim2, a_dim3);
Eigen::Tensor<Src, 4, Eigen::RowMajor> b(b_dim0, b_dim1, b_dim2, b_dim3);
Eigen::Tensor<Dst, 4, Eigen::RowMajor> c(broadcast_dim0, broadcast_dim1, broadcast_dim2,
broadcast_dim3);
a.setRandom();
b.setRandom();
if (binary_op == BinaryOp::kAdd) {
c = (a.broadcast(a_broadcast) + b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kSub) {
c = (a.broadcast(a_broadcast) - b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kMul) {
c = (a.broadcast(a_broadcast) * b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kDiv) {
Eigen::Tensor<Src, 4, Eigen::RowMajor> constant_value(b_dim0, b_dim1, b_dim2, b_dim3);
// avoid div 0
if (src_data_type == kInt8 || src_data_type == kUInt8) {
int rand_value = std::rand() % 127;
constant_value.setConstant(static_cast<Src>(rand_value));
b = constant_value;
} else {
constant_value.setConstant(static_cast<Src>(1));
b += constant_value;
}
c = (a.broadcast(a_broadcast) / b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kMax) {
c = (a.broadcast(a_broadcast).cwiseMax(b.broadcast(b_broadcast))).template cast<Dst>();
} else if (binary_op == BinaryOp::kMin) {
c = (a.broadcast(a_broadcast).cwiseMin(b.broadcast(b_broadcast))).template cast<Dst>();
} else if (binary_op == BinaryOp::kEqual) {
c = (a.broadcast(a_broadcast) == b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kNotEqual) {
c = (a.broadcast(a_broadcast) != b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kLessThan) {
c = (a.broadcast(a_broadcast) < b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kLessEqual) {
c = (a.broadcast(a_broadcast) <= b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kGreaterThan) {
c = (a.broadcast(a_broadcast) > b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kGreaterEqual) {
c = (a.broadcast(a_broadcast) >= b.broadcast(b_broadcast)).template cast<Dst>();
} else if (binary_op == BinaryOp::kLogicalAnd) {
c = (a.broadcast(a_broadcast).template cast<bool>()
&& b.broadcast(b_broadcast).template cast<bool>())
.template cast<Dst>();
} else if (binary_op == BinaryOp::kLogicalOr) {
c = (a.broadcast(a_broadcast).template cast<bool>()
|| b.broadcast(b_broadcast).template cast<bool>())
.template cast<Dst>();
} else if (binary_op == BinaryOp::kLogicalXor) {
c = (a.broadcast(a_broadcast).template cast<bool>()
^ b.broadcast(b_broadcast).template cast<bool>())
.template cast<Dst>();
} else {
UNIMPLEMENTED();
}
std::vector<int64_t> a_dims = {a.dimension(0), a.dimension(1), a.dimension(2), a.dimension(3)};
std::vector<int64_t> b_dims = {b.dimension(0), b.dimension(1), b.dimension(2), b.dimension(3)};
std::vector<int64_t> c_dims = {c.dimension(0), c.dimension(1), c.dimension(2), c.dimension(3)};
int64_t a_size = a.size() * sizeof(Src);
int64_t b_size = b.size() * sizeof(Src);
int64_t c_size = c.size() * sizeof(Dst);

for (const auto& device_type : device_types) {
LOG(ERROR) << "device " << device_type << " dtype " << src_data_type << " binary " << binary_op
<< " test " << test_type;
auto device = registry->GetDevice(device_type, 0);
ep::test::PinnedMemoryGuard input_a(device.get(), a_size);
ep::test::PinnedMemoryGuard input_b(device.get(), b_size);
std::memcpy(input_a.ptr(), a.data(), a_size);
std::memcpy(input_b.ptr(), b.data(), b_size);

ep::test::PinnedMemoryGuard output(device.get(), c_size);
ep::test::DeviceMemoryGuard device_a(device.get(), a_size);
ep::test::DeviceMemoryGuard device_b(device.get(), b_size);
ep::test::DeviceMemoryGuard device_c(device.get(), c_size);
ep::test::StreamGuard stream(device.get());
std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);
std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);
std::unique_ptr<BroadcastElementwiseBinary> binary =
NewPrimitive<BroadcastElementwiseBinaryFactory>(device_type, binary_op, src_data_type,
dst_data_type, num_axes);
ASSERT_TRUE(d2h.operator bool());
ASSERT_TRUE(h2d.operator bool());
ASSERT_TRUE(binary.operator bool());
h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size);
h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size);
if (left_scalar) {
Src a_value = *reinterpret_cast<Src*>(input_a.ptr());
binary->Launch(stream.stream(), GetScalar(a_value), num_axes, b_dims.data(), device_b.ptr(),
device_c.ptr());
} else if (right_scalar) {
Src b_value = *reinterpret_cast<Src*>(input_b.ptr());
binary->Launch(stream.stream(), num_axes, a_dims.data(), device_a.ptr(), GetScalar(b_value),
device_c.ptr());
} else {
binary->Launch(stream.stream(), num_axes, a_dims.data(), device_a.ptr(), num_axes,
b_dims.data(), device_b.ptr(), device_c.ptr());
}
d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size);
CHECK_JUST(stream.stream()->Sync());

Eigen::Map<Eigen::Matrix<Dst, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(c.data(),
c.size());
Eigen::Map<Eigen::Matrix<Dst, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out(
reinterpret_cast<Dst*>(output.ptr()), c.size());
ASSERT_TRUE(eigen_out.template isApprox(of_out));
}
}

template<BinaryOp binary_op, DataType src_data_type, typename Src, DataType dst_data_type,
typename Dst>
void TestElementwiseBroadcastBinary(DeviceManagerRegistry* registry,
const std::set<DeviceType>& device_types) {
TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>(
registry, device_types, 0);
TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>(
registry, device_types, 1);
TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>(
registry, device_types, 2);
TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>(
registry, device_types, 3);
}

template<BinaryOp binary_op>
void TestComputeBinary(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) {
TestElementwiseBroadcastBinary<binary_op, DataType::kInt8, int8_t, DataType::kInt8, int8_t>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kUInt8, uint8_t, DataType::kUInt8, uint8_t>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kInt32, int32_t, DataType::kInt32, int32_t>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kInt64, int64_t, DataType::kInt64, int64_t>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kDouble, double, DataType::kDouble, double>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kFloat, float, DataType::kFloat, float>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kFloat16, Eigen::half, DataType::kFloat16,
Eigen::half>(registry, device_types);
}

template<BinaryOp binary_op>
void TestLogicalBinary(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) {
TestElementwiseBroadcastBinary<binary_op, DataType::kInt8, int8_t, DataType::kBool, bool>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kUInt8, uint8_t, DataType::kBool, bool>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kInt32, int32_t, DataType::kBool, bool>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kInt64, int64_t, DataType::kBool, bool>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kDouble, double, DataType::kBool, bool>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kFloat, float, DataType::kBool, bool>(
registry, device_types);
TestElementwiseBroadcastBinary<binary_op, DataType::kFloat16, Eigen::half, DataType::kBool, bool>(
registry, device_types);
}

} // namespace

TEST_F(PrimitiveTest, TestBinary) {
TestComputeBinary<BinaryOp::kAdd>(&device_manager_registry_, available_device_types_);
TestComputeBinary<BinaryOp::kSub>(&device_manager_registry_, available_device_types_);
TestComputeBinary<BinaryOp::kMul>(&device_manager_registry_, available_device_types_);
TestComputeBinary<BinaryOp::kDiv>(&device_manager_registry_, available_device_types_);
TestComputeBinary<BinaryOp::kMax>(&device_manager_registry_, available_device_types_);
TestComputeBinary<BinaryOp::kMin>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kEqual>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kNotEqual>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kLessThan>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kLessEqual>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kGreaterThan>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kGreaterEqual>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kLogicalAnd>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kLogicalOr>(&device_manager_registry_, available_device_types_);
TestLogicalBinary<BinaryOp::kLogicalXor>(&device_manager_registry_, available_device_types_);
}

} // namespace test

} // namespace primitive

} // namespace ep

} // namespace oneflow
Loading