-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[primitive] add binary tests #8109
Merged
Merged
Changes from 23 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
1f2cd7d
[Primitive] Unit test
liujuncheng a5fd294
Merge branch 'master' into dev_primitive_unit_test
liujuncheng 9a62b31
add copy_nd test
guo-ran 38475ce
add test
guo-ran 518684a
add softmax test
guo-ran 71579b5
add test
guo-ran 5d7cd02
binary test
guo-ran f92d31a
Merge branch 'master' into dev_add_primitive_tests
guo-ran 3205ac7
add broadcast
guo-ran d5c82b9
fix softmax test
guo-ran 3eff841
add
guo-ran 8f1e03e
refine
guo-ran 194a208
add test
guo-ran adc9904
add test
guo-ran 426dc34
Merge branch 'dev_add_primitive_tests' of https://github.com/Oneflow-…
guo-ran eab44fe
add test
guo-ran b94388a
Merge branch 'dev_add_primitive_tests' of work24:/home/guoran/git_rep…
guo-ran e719b91
Merge branch 'master' into dev_add_primitive_tests
guo-ran f7723ea
fix of_tidy
guo-ran 919447b
Merge branch 'dev_add_primitive_tests' of work24:/home/guoran/git_rep…
guo-ran 5e13dc6
Merge branch 'dev_add_primitive_tests' of https://github.com/Oneflow-…
guo-ran 1b07bff
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
guo-ran 5f05ff1
revert copy_nd softmax
guo-ran 78bd8a8
fix float (#8146)
luqiang-guo 4aed021
Merge branch 'master' into dev_add_primitive_tests
guo-ran cb3ee1f
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] 8f37c28
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] 8d6ac76
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] 6fcdada
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] 6bb3d34
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] dd18461
Merge branch 'master' into dev_add_primitive_tests
guo-ran b416375
Merge branch 'master' into dev_add_primitive_tests
mergify[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -237,18 +237,19 @@ class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { | |
} | ||
}; | ||
|
||
#define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ \ | ||
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8, int8_t) \ | ||
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool) \ | ||
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kUInt8, uint8_t) \ | ||
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float类型过不了left_scalar和right_scalar的单测,是不是实现时没考虑位置不同计算的区别 @luqiang-guo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修复 |
||
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f16, DataType::kFloat16, float16) | ||
#define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ \ | ||
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool) | ||
// OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float) | ||
|
||
// OneDNN binary op does not support s32 | ||
// CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ | ||
|
||
#define CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ \ | ||
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \ | ||
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ | ||
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ | ||
CPU_PRIMITIVE_INT8_TYPE_SEQ \ | ||
CPU_PRIMITIVE_UINT8_TYPE_SEQ \ | ||
CPU_PRIMITIVE_INT32_TYPE_SEQ \ | ||
CPU_PRIMITIVE_INT64_TYPE_SEQ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include <gtest/gtest.h> | ||
#include "oneflow/core/ep/test/primitive/primitive_test.h" | ||
#include "oneflow/core/ep/include/primitive/memset.h" | ||
#include "oneflow/core/ep/include/primitive/memcpy.h" | ||
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" | ||
#include <Eigen/Core> | ||
#include <unsupported/Eigen/CXX11/Tensor> | ||
|
||
namespace oneflow { | ||
|
||
namespace ep { | ||
|
||
namespace primitive { | ||
|
||
namespace test { | ||
|
||
namespace { | ||
|
||
template<typename T> | ||
Scalar GetScalar(const T& value) { | ||
return Scalar(value); | ||
} | ||
|
||
template<> | ||
Scalar GetScalar<Eigen::half>(const Eigen::half& value) { | ||
return Scalar(static_cast<float>(value)); | ||
} | ||
|
||
template<BinaryOp binary_op, DataType src_data_type, typename Src, DataType dst_data_type, | ||
typename Dst> | ||
void TestElementwiseBroadcastBinary(DeviceManagerRegistry* registry, | ||
const std::set<DeviceType>& device_types, int test_type) { | ||
const int num_axes = 4; | ||
const int broadcast_dim0 = 16; | ||
const int broadcast_dim1 = 3; | ||
const int broadcast_dim2 = 4; | ||
const int broadcast_dim3 = 8; | ||
bool is_broadcast = false; | ||
bool left_scalar = false; | ||
bool right_scalar = false; | ||
if (test_type == 0) { | ||
// do nothing | ||
} else if (test_type == 1) { | ||
is_broadcast = true; | ||
} else if (test_type == 2) { | ||
left_scalar = true; | ||
} else if (test_type == 3) { | ||
right_scalar = true; | ||
} else { | ||
UNIMPLEMENTED(); | ||
} | ||
const int a_dim0 = left_scalar ? 1 : broadcast_dim0; | ||
const int a_dim1 = left_scalar ? 1 : broadcast_dim1; | ||
const int a_dim2 = left_scalar ? 1 : broadcast_dim2; | ||
const int a_dim3 = left_scalar ? 1 : (is_broadcast ? 1 : broadcast_dim3); | ||
const int b_dim0 = right_scalar ? 1 : broadcast_dim0; | ||
const int b_dim1 = right_scalar ? 1 : (is_broadcast ? 1 : broadcast_dim1); | ||
const int b_dim2 = right_scalar ? 1 : broadcast_dim2; | ||
const int b_dim3 = right_scalar ? 1 : broadcast_dim3; | ||
const int a_broadcast0 = left_scalar ? broadcast_dim0 : 1; | ||
const int a_broadcast1 = left_scalar ? broadcast_dim1 : 1; | ||
const int a_broadcast2 = left_scalar ? broadcast_dim2 : 1; | ||
const int a_broadcast3 = left_scalar ? broadcast_dim3 : (is_broadcast ? broadcast_dim3 : 1); | ||
const int b_broadcast0 = right_scalar ? broadcast_dim0 : 1; | ||
const int b_broadcast1 = right_scalar ? broadcast_dim1 : (is_broadcast ? broadcast_dim1 : 1); | ||
const int b_broadcast2 = right_scalar ? broadcast_dim2 : 1; | ||
const int b_broadcast3 = right_scalar ? broadcast_dim3 : 1; | ||
const Eigen::array<int, 4> a_broadcast = {a_broadcast0, a_broadcast1, a_broadcast2, a_broadcast3}; | ||
const Eigen::array<int, 4> b_broadcast = {b_broadcast0, b_broadcast1, b_broadcast2, b_broadcast3}; | ||
Eigen::Tensor<Src, 4, Eigen::RowMajor> a(a_dim0, a_dim1, a_dim2, a_dim3); | ||
Eigen::Tensor<Src, 4, Eigen::RowMajor> b(b_dim0, b_dim1, b_dim2, b_dim3); | ||
Eigen::Tensor<Dst, 4, Eigen::RowMajor> c(broadcast_dim0, broadcast_dim1, broadcast_dim2, | ||
broadcast_dim3); | ||
a.setRandom(); | ||
b.setRandom(); | ||
if (binary_op == BinaryOp::kAdd) { | ||
c = (a.broadcast(a_broadcast) + b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kSub) { | ||
c = (a.broadcast(a_broadcast) - b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kMul) { | ||
c = (a.broadcast(a_broadcast) * b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kDiv) { | ||
Eigen::Tensor<Src, 4, Eigen::RowMajor> constant_value(b_dim0, b_dim1, b_dim2, b_dim3); | ||
// avoid div 0 | ||
if (src_data_type == kInt8 || src_data_type == kUInt8) { | ||
int rand_value = std::rand() % 127; | ||
constant_value.setConstant(static_cast<Src>(rand_value)); | ||
b = constant_value; | ||
} else { | ||
constant_value.setConstant(static_cast<Src>(1)); | ||
b += constant_value; | ||
} | ||
c = (a.broadcast(a_broadcast) / b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kMax) { | ||
c = (a.broadcast(a_broadcast).cwiseMax(b.broadcast(b_broadcast))).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kMin) { | ||
c = (a.broadcast(a_broadcast).cwiseMin(b.broadcast(b_broadcast))).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kEqual) { | ||
c = (a.broadcast(a_broadcast) == b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kNotEqual) { | ||
c = (a.broadcast(a_broadcast) != b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kLessThan) { | ||
c = (a.broadcast(a_broadcast) < b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kLessEqual) { | ||
c = (a.broadcast(a_broadcast) <= b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kGreaterThan) { | ||
c = (a.broadcast(a_broadcast) > b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kGreaterEqual) { | ||
c = (a.broadcast(a_broadcast) >= b.broadcast(b_broadcast)).template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kLogicalAnd) { | ||
c = (a.broadcast(a_broadcast).template cast<bool>() | ||
&& b.broadcast(b_broadcast).template cast<bool>()) | ||
.template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kLogicalOr) { | ||
c = (a.broadcast(a_broadcast).template cast<bool>() | ||
|| b.broadcast(b_broadcast).template cast<bool>()) | ||
.template cast<Dst>(); | ||
} else if (binary_op == BinaryOp::kLogicalXor) { | ||
c = (a.broadcast(a_broadcast).template cast<bool>() | ||
^ b.broadcast(b_broadcast).template cast<bool>()) | ||
.template cast<Dst>(); | ||
} else { | ||
UNIMPLEMENTED(); | ||
} | ||
std::vector<int64_t> a_dims = {a.dimension(0), a.dimension(1), a.dimension(2), a.dimension(3)}; | ||
std::vector<int64_t> b_dims = {b.dimension(0), b.dimension(1), b.dimension(2), b.dimension(3)}; | ||
std::vector<int64_t> c_dims = {c.dimension(0), c.dimension(1), c.dimension(2), c.dimension(3)}; | ||
int64_t a_size = a.size() * sizeof(Src); | ||
int64_t b_size = b.size() * sizeof(Src); | ||
int64_t c_size = c.size() * sizeof(Dst); | ||
|
||
for (const auto& device_type : device_types) { | ||
LOG(ERROR) << "device " << device_type << " dtype " << src_data_type << " binary " << binary_op | ||
<< " test " << test_type; | ||
auto device = registry->GetDevice(device_type, 0); | ||
ep::test::PinnedMemoryGuard input_a(device.get(), a_size); | ||
ep::test::PinnedMemoryGuard input_b(device.get(), b_size); | ||
std::memcpy(input_a.ptr(), a.data(), a_size); | ||
std::memcpy(input_b.ptr(), b.data(), b_size); | ||
|
||
ep::test::PinnedMemoryGuard output(device.get(), c_size); | ||
ep::test::DeviceMemoryGuard device_a(device.get(), a_size); | ||
ep::test::DeviceMemoryGuard device_b(device.get(), b_size); | ||
ep::test::DeviceMemoryGuard device_c(device.get(), c_size); | ||
ep::test::StreamGuard stream(device.get()); | ||
std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD); | ||
std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH); | ||
std::unique_ptr<BroadcastElementwiseBinary> binary = | ||
NewPrimitive<BroadcastElementwiseBinaryFactory>(device_type, binary_op, src_data_type, | ||
dst_data_type, num_axes); | ||
ASSERT_TRUE(d2h.operator bool()); | ||
ASSERT_TRUE(h2d.operator bool()); | ||
ASSERT_TRUE(binary.operator bool()); | ||
h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size); | ||
h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size); | ||
if (left_scalar) { | ||
Src a_value = *reinterpret_cast<Src*>(input_a.ptr()); | ||
binary->Launch(stream.stream(), GetScalar(a_value), num_axes, b_dims.data(), device_b.ptr(), | ||
device_c.ptr()); | ||
} else if (right_scalar) { | ||
Src b_value = *reinterpret_cast<Src*>(input_b.ptr()); | ||
binary->Launch(stream.stream(), num_axes, a_dims.data(), device_a.ptr(), GetScalar(b_value), | ||
device_c.ptr()); | ||
} else { | ||
binary->Launch(stream.stream(), num_axes, a_dims.data(), device_a.ptr(), num_axes, | ||
b_dims.data(), device_b.ptr(), device_c.ptr()); | ||
} | ||
d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size); | ||
CHECK_JUST(stream.stream()->Sync()); | ||
|
||
Eigen::Map<Eigen::Matrix<Dst, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(c.data(), | ||
c.size()); | ||
Eigen::Map<Eigen::Matrix<Dst, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out( | ||
reinterpret_cast<Dst*>(output.ptr()), c.size()); | ||
ASSERT_TRUE(eigen_out.template isApprox(of_out)); | ||
} | ||
} | ||
|
||
template<BinaryOp binary_op, DataType src_data_type, typename Src, DataType dst_data_type, | ||
typename Dst> | ||
void TestElementwiseBroadcastBinary(DeviceManagerRegistry* registry, | ||
const std::set<DeviceType>& device_types) { | ||
TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>( | ||
registry, device_types, 0); | ||
TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>( | ||
registry, device_types, 1); | ||
TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>( | ||
registry, device_types, 2); | ||
TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>( | ||
registry, device_types, 3); | ||
} | ||
|
||
template<BinaryOp binary_op> | ||
void TestComputeBinary(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) { | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kInt8, int8_t, DataType::kInt8, int8_t>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kUInt8, uint8_t, DataType::kUInt8, uint8_t>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kInt32, int32_t, DataType::kInt32, int32_t>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kInt64, int64_t, DataType::kInt64, int64_t>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kDouble, double, DataType::kDouble, double>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kFloat, float, DataType::kFloat, float>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kFloat16, Eigen::half, DataType::kFloat16, | ||
Eigen::half>(registry, device_types); | ||
} | ||
|
||
template<BinaryOp binary_op> | ||
void TestLogicalBinary(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) { | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kInt8, int8_t, DataType::kBool, bool>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kUInt8, uint8_t, DataType::kBool, bool>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kInt32, int32_t, DataType::kBool, bool>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kInt64, int64_t, DataType::kBool, bool>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kDouble, double, DataType::kBool, bool>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kFloat, float, DataType::kBool, bool>( | ||
registry, device_types); | ||
TestElementwiseBroadcastBinary<binary_op, DataType::kFloat16, Eigen::half, DataType::kBool, bool>( | ||
registry, device_types); | ||
} | ||
|
||
} // namespace | ||
|
||
TEST_F(PrimitiveTest, TestBinary) { | ||
TestComputeBinary<BinaryOp::kAdd>(&device_manager_registry_, available_device_types_); | ||
TestComputeBinary<BinaryOp::kSub>(&device_manager_registry_, available_device_types_); | ||
TestComputeBinary<BinaryOp::kMul>(&device_manager_registry_, available_device_types_); | ||
TestComputeBinary<BinaryOp::kDiv>(&device_manager_registry_, available_device_types_); | ||
TestComputeBinary<BinaryOp::kMax>(&device_manager_registry_, available_device_types_); | ||
TestComputeBinary<BinaryOp::kMin>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kEqual>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kNotEqual>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kLessThan>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kLessEqual>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kGreaterThan>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kGreaterEqual>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kLogicalAnd>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kLogicalOr>(&device_manager_registry_, available_device_types_); | ||
TestLogicalBinary<BinaryOp::kLogicalXor>(&device_manager_registry_, available_device_types_); | ||
} | ||
|
||
} // namespace test | ||
|
||
} // namespace primitive | ||
|
||
} // namespace ep | ||
|
||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
onednn实现的int8和uint8类型过不了primitive单测 @luqiang-guo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
计算策略不同,onednn在溢出的时候是截断处理,torch 直接溢出