Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(aten::matmul|aten::addmm): Adds support for aten::matmul and
aten::admm Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
- Loading branch information
1 parent
d945eb9
commit c5b6202
Showing
17 changed files
with
197 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#include "core/util/prelude.h" | ||
#include "core/conversion/converters/converters.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace conversion { | ||
namespace converters { | ||
namespace impl { | ||
namespace { | ||
|
||
auto mm_registrations = RegisterNodeConversionPatterns() | ||
.pattern({ | ||
"aten::matmul(Tensor self, Tensor other) -> (Tensor)", | ||
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { | ||
nvinfer1::ITensor* self; | ||
if (args[0].isIValue()) { | ||
auto t = args[0].unwrapToTensor(); | ||
auto t_weights = Weights(ctx, t); | ||
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); | ||
TRTORCH_CHECK(const_layer, "Unable to freeze tensor self for node: " << *n); | ||
const_layer->setName((util::node_info(n) + " [Freeze Tensor(self)]").c_str()); | ||
self = const_layer->getOutput(0); | ||
} else { | ||
self = args[0].ITensor(); | ||
} | ||
LOG_DEBUG("self tensor shape: " << self->getDimensions()); | ||
|
||
nvinfer1::ITensor* other; | ||
if (args[1].isIValue()) { | ||
auto t = args[1].unwrapToTensor(); | ||
auto t_weights = Weights(ctx, t); | ||
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); | ||
TRTORCH_CHECK(const_layer, "Unable to freeze tensor other for node: " << *n); | ||
const_layer->setName((util::node_info(n) + " [Freeze Tensor(other)]").c_str()); | ||
other = const_layer->getOutput(0); | ||
} else { | ||
other = args[1].ITensor(); | ||
} | ||
LOG_DEBUG("other tensor shape: " << other->getDimensions()); | ||
|
||
auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE); | ||
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n); | ||
mm_layer->setName(util::node_info(n).c_str()); | ||
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); | ||
|
||
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); | ||
return true; | ||
} | ||
}); | ||
} // namespace | ||
} // namespace impl | ||
} // namespace converters | ||
} // namespace conversion | ||
} // namespace core | ||
} // namespace trtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#include "torch/csrc/jit/passes/fuse_linear.h" | ||
#include "torch/csrc/jit/passes/subgraph_rewrite.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace lowering { | ||
namespace passes { | ||
|
||
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) { | ||
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary | ||
std::string addmm_pattern = R"IR( | ||
graph(%b, %x, %w, %1): | ||
%out: Tensor = aten::addmm(%b, %x, %w, %1, %1) | ||
return (%out))IR"; | ||
std::string mm_add_pattern = R"IR( | ||
graph(%b, %x, %w, %1): | ||
%mm: Tensor = aten::matmul(%x, %w) | ||
%bias: Tensor = trt::const(%b) | ||
%out: Tensor = aten::add_(%bias, %mm, %1) | ||
return (%out))IR"; | ||
|
||
|
||
torch::jit::SubgraphRewriter unpack_addmm; | ||
unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern); | ||
unpack_addmm.runOnGraph(graph); | ||
} | ||
|
||
|
||
} // namespace passes | ||
} // namespace lowering | ||
} // namespace core | ||
} // namespace trtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#include "torch/csrc/jit/runtime/custom_operator.h" | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
c10::AliasAnalysisKind aliasAnalysisFromSchema() { | ||
return c10::AliasAnalysisKind::FROM_SCHEMA; | ||
} | ||
|
||
/// Op marks a Tensor to be conveted from an Torch Tensor | ||
/// to a TRT constant Tensor | ||
RegisterOperators trt_const_op_reg({ | ||
Operator( | ||
"trt::const(Tensor val) -> Tensor", | ||
[](Stack& stack) { | ||
return 0; //noop | ||
}, | ||
aliasAnalysisFromSchema())}); | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,59 @@ | ||
load("//tests/core/converters:converter_test.bzl", "converter_test") | ||
|
||
converter_test( | ||
name = "test_softmax" | ||
name = "test_activation" | ||
) | ||
|
||
converter_test( | ||
name = "test_shuffle" | ||
name = "test_conv" | ||
) | ||
|
||
converter_test( | ||
name = "test_activation" | ||
name = "test_element_wise" | ||
) | ||
|
||
converter_test( | ||
name = "test_pooling" | ||
name = "test_linear" | ||
) | ||
|
||
converter_test( | ||
name = "test_unary" | ||
name = "test_matrix_multiply" | ||
) | ||
|
||
converter_test( | ||
name = "test_linear" | ||
name = "test_pooling" | ||
) | ||
|
||
converter_test( | ||
name = "test_element_wise" | ||
name = "test_reduce" | ||
) | ||
|
||
converter_test( | ||
name = "test_conv" | ||
name = "test_shuffle" | ||
) | ||
|
||
converter_test( | ||
name = "test_reduce" | ||
name = "test_softmax" | ||
) | ||
|
||
converter_test( | ||
name = "test_unary" | ||
) | ||
|
||
test_suite( | ||
name = "test_converters", | ||
tests = [ | ||
":test_softmax", | ||
":test_shuffle", | ||
":test_activation", | ||
":test_pooling", | ||
":test_unary", | ||
":test_linear", | ||
":test_element_wise", | ||
":test_conv", | ||
":test_reduce" | ||
] | ||
name = "test_converters", | ||
tests = [ | ||
":test_activation", | ||
":test_conv", | ||
":test_element_wise", | ||
":test_linear", | ||
":test_matrix_multiply", | ||
":test_pooling", | ||
":test_reduce", | ||
":test_shuffle", | ||
":test_softmax", | ||
":test_unary", | ||
] | ||
) | ||
|
||
|
Oops, something went wrong.