Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
copyfrom test back
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Jun 15, 2018
1 parent 8f8990e commit 7f9dac4
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions tests/cpp/operator/mkldnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,27 @@ void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
printf(")\n");
}

TEST(MKLDNN_NDArray, CopyFrom) {
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(InitDefaultArray);
for (auto in_arr : in_arrs) {
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds,
InitDefaultArray);
for (auto out_arr : out_arrs) {
if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView())
in_arr.arr = in_arr.arr.Reorder2Default();
const mkldnn::memory *mem = in_arr.arr.GetMKLDNNData();
out_arr.arr.CopyFrom(*mem);
MKLDNNStream::Get()->Submit();
std::vector<NDArray *> inputs(1);
inputs[0] = &in_arr.arr;
VerifyCopyResult(inputs, out_arr.arr);
}
}
}

void TestUnaryOp(const OpAttrs &attrs, InitFunc init_fn, VerifyFunc verify_fn) {
std::vector<NDArray*> inputs(1);
std::vector<NDArray*> outputs(1);
Expand Down Expand Up @@ -731,27 +752,6 @@ void TestBinaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
}
}

TEST(MKLDNN_NDArray, CopyFrom) {
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(InitDefaultArray);
for (auto in_arr : in_arrs) {
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds,
InitDefaultArray);
for (auto out_arr : out_arrs) {
if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView())
in_arr.arr = in_arr.arr.Reorder2Default();
const mkldnn::memory *mem = in_arr.arr.GetMKLDNNData();
out_arr.arr.CopyFrom(*mem);
MKLDNNStream::Get()->Submit();
std::vector<NDArray *> inputs(1);
inputs[0] = &in_arr.arr;
VerifyCopyResult(inputs, out_arr.arr);
}
}
}

TEST(IMPERATIVE, UnaryOp) {
OpAttrs attrs = GetCopyOp();
TestUnaryOp(attrs, InitDefaultArray, VerifyCopyResult);
Expand Down

0 comments on commit 7f9dac4

Please sign in to comment.