Skip to content

Commit

Permalink
feat(conv2d_to_convolution): A pass to map aten::conv2d to _convolution
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed May 7, 2020
1 parent 272ef40 commit 2c5c0d5
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Expand Up @@ -25,6 +25,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
torch::jit::FuseLinear(g);
passes::RemoveDropout(g);
passes::FuseFlattenLinear(g);
passes::Conv2DToConvolution(g);
passes::UnpackAddMM(g);
passes::UnpackLogSoftmax(g);
//passes::RemoveDimExeception(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Expand Up @@ -6,6 +6,7 @@ cc_library(
"passes.h",
],
srcs = [
"conv2d_to_convolution.cpp",
"exception_elimination.cpp",
"fuse_flatten_linear.cpp",
"remove_dropout.cpp",
Expand Down
34 changes: 34 additions & 0 deletions core/lowering/passes/conv2d_to_convolution.cpp
@@ -0,0 +1,34 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv2d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=1]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %3)
return (%4))IR";;

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
map_conv2d_to_convolution.RegisterRewritePattern(
conv2d_pattern, convolution_pattern);
map_conv2d_to_convolution.runOnGraph(graph);
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Expand Up @@ -7,6 +7,7 @@ namespace core {
namespace lowering {
namespace passes {

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down

0 comments on commit 2c5c0d5

Please sign in to comment.