Skip to content

Commit

Permalink
feat(aten::std|aten::masked_fill): Implement masked_fill, aten::std
Browse files Browse the repository at this point in the history
works for non bias corrected cases

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jul 28, 2021
1 parent fa7d6d9 commit a086a5b
Show file tree
Hide file tree
Showing 16 changed files with 565 additions and 35 deletions.
6 changes: 6 additions & 0 deletions core/conversion/conversion.cpp
Expand Up @@ -87,6 +87,9 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
if (eval) {
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
if (eval.value().isTuple() && eval.value().toTuple()->elements().size() == 1) {
eval.value() = {eval.value().toTuple()->elements()[0]};
}
} else {
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
}
Expand Down Expand Up @@ -283,6 +286,9 @@ void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n, boo
auto eval = EvaluateNode(ctx, bn);
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
if (eval.value().isTuple() && eval.value().toTuple()->elements().size() == 1) {
eval.value() = {eval.value().toTuple()->elements()[0]};
}
} else {
LOG_DEBUG(
ctx->logger,
Expand Down
24 changes: 24 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Expand Up @@ -185,6 +185,30 @@ auto element_wise_registrations TRTORCH_UNUSED =
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement self - alpha * other
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].unwrapToScalar().to<float>();
auto alpha = args[2].unwrapToScalar().to<float>();

auto rhs = other * alpha;
if (1 != rhs) {
auto rhs_tensor = tensor_to_const(ctx, torch::tensor({rhs}));
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, rhs_tensor, util::node_info(n));
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
sub->setName(util::node_info(n).c_str());
LOG_DEBUG("Output tensor shape: " << sub->getOutput(0)->getDimensions());
ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));
return true;
} else {
LOG_DEBUG("Nothing to be done this layer, passing through input");
LOG_DEBUG("Output tensor shape: " << self->getDimensions());

ctx->AssociateValueAndTensor(n->outputs()[0], self);
return true;
}
}})
.pattern({"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar "
"alpha=1) -> (Tensor(a!))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
18 changes: 9 additions & 9 deletions core/conversion/converters/impl/reduce.cpp
Expand Up @@ -40,23 +40,23 @@ auto reduce_registrations TRTORCH_UNUSED =
auto in_dims = util::toVec(in_tensor->getDimensions());
LOG_DEBUG("InDims " << in_dims); // Some abuse of toDim but just for debug info
LOG_DEBUG(
"Dim to reduce(original):" << util::toDims(dims)); // Some abuse of toDim but just for debug info
"Dim to reduce (original): " << util::toDims(dims)); // Some abuse of toDim but just for debug info
for (size_t i = 0; i < dims.size(); i++) {
auto dim_val = dims[i] < 0 ? (in_dims.size() + dims[i]) : dims[i];
calculated_dims.push_back(dim_val);
}
LOG_DEBUG(
"Dim to reduce(converted):"
"Dim to reduce (converted): "
<< util::toDims(calculated_dims)); // Some abuse of toDim but just for debug info

uint32_t axis_mask = 0;
for (size_t d = 0; d < calculated_dims.size(); d++) {
axis_mask |= 1 << calculated_dims[d];
}
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));

auto keepdim = args[2].unwrapToBool();
LOG_DEBUG("Keep dims :" << keepdim);
LOG_DEBUG("Keep dims: " << keepdim);
LOG_WARNING("Mean converter disregards dtype");
auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, keepdim);
TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n);
Expand Down Expand Up @@ -106,10 +106,10 @@ auto reduce_registrations TRTORCH_UNUSED =
for (size_t d = 0; d < calculated_dims.size(); d++) {
axis_mask |= 1 << calculated_dims[d];
}
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));

auto keepdim = args[2].unwrapToBool();
LOG_DEBUG("Keep dims :" << keepdim);
LOG_DEBUG("Keep dims: " << keepdim);

LOG_WARNING("Sum converter disregards dtype");
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
Expand Down Expand Up @@ -145,13 +145,13 @@ auto reduce_registrations TRTORCH_UNUSED =
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in_tensor = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
LOG_DEBUG("Dim to reduce:" << dim); // Some abuse of toDim but just for debug info
LOG_DEBUG("Dim to reduce: " << dim); // Some abuse of toDim but just for debug info

uint32_t axis_mask = 1 << dim;
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));

auto keepdim = args[2].unwrapToBool();
LOG_DEBUG("Keep dims :" << keepdim);
LOG_DEBUG("Keep dims: " << keepdim);

LOG_WARNING("Prod converter disregards dtype");
auto prod_layer =
Expand Down
34 changes: 16 additions & 18 deletions core/conversion/converters/impl/select.cpp
Expand Up @@ -71,35 +71,34 @@ auto select_registrations TRTORCH_UNUSED =
RegisterNodeConversionPatterns()
.pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in = args[0].ITensorOrFreeze(ctx);
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
auto axis = args[1].unwrapToInt();
axis = axis < 0 ? axis + maxDim : axis;
auto ind = (int32_t)args[2].unwrapToInt();

// index to access needs to be an at::Tensor
at::Tensor indices = torch::tensor({ind}).to(torch::kI32);
auto weights = Weights(ctx, indices);

// IConstantLayer to convert indices from Weights to ITensor
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
auto const_out = const_layer->getOutput(0);
auto const_out = tensor_to_const(ctx, indices);

// IGatherLayer takes in input tensor, the indices, and the axis
// of input tensor to take indices from
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);
auto out = gather_layer->getOutput(0);

// IShuffleLayer removes redundant dimensions
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setReshapeDimensions(util::squeezeDims(gather_out->getDimensions(), axis));
shuffle_layer->setName(util::node_info(n).c_str());
auto shuffle_out = shuffle_layer->getOutput(0);
LOG_DEBUG("Gather tensor shape: " << out->getDimensions());

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out);
if (out->getDimensions().nbDims != 1) {
// IShuffleLayer removes redundant dimensions
auto shuffle_layer = ctx->net->addShuffle(*out);
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setReshapeDimensions(util::squeezeDims(out->getDimensions(), axis));
shuffle_layer->setName(util::node_info(n).c_str());
out = shuffle_layer->getOutput(0);
}

out = ctx->AssociateValueAndTensor(n->outputs()[0], out);

LOG_DEBUG("Output tensor shape: " << out->getDimensions());

Expand Down Expand Up @@ -253,15 +252,14 @@ auto select_registrations TRTORCH_UNUSED =
"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
LOG_DEBUG(args[1].unwrapToTensor());
auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL);
mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true);
auto val = args[2].unwrapToScalar().to<float>();
LOG_DEBUG(torch::full(util::toVec(self->getDimensions()), val));
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));

TRTORCH_CHECK(util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), "Self and mask tensors are not broadcastable");

auto new_layer = ctx->net->addSelect(*mask, *self, *val_t);
auto new_layer = ctx->net->addSelect(*mask, *val_t, *self);
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::masked_fill");

new_layer->setName(util::node_info(n).c_str());
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/evaluators/aten.cpp
Expand Up @@ -573,10 +573,10 @@ auto aten_registrations TRTORCH_UNUSED =
auto dtype = args.at(n->input(1)).IValue();
auto device = args.at(n->input(2)).IValue();
auto tensor = createTensorFromList(*data, *dtype, *device);
LOG_DEBUG(tensor);
if (tensor.dtype() == at::kByte) {
return tensor.to(at::kInt);
return tensor.to(at::kFloat);
}
std::cout << tensor << std::endl;
return tensor;
},
EvalOptions().validSchemas({
Expand Down
2 changes: 2 additions & 0 deletions core/lowering/lowering.cpp
Expand Up @@ -48,6 +48,8 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::UnpackAddMM(g);
// passes::UnpackBatchNorm(g);
passes::UnpackLogSoftmax(g);
passes::UnpackStd(g);
passes::UnpackVar(g);
passes::RemoveNOPs(g);
passes::AliasOperators(g);
passes::SiluToSigmoidMultipication(g);
Expand Down
4 changes: 3 additions & 1 deletion core/lowering/passes/BUILD
Expand Up @@ -24,7 +24,9 @@ cc_library(
"unpack_addmm.cpp",
"unpack_batch_norm.cpp",
"unpack_log_softmax.cpp",
"unpack_hardswish.cpp"
"unpack_hardswish.cpp",
"unpack_std.cpp",
"unpack_var.cpp",
],
hdrs = [
"passes.h",
Expand Down
2 changes: 2 additions & 0 deletions core/lowering/passes/passes.h
Expand Up @@ -19,6 +19,8 @@ void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackStd(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
30 changes: 30 additions & 0 deletions core/lowering/passes/unpack_std.cpp
@@ -0,0 +1,30 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void UnpackStd(std::shared_ptr<torch::jit::Graph>& graph) {
std::string std_pattern = R"IR(
graph(%1, %dim, %unbiased, %keepdim):
%out: Tensor = aten::std(%1, %dim, %unbiased, %keepdim)
return (%out))IR";
std::string unpacked_pattern = R"IR(
graph(%1, %dim, %unbiased, %keepdim):
%z: Tensor = aten::var(%1, %dim, %unbiased, %keepdim)
%out: Tensor = aten::sqrt(%z)
return (%out))IR";

torch::jit::SubgraphRewriter std_rewriter;
std_rewriter.RegisterRewritePattern(std_pattern, unpacked_pattern);
std_rewriter.runOnGraph(graph);
LOG_GRAPH("Post unpack std: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
51 changes: 51 additions & 0 deletions core/lowering/passes/unpack_var.cpp
@@ -0,0 +1,51 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
std::string var_pattern = R"IR(
graph(%input, %dim, %unbiased, %keepdim):
%out: Tensor = aten::var(%input, %dim, %unbiased, %keepdim)
return (%out))IR";
std::string unpacked_pattern = R"IR(
graph(%input, %dims, %unbiased, %keepdim):
%none: None = prim::Constant()
%false: bool = prim::Constant[value=0]()
%0: int = prim::Constant[value=0]()
%1: int = prim::Constant[value=1]()
%sqrd: Tensor = aten::mul(%input, %input)
%sqrdmean: Tensor = aten::mean(%sqrd, %dims, %keepdim, %none)
%mean: Tensor = aten::mean(%input, %dims, %keepdim, %none)
%meansqrd: Tensor = aten::mul(%mean, %mean)
%var: Tensor = aten::sub(%sqrdmean, %meansqrd, %1)
%varout : Tensor = prim::If(%unbiased)
block0():
%shape: int[] = aten::size(%input)
%shapet: Tensor = aten::tensor(%shape, %0, %none, %false)
%dim: int = prim::ListUnpack(%dims)
%reduceddims: Tensor = aten::select(%shapet, %0, %dim)
%numel: Tensor = aten::prod(%reduceddims, %dim, %keepdim, %none)
%mul: Tensor = aten::mul(%var, %numel)
%sub: Tensor = aten::sub(%numel, %1, %1)
%v: Tensor = aten::div(%mul, %sub)
-> (%v)
block1():
-> (%var)
return(%varout))IR";

torch::jit::SubgraphRewriter var_rewriter;
var_rewriter.RegisterRewritePattern(var_pattern, unpacked_pattern);
var_rewriter.runOnGraph(graph);
LOG_DEBUG("Post unpack var: " << *graph);

}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
2 changes: 0 additions & 2 deletions core/util/trt_util.h
Expand Up @@ -21,8 +21,6 @@ inline std::ostream& operator<<(std::ostream& os, const nvinfer1::TensorFormat&

inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType& dtype) {
switch (dtype) {
case nvinfer1::DataType::kBOOL:
return stream << "Bool";
case nvinfer1::DataType::kFLOAT:
return stream << "Float32";
case nvinfer1::DataType::kHALF:
Expand Down

0 comments on commit a086a5b

Please sign in to comment.