Skip to content

Commit

Permalink
Mean tests both keep dims and not keep dims.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618855694
  • Loading branch information
alankelly authored and xnnpack-bot committed May 16, 2024
1 parent 7fabcac commit be6c51b
Showing 1 changed file with 49 additions and 30 deletions.
79 changes: 49 additions & 30 deletions test/static-mean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
#include <fp16/fp16.h>

namespace xnnpack {
template <class T> class MeanTestBase : public ::testing::Test {
template <class T>
class MeanTestBase : public ::testing::TestWithParam<bool> {
protected:
MeanTestBase() {
f32dist = std::uniform_real_distribution<float>(-1.0f, 1.0f);
Expand All @@ -43,7 +44,7 @@ template <class T> class MeanTestBase : public ::testing::Test {
auto end = std::unique(reduction_axes.begin(), reduction_axes.end());
reduction_axes.erase(end, reduction_axes.end());

auto shape_dist = std::uniform_int_distribution<size_t>(2, 15);
auto shape_dist = std::uniform_int_distribution<size_t>(2, 9);
input_shape.resize(num_input_dims);
std::generate(input_shape.begin(), input_shape.end(), [&]() { return shape_dist(rng); });
num_input_elements = std::accumulate(input_shape.cbegin(), input_shape.cend(), size_t(1), std::multiplies<size_t>());
Expand Down Expand Up @@ -76,6 +77,9 @@ template <class T> class MeanTestBase : public ::testing::Test {
using MeanTestF16 = MeanTestBase<uint16_t>;
using MeanTestF32 = MeanTestBase<float>;

INSTANTIATE_TEST_SUITE_P(KeepDims, MeanTestF16, testing::Bool());
INSTANTIATE_TEST_SUITE_P(KeepDims, MeanTestF32, testing::Bool());

TEST_F(MeanTestF16, define)
{
ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
Expand Down Expand Up @@ -160,8 +164,8 @@ TEST_F(MeanTestF32, define)
ASSERT_EQ(node->flags, 0);
}

TEST_F(MeanTestF16, matches_operator_api)
{
TEST_P(MeanTestF16, matches_operator_api) {
bool keep_dims = GetParam();
ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));

xnn_operator_t op = nullptr;
Expand All @@ -170,8 +174,9 @@ TEST_F(MeanTestF16, matches_operator_api)
std::fill(operator_output.begin(), operator_output.end(), UINT16_C(0x7E00) /* NaN */);
std::fill(subgraph_output.begin(), subgraph_output.end(), UINT16_C(0x7E00) /* NaN */);

uint32_t flags = keep_dims ? XNN_FLAG_KEEP_DIMS : 0;
// Call operator API.
const xnn_status status = xnn_create_mean_nd_f16(/*flags=*/XNN_FLAG_KEEP_DIMS, &op);
const xnn_status status = xnn_create_mean_nd_f16(flags, &op);
if (status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
}
Expand Down Expand Up @@ -202,17 +207,21 @@ TEST_F(MeanTestF16, matches_operator_api)
ASSERT_NE(input_id, XNN_INVALID_NODE_ID);

uint32_t output_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp16, output_shape.size(), output_shape.data(),
nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
int output_num_dims = input_shape.size();
if (!keep_dims) {
output_num_dims -= reduction_axes.size();
}
ASSERT_EQ(
xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp16, output_num_dims,
output_shape.data(), nullptr, /*external_id=*/1,
XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
ASSERT_NE(output_id, XNN_INVALID_NODE_ID);

ASSERT_EQ(xnn_status_success,
xnn_define_static_mean(
subgraph,
reduction_axes.size(), reduction_axes.data(),
input_id, output_id,
/*flags=*/XNN_FLAG_KEEP_DIMS));
xnn_define_static_mean(subgraph, reduction_axes.size(),
reduction_axes.data(), input_id, output_id,
flags));

xnn_runtime_t runtime = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
Expand All @@ -228,12 +237,14 @@ TEST_F(MeanTestF16, matches_operator_api)

// Check outputs match.
for (size_t i = 0; i < operator_output.size(); i++) {
ASSERT_EQ(subgraph_output[i], operator_output[i]);
float sub_out = fp16_ieee_to_fp32_value(subgraph_output[i]);
float op_out = fp16_ieee_to_fp32_value(operator_output[i]);
ASSERT_NEAR(sub_out, op_out, std::abs(0.05f * std::min(sub_out, op_out)));
}
}

TEST_F(MeanTestF32, matches_operator_api)
{
TEST_P(MeanTestF32, matches_operator_api) {
bool keep_dims = GetParam();
ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));

xnn_operator_t op = nullptr;
Expand All @@ -242,8 +253,9 @@ TEST_F(MeanTestF32, matches_operator_api)
std::fill(operator_output.begin(), operator_output.end(), nanf(""));
std::fill(subgraph_output.begin(), subgraph_output.end(), nanf(""));

uint32_t flags = keep_dims ? XNN_FLAG_KEEP_DIMS : 0;
// Call operator API.
const xnn_status status = xnn_create_mean_nd_f32(/*flags=*/XNN_FLAG_KEEP_DIMS, &op);
const xnn_status status = xnn_create_mean_nd_f32(flags, &op);
if (status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
}
Expand Down Expand Up @@ -274,17 +286,21 @@ TEST_F(MeanTestF32, matches_operator_api)
ASSERT_NE(input_id, XNN_INVALID_NODE_ID);

uint32_t output_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_shape.size(), output_shape.data(),
nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
int output_num_dims = input_shape.size();
if (!keep_dims) {
output_num_dims -= reduction_axes.size();
}
ASSERT_EQ(
xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_num_dims,
output_shape.data(), nullptr, /*external_id=*/1,
XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
ASSERT_NE(output_id, XNN_INVALID_NODE_ID);

ASSERT_EQ(xnn_status_success,
xnn_define_static_mean(
subgraph,
reduction_axes.size(), reduction_axes.data(),
input_id, output_id,
/*flags=*/XNN_FLAG_KEEP_DIMS));
xnn_define_static_mean(subgraph, reduction_axes.size(),
reduction_axes.data(), input_id, output_id,
flags));

xnn_runtime_t runtime = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
Expand All @@ -299,8 +315,8 @@ TEST_F(MeanTestF32, matches_operator_api)
ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));

// Check outputs match.
for (size_t i = 0; i < operator_output.size(); i++) {
ASSERT_EQ(subgraph_output[i], operator_output[i]);
for (int i = 0; i < subgraph_output.size(); ++i) {
ASSERT_NEAR(subgraph_output[i], operator_output[i], 2.5f * std::numeric_limits<float>::epsilon()) << " i " << i;
}
}

Expand Down Expand Up @@ -402,9 +418,12 @@ TEST_F(MeanTestF32, reshape_output_no_keep_dims)
ASSERT_NE(input_id, XNN_INVALID_NODE_ID);

uint32_t output_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_shape.size(), output_shape.data(),
nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
int output_num_dims = input_shape.size() - reduction_axes.size();
ASSERT_EQ(
xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_num_dims,
output_shape.data(), nullptr, /*external_id=*/1,
XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
ASSERT_NE(output_id, XNN_INVALID_NODE_ID);

ASSERT_EQ(xnn_status_success,
Expand Down

0 comments on commit be6c51b

Please sign in to comment.