Skip to content

Commit

Permalink
feat(aten::matmul|aten::addmm): Adds support for aten::matmul and
Browse files Browse the repository at this point in the history
aten::admm

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed May 2, 2020
1 parent d945eb9 commit c5b6202
Show file tree
Hide file tree
Showing 17 changed files with 197 additions and 95 deletions.
6 changes: 5 additions & 1 deletion core/conversion/conversion.cpp
Expand Up @@ -73,7 +73,11 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
LOG_DEBUG(ctx->logger, "Node input is a value that needs to be evaluated");
auto eval = EvaluateNode(ctx, input_node);
if (eval) {
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
} else {
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
}
ctx->evaluated_value_map[input] = std::move(eval.value());
node_args.push_back(&(ctx->evaluated_value_map[input]));
} else {
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/converters/Arg.cpp
Expand Up @@ -89,6 +89,7 @@ std::string Arg::type_name() const {
}

const torch::jit::IValue* Arg::IValue() const {
TRTORCH_CHECK(isIValue(), "Requested IValue from Arg, however arg type is " << type_name());
if (type_ == Type::kIValue) {
return ptr_.ivalue;
} else {
Expand All @@ -97,6 +98,7 @@ const torch::jit::IValue* Arg::IValue() const {
}

nvinfer1::ITensor* Arg::ITensor() const {
TRTORCH_CHECK(isITensor(), "Requested ITensor from Arg, however arg type is " << type_name());
if (type_ == Type::kITensor) {
return ptr_.tensor;
} else {
Expand Down
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Expand Up @@ -15,6 +15,7 @@ cc_library(
"impl/conv_deconv.cpp",
"impl/element_wise.cpp",
"impl/linear.cpp",
"impl/matrix_multiply.cpp",
"impl/pooling.cpp",
"impl/reduce.cpp",
"impl/shuffle.cpp",
Expand Down
9 changes: 9 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Expand Up @@ -14,6 +14,15 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera

TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims), "Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims);

if (self_dims != other_dims) {
LOG_DEBUG("Input shape dont match inserting shuffle layers to reshape to " << self_dims);
auto other_shuffle = ctx->net->addShuffle(*other);
other_shuffle->setReshapeDimensions(self_dims);
other_shuffle->setName(std::string("[Reshape other to " + util::toStr(self_dims) + ']').c_str());
other = other_shuffle->getOutput(0);
}


nvinfer1::ILayer* ele;
if (scalar != 1) {
LOG_WARNING("Please verify scalar handling in add converter, channel axis set to 3 but scaling is uniform");
Expand Down
9 changes: 1 addition & 8 deletions core/conversion/converters/impl/linear.cpp
Expand Up @@ -9,13 +9,6 @@ namespace impl {
namespace {

auto linear_registrations = RegisterNodeConversionPatterns()
// .pattern({
// "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
// [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> {
// auto in = args[0].ITensor();

// }
// })
.pattern({
"aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down Expand Up @@ -71,4 +64,4 @@ auto linear_registrations = RegisterNodeConversionPatterns()
} // namespace converters
} // namespace conversion
} // namespace core
} // trtorch
} // namespace trtorch
55 changes: 55 additions & 0 deletions core/conversion/converters/impl/matrix_multiply.cpp
@@ -0,0 +1,55 @@
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {

auto mm_registrations = RegisterNodeConversionPatterns()
.pattern({
"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
nvinfer1::ITensor* self;
if (args[0].isIValue()) {
auto t = args[0].unwrapToTensor();
auto t_weights = Weights(ctx, t);
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor self for node: " << *n);
const_layer->setName((util::node_info(n) + " [Freeze Tensor(self)]").c_str());
self = const_layer->getOutput(0);
} else {
self = args[0].ITensor();
}
LOG_DEBUG("self tensor shape: " << self->getDimensions());

nvinfer1::ITensor* other;
if (args[1].isIValue()) {
auto t = args[1].unwrapToTensor();
auto t_weights = Weights(ctx, t);
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor other for node: " << *n);
const_layer->setName((util::node_info(n) + " [Freeze Tensor(other)]").c_str());
other = const_layer->getOutput(0);
} else {
other = args[1].ITensor();
}
LOG_DEBUG("other tensor shape: " << other->getDimensions());

auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
mm_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}
});
} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
4 changes: 3 additions & 1 deletion core/lowering/BUILD
Expand Up @@ -8,12 +8,14 @@ cc_library(
srcs = [
"lowering.cpp",
"drop_unused_nodes.cpp",
"register_const_op.cpp"
],
deps = [
"@libtorch//:libtorch",
"//core/lowering/passes",
"//core/util:prelude"
]
],
alwayslink = True
)

load("@rules_pkg//:pkg.bzl", "pkg_tar")
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Expand Up @@ -25,6 +25,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
torch::jit::FuseLinear(g);
passes::RemoveDropout(g);
passes::FuseFlattenLinear(g);
passes::UnpackAddMM(g);
passes::ExpandLogSoftmax(g);
//passes::RemoveDimExeception(g);
//irfusers::UnpackBatchNorm(g);
Expand Down
3 changes: 2 additions & 1 deletion core/lowering/passes/BUILD
Expand Up @@ -10,7 +10,8 @@ cc_library(
"expand_log_softmax.cpp",
"remove_dropout.cpp",
"unpack_batch_norm.cpp",
"exception_elimination.cpp"
"exception_elimination.cpp",
"unpack_addmm.cpp"
],
deps = [
"//core/util:prelude",
Expand Down
33 changes: 0 additions & 33 deletions core/lowering/passes/fuse_flatten_linear.cpp
Expand Up @@ -40,39 +40,6 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
flatten_linear_bias_none_to_linear.runOnGraph(graph);
}

void FuseFlattenAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string flatten_linear_pattern = R"IR(
graph(%input, %6, %7, %weight, %bias):
%flat = aten::flatten(%input, %6, %7)
%res = aten::linear(%flat, %weight, %bias)
return (%res))IR";
std::string flatten_linear_bias_none_pattern = R"IR(
graph(%input, %6, %7, %weight):
%flat = aten::flatten(%input, %6, %7)
%bias: Tensor? = prim::Constant()
%res = aten::linear(%flat, %weight, %bias)
return (%res))IR";
std::string fused_linear = R"IR(
graph(%input, %6, %7, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";

std::string fused_linear_bias_none = R"IR(
graph(%input, %6, %7, %weight):
%bias: Tensor? = prim::Constant()
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";

torch::jit::SubgraphRewriter flatten_linear_to_linear;
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
flatten_linear_to_linear.runOnGraph(graph);

torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
flatten_linear_bias_none_pattern, fused_linear_bias_none);
flatten_linear_bias_none_to_linear.runOnGraph(graph);
}
} // namespace passes
} // namespace lowering
} // namespace core
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Expand Up @@ -11,6 +11,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);

} // namespace irfusers
Expand Down
32 changes: 32 additions & 0 deletions core/lowering/passes/unpack_addmm.cpp
@@ -0,0 +1,32 @@
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string addmm_pattern = R"IR(
graph(%b, %x, %w, %1):
%out: Tensor = aten::addmm(%b, %x, %w, %1, %1)
return (%out))IR";
std::string mm_add_pattern = R"IR(
graph(%b, %x, %w, %1):
%mm: Tensor = aten::matmul(%x, %w)
%bias: Tensor = trt::const(%b)
%out: Tensor = aten::add_(%bias, %mm, %1)
return (%out))IR";


torch::jit::SubgraphRewriter unpack_addmm;
unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern);
unpack_addmm.runOnGraph(graph);
}


} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
20 changes: 0 additions & 20 deletions core/lowering/passes/unpack_batch_norm.cpp
@@ -1,25 +1,5 @@
#include "torch/csrc/jit/runtime/custom_operator.h"
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

namespace torch {
namespace jit {

c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}

RegisterOperators trt_const_op_reg({
Operator(
"trt::const(Tensor val) -> Tensor",
[](Stack& stack) {
return 0; //nop
},
aliasAnalysisFromSchema())});

} // namespace jit
} // namespace torch

namespace trtorch {
namespace core {
namespace lowering {
Expand Down
21 changes: 21 additions & 0 deletions core/lowering/register_const_op.cpp
@@ -0,0 +1,21 @@
#include "torch/csrc/jit/runtime/custom_operator.h"

namespace torch {
namespace jit {

c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}

/// Op marks a Tensor to be conveted from an Torch Tensor
/// to a TRT constant Tensor
RegisterOperators trt_const_op_reg({
Operator(
"trt::const(Tensor val) -> Tensor",
[](Stack& stack) {
return 0; //noop
},
aliasAnalysisFromSchema())});

} // namespace jit
} // namespace torch
47 changes: 26 additions & 21 deletions tests/core/converters/BUILD
@@ -1,54 +1,59 @@
load("//tests/core/converters:converter_test.bzl", "converter_test")

converter_test(
name = "test_softmax"
name = "test_activation"
)

converter_test(
name = "test_shuffle"
name = "test_conv"
)

converter_test(
name = "test_activation"
name = "test_element_wise"
)

converter_test(
name = "test_pooling"
name = "test_linear"
)

converter_test(
name = "test_unary"
name = "test_matrix_multiply"
)

converter_test(
name = "test_linear"
name = "test_pooling"
)

converter_test(
name = "test_element_wise"
name = "test_reduce"
)

converter_test(
name = "test_conv"
name = "test_shuffle"
)

converter_test(
name = "test_reduce"
name = "test_softmax"
)

converter_test(
name = "test_unary"
)

test_suite(
name = "test_converters",
tests = [
":test_softmax",
":test_shuffle",
":test_activation",
":test_pooling",
":test_unary",
":test_linear",
":test_element_wise",
":test_conv",
":test_reduce"
]
name = "test_converters",
tests = [
":test_activation",
":test_conv",
":test_element_wise",
":test_linear",
":test_matrix_multiply",
":test_pooling",
":test_reduce",
":test_shuffle",
":test_softmax",
":test_unary",
]
)


0 comments on commit c5b6202

Please sign in to comment.