Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add boilerplate for CreateMKLDNNMem test
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Jun 14, 2018
1 parent ed353fa commit becedcc
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/cpp/operator/mkldnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,4 +736,42 @@ TEST(IMPERATIVE, BinaryOp) {
TestBinaryOp(attrs, VerifySumResult);
}

void VerifyCopyMemory(mkldnn::memory in_mem1, mkldnn::memory out_mem) {
float *in1 = static_cast<float*>(in_mem1.get_data_handle());
float *out = static_cast<float*>(out_mem.get_data_handle());
EXPECT_EQ(in_mem1.get_primitive_desc().get_size(), out_mem.get_primitive_desc().get_size());
EXPECT_EQ(memcmp(in1, out, in_mem1.get_primitive_desc().get_size()),0);
}


TEST(MKLDNN_BASE, CreateMKLDNNMem) {
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(InitDefaultArray);
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

MKLDNNStream *stream = MKLDNNStream::Get();

for (auto in_arr : in_arrs) {

if (!SupportMKLDNN(in_arr.arr) || in_arr.arr.IsView())
continue;

std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds,
InitDefaultArray);
for (auto out_arr : out_arrs) {
auto in_mem = in_arr.arr.GetMKLDNNData();
auto output_mem_t = CreateMKLDNNMem(out_arr.arr, in_mem->get_primitive_desc(), kWriteTo);
const_cast<NDArray &>(out_arr.arr).CopyFrom(*in_mem);
CommitOutput(out_arr.arr, output_mem_t);
stream->Submit();
VerifyCopyMemory(*in_mem, *out_arr.arr.GetMKLDNNData());
}

// in place

}


}

#endif

0 comments on commit becedcc

Please sign in to comment.