Skip to content

Commit

Permalink
feat(aten::Int): Adding a new pass to remove single use
Browse files Browse the repository at this point in the history
0D Tensors

Now we remove select more complex aten::Int cases found in
models such as BERT, like the following:

```
graph(%0: int):
      %1: Tensor = prim::Constant[value={8}]()
      %2: int = prim::Constant[value=1]()
      %3: Tensor = prim::NumToTensor(%0)
      %4: Tensor = aten::add(%1, %3, %2)
      %5: int = aten::Int(%4)
      %6: int = aten::add(%5, %5)
      return (%6)";

graph(%0: int):
      %1: int = prim::Constant[value=8]()
      %4: int = aten::add(%1, %0)
      %6: int = aten::add(%4, %4)
      return (%6)";
```

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 4, 2022
1 parent 908340f commit 46ac757
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 3 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Expand Up @@ -64,6 +64,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::RemoveNOPs(g);
passes::AliasOperators(g);
passes::SiluToSigmoidMultipication(g);
passes::RemoveSingleUse0DTensors(g);
passes::RemoveUnnecessaryCasts(g);
LOG_GRAPH(*g);
}
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Expand Up @@ -27,6 +27,7 @@ void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
35 changes: 35 additions & 0 deletions core/lowering/passes/remove_set_attrs.cpp
@@ -0,0 +1,35 @@
#include <stack>
#include <unordered_set>

#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/lowering/passes/passes.h"
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

void RemoveSetAttrs(const torch::jit::Module& mod, std::string method_name) {
auto g = mod.get_method(method_name).graph();

std::string set_attr_pattern = R"IR(
graph(%self, %0):
None = prim::SetAttr[name="_has_warned"](%self, %0)
return ())IR";
std::string no_set_attr_pattern = R"IR(
graph(%self, %0):
return ())IR";

// remove contiguous
torch::jit::SubgraphRewriter remove_set_attr;
remove_set_attr.RegisterRewritePattern(set_attr_pattern, no_set_attr_pattern);
remove_set_attr.runOnGraph(g);
LOG_GRAPH("Post remove contiguous: " << *g);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
114 changes: 114 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
@@ -1,4 +1,5 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
#include "torch/csrc/jit/ir/constants.h"

#include "core/util/prelude.h"

Expand Down Expand Up @@ -55,6 +56,119 @@ void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) {
LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph);
}

void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) {
if (it->kind() == torch::jit::prim::Constant) {
// Going from a constant and is single use means we can fuse
if (it->output()->type()->isSubtypeOf(c10::TensorType::get())) {
// Get the tensor stored in constant
at::Tensor t = *torch::jit::constant_as<at::Tensor>(it->output());
// If shape is 0D
if (t.sizes() == std::vector<int64_t>({})) {
LOG_GRAPH("Found a 0D Tensor: " << it->output()->debugName());
LOG_GRAPH("Number of uses: " << it->output()->uses().size());
// If the tensor is only used once
if (it->output()->uses().size() == 1) {
auto use = it->output()->uses()[0];
auto user = use.user;

// Is a NumToTensor / aten::[Int/Float] case
if (user->outputs().size() == 1 && user->outputs()[0]->type()->isSubtypeOf(c10::TensorType::get())) {
if (user->output()->uses().size() == 1) {
auto potential_cast = user->output()->uses()[0].user;
// The downstream user is aten::Int
if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int")
|| potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
LOG_GRAPH("Downstream user is aten::Int/aten::Float");
auto arg = use.offset;

for (size_t k = 0; k < user->inputs().size(); ++k) {
if (k != arg) {
if (user->inputs()[k]->type()->isSubtypeOf(c10::TensorType::get())) {
LOG_GRAPH("Input " << k << " is a Tensor");
if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) {
auto num_to_tensor = user->inputs()[k]->node();

LOG_GRAPH("Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
<< *(*it)
<< *num_to_tensor
<< *user
<< *potential_cast);

// Replace the Tensor Constant with a scalar constant
LOG_GRAPH("Deleting 0-dim Tensor: " << **it);
torch::jit::WithInsertPoint gaurd(*it);

auto new_const_val = g->insertConstant(t.item(), c10::nullopt, it->scope());
new_const_val->copyMetadata(it->output());
// How to determine the internal scalar type instead of assuming?
if (potential_cast->kind() == c10::aten::Int) {
new_const_val->setType(c10::IntType::get());
} else if (potential_cast->kind() == c10::aten::Float) {
new_const_val->setType(c10::FloatType::get());
}
it->output()->replaceAllUsesWith(new_const_val);
it.destroyCurrent();

LOG_GRAPH("New constant: " << *new_const_val->node());

// Delete NumToTensor
LOG_GRAPH("Deleting NumToTensor: " << *num_to_tensor);
num_to_tensor->output()->replaceAllUsesWith(num_to_tensor->inputs()[0]);
num_to_tensor->destroy();

// Change intermediate op output type
LOG_GRAPH(user->schema());

torch::jit::Node* new_node;
switch (user->kind()) {
// Use this to handle special cases where the scalar version of the intermediate operator
// has a different schema than the original
case c10::aten::add:
new_node = g->create(
user->kind(),
torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
default:
new_node = g->create(
user->kind(),
user->inputs(),
1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
}

LOG_GRAPH("New intermediate operation: " << *new_node);
LOG_GRAPH(new_node->schema());

// Delete aten::Int
LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast);
potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]);
potential_cast->destroy();
}
}
}
}
}
}
}
}
}
}
}
}
LOG_ERROR("Post removing single use 0-dim Tensor operations: " << *g);
}


} // namespace passes
} // namespace lowering
} // namespace core
Expand Down
84 changes: 81 additions & 3 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
Expand Up @@ -22,7 +22,7 @@ TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) {
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());
Expand All @@ -46,7 +46,7 @@ TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) {
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());
Expand All @@ -70,7 +70,85 @@ TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) {
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsIntCorrectly) {
std::string source_graph = R"IR(
graph(%0: int):
%1: Tensor = prim::Constant[value=[8]]()
%2: int = prim::Constant[value=1]()
%3: Tensor = prim::NumToTensor(%0)
%4: Tensor = aten::add(%1, %3, %2)
%5: int = aten::Int(%4)
%6: int = aten::add(%5, %5)
return (%6))IR";
std::string target_graph = R"IR(
graph(%0: int):
%1: int = prim::Constant[value=8]()
%4: int = aten::add(%1, %0)
%6: int = aten::add(%4, %4)
return (%6))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());

auto first_op = *(sg->block()->nodes().begin());
torch::jit::WithInsertPoint guard(first_op);
torch::jit::Value* r = sg->insertConstant(
c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
first_op->destroy();

torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {
std::string source_graph = R"IR(
graph(%0: float):
%1: Tensor = prim::Constant[value=[8.]]()
%2: float = prim::Constant[value=1.]()
%3: Tensor = prim::NumToTensor(%0)
%4: Tensor = aten::add(%1, %3, %2)
%5: float = aten::Float(%4)
%6: float = aten::add(%5, %5)
return (%6))IR";
std::string target_graph = R"IR(
graph(%0: float):
%1: float = prim::Constant[value=8.]()
%4: float = aten::add(%1, %0)
%6: float = aten::add(%4, %4)
return (%6))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());

auto first_op = *(sg->block()->nodes().begin());
torch::jit::WithInsertPoint guard(first_op);
torch::jit::Value* r = sg->insertConstant(
c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
first_op->destroy();

torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());
Expand Down

0 comments on commit 46ac757

Please sign in to comment.