Skip to content

Commit

Permalink
feat(aten::Int): Lowers out aten::Int
Browse files Browse the repository at this point in the history
This commit adds a pass to lower out aten::[Int/Float/Bool],
aten::NumToTensor pairs w.o. exception.
We are assumming this is safe as there are similar
passes in PyTorch for ONNX lowering however the scope
of this rule is intentionally limited to avoid possible cases
where it is not safe. Therefore it should not be expected that
all aten::Int issues will be solved with this change and
the operator itself remains a limitation of TorchTRT

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 4, 2022
1 parent ba9f730 commit 908340f
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 0 deletions.
2 changes: 2 additions & 0 deletions core/lowering/lowering.cpp
@@ -1,6 +1,7 @@
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/create_functional_graphs.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/erase_number_types.h"
#include "torch/csrc/jit/passes/freeze_module.h"
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
Expand Down Expand Up @@ -63,6 +64,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::RemoveNOPs(g);
passes::AliasOperators(g);
passes::SiluToSigmoidMultipication(g);
passes::RemoveUnnecessaryCasts(g);
LOG_GRAPH(*g);
}

Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Expand Up @@ -23,6 +23,7 @@ cc_library(
"view_to_reshape.cpp",
"remove_dropout.cpp",
"remove_nops.cpp",
"remove_unnecessary_casts.cpp",
"silu_to_sigmoid_multiplication.cpp",
"unpack_addmm.cpp",
"unpack_batch_norm.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Expand Up @@ -27,6 +27,7 @@ void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
61 changes: 61 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
@@ -0,0 +1,61 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

#include <vector>

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {


// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just
// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) {
std::string int_cast_pattern = R"IR(
graph(%1: int):
%2: Tensor = aten::NumToTensor(%1)
%3: int = aten::Int(%2)
return (%3))IR";
std::string int_clean_pattern = R"IR(
graph(%1: int):
return (%1))IR";

std::string float_cast_pattern = R"IR(
graph(%1: float):
%2: Tensor = aten::NumToTensor(%1)
%3: float = aten::Float(%2)
return (%3))IR";
std::string float_clean_pattern = R"IR(
graph(%1: float):
return (%1))IR";

std::string bool_cast_pattern = R"IR(
graph(%1: bool):
%2: Tensor = aten::NumToTensor(%1)
%3: bool = aten::Bool(%2)
return (%3))IR";
std::string bool_clean_pattern = R"IR(
graph(%1: bool):
return (%1))IR";

torch::jit::SubgraphRewriter int_cast_rewriter;
int_cast_rewriter.RegisterRewritePattern(int_cast_pattern, int_clean_pattern);
int_cast_rewriter.runOnGraph(graph);

torch::jit::SubgraphRewriter float_cast_rewriter;
float_cast_rewriter.RegisterRewritePattern(float_cast_pattern, float_clean_pattern);
float_cast_rewriter.runOnGraph(graph);

torch::jit::SubgraphRewriter bool_cast_rewriter;
bool_cast_rewriter.RegisterRewritePattern(bool_cast_pattern, bool_clean_pattern);
bool_cast_rewriter.runOnGraph(graph);

LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
5 changes: 5 additions & 0 deletions tests/core/lowering/BUILD
Expand Up @@ -50,6 +50,10 @@ lowering_test(
name = "test_remove_detach_pass",
)

lowering_test(
name = "test_remove_unnecessary_casts",
)

lowering_test(
name = "test_view_to_reshape_pass",
)
Expand Down Expand Up @@ -81,6 +85,7 @@ test_suite(
":test_remove_detach_pass",
":test_view_to_reshape_pass",
":test_remove_dropout_pass",
":test_remove_unnecessary_casts",
":test_reduce_to_pass",
":test_reduce_gelu",
":test_unpack_hardswish",
Expand Down
79 changes: 79 additions & 0 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
@@ -0,0 +1,79 @@
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"

TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) {
std::string source_graph = R"IR(
graph(%1: int):
%2: Tensor = aten::NumToTensor(%1)
%3: int = aten::Int(%2)
%4: int = aten::add(%3, %3, %3)
return (%4))IR";
std::string target_graph = R"IR(
graph(%1: int):
%4: int = aten::add(%1, %1, %1)
return (%4))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) {
std::string source_graph = R"IR(
graph(%1: float):
%2: Tensor = aten::NumToTensor(%1)
%3: float = aten::Float(%2)
%4: float = aten::add(%3, %3, %3)
return (%3))IR";
std::string target_graph = R"IR(
graph(%1: float):
%4: float = aten::add(%1, %1, %1)
return (%4))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) {
std::string source_graph = R"IR(
graph(%1: bool):
%2: Tensor = aten::NumToTensor(%1)
%3: bool = aten::Bool(%2)
%4: bool = aten::__and__(%3, %3)
return (%3))IR";
std::string target_graph = R"IR(
graph(%1: bool):
%4: bool = aten::__and__(%1, %1)
return (%4))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

0 comments on commit 908340f

Please sign in to comment.