diff --git a/tests/cpp/operator/mkldnn.cc b/tests/cpp/operator/mkldnn.cc index 76872d5e6cfc..56f4d0753076 100644 --- a/tests/cpp/operator/mkldnn.cc +++ b/tests/cpp/operator/mkldnn.cc @@ -736,4 +736,42 @@ TEST(IMPERATIVE, BinaryOp) { TestBinaryOp(attrs, VerifySumResult); } +void VerifyCopyMemory(mkldnn::memory in_mem1, mkldnn::memory out_mem) { + float *in1 = static_cast(in_mem1.get_data_handle()); + float *out = static_cast(out_mem.get_data_handle()); + EXPECT_EQ(in_mem1.get_primitive_desc().get_size(), out_mem.get_primitive_desc().get_size()); + EXPECT_EQ(memcmp(in1, out, in_mem1.get_primitive_desc().get_size()),0); +} + + +TEST(MKLDNN_BASE, CreateMKLDNNMem) { + std::vector in_arrs = GetTestInputArrays(InitDefaultArray); + TestArrayShapes tas = GetTestArrayShapes(); + std::vector pds = tas.pds; + + MKLDNNStream *stream = MKLDNNStream::Get(); + + for (auto in_arr : in_arrs) { + + if (!SupportMKLDNN(in_arr.arr) || in_arr.arr.IsView()) + continue; + + std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds, + InitDefaultArray); + for (auto out_arr : out_arrs) { + auto in_mem = in_arr.arr.GetMKLDNNData(); + auto output_mem_t = CreateMKLDNNMem(out_arr.arr, in_mem->get_primitive_desc(), kWriteTo); + const_cast(out_arr.arr).CopyFrom(*in_mem); + CommitOutput(out_arr.arr, output_mem_t); + stream->Submit(); + VerifyCopyMemory(*in_mem, *out_arr.arr.GetMKLDNNData()); + } + + // in place + + } + + +} + #endif