Skip to content

Commit

Permalink
feat(aten::permute): Implement permute support
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jun 15, 2020
1 parent 461e2ca commit c7d6b49
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 28 deletions.
21 changes: 21 additions & 0 deletions core/conversion/converters/impl/shuffle.cpp
Expand Up @@ -59,6 +59,27 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}
}).pattern({
"aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());
auto new_order = args[1].unwrapToIntList().vec();

LOG_DEBUG("Shuffle to: " << util::toDims(new_order));

auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
nvinfer1::Permutation permute;
std::copy(new_order.begin(), new_order.end(), permute.order);
shuffle->setSecondTranspose(permute);
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}
});
Expand Down
122 changes: 94 additions & 28 deletions tests/core/converters/test_shuffle.cpp
Expand Up @@ -30,17 +30,87 @@ TEST(Converters, ATenFlattenConvertsCorrectly) {

// TODO: IR Parser doesnt work well with neg numbers
TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%3 : Tensor = aten::flatten(%0, %1, %2)
return (%3))IR";
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%3 : Tensor = aten::flatten(%0, %1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

auto g = std::make_shared<torch::jit::Graph>();
TEST(Converters, ATenReshapeConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=3]()
%2 : int = prim::Constant[value=2]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : Tensor = aten::reshape(%0, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenViewConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=3]()
%2 : int = prim::Constant[value=2]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : Tensor = aten::view(%0, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenPermuteConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[3, 0, 1, 2]]()
%3 : Tensor = aten::permute(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA});
auto in = at::randint(0, 5, {2, 3, 2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

Expand All @@ -52,19 +122,17 @@ TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenReshapeConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=3]()
%2 : int = prim::Constant[value=2]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : Tensor = aten::reshape(%0, %3)
return (%4))IR";
TEST(Converters, ATenPermute3DConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[0, 2, 1]]()
%3 : Tensor = aten::permute(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
auto in = at::randint(0, 5, {2, 2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

Expand All @@ -76,19 +144,17 @@ TEST(Converters, ATenReshapeConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenViewConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=3]()
%2 : int = prim::Constant[value=2]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : Tensor = aten::view(%0, %3)
return (%4))IR";
TEST(Converters, ATenPermute5DConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[3, 4, 0, 2, 1]]()
%3 : Tensor = aten::permute(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
auto in = at::randint(0, 5, {2, 2, 1, 2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

Expand Down

0 comments on commit c7d6b49

Please sign in to comment.