forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_static_runtime.cc
46 lines (39 loc) · 1.57 KB
/
test_static_runtime.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <gtest/gtest.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include "deep_wide_pt.h"
TEST(StaticRuntime, TrivialModel) {
torch::jit::Module mod = getTrivialScriptModel();
auto a = torch::randn({2, 2});
auto b = torch::randn({2, 2});
auto c = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({a, b, c});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}
TEST(StaticRuntime, DeepWide) {
const int embedding_size = 32;
const int num_features = 50;
torch::jit::Module mod = getDeepAndWideSciptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 5; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
at::Tensor output_1 = mod.forward(inputs).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({ad_emb_packed, user_emb, wide});
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}
}
}