Skip to content

Commit

Permalink
feat(): added adaptive_avg_pool2d plugin, and added test for it
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <abhirami@nvidia.com>

Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com>
  • Loading branch information
abhi-iyer committed Jun 18, 2020
1 parent 9458f21 commit fa227b0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 17 deletions.
59 changes: 42 additions & 17 deletions core/conversion/converters/impl/pooling.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"
#include "plugins/interpolate_plugin.h"


namespace trtorch {
namespace core {
Expand Down Expand Up @@ -273,30 +275,51 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
in_shape = util::toVec(in->getDimensions());
}

auto out_shape = args[1].IValue()->toIntList();
//auto out_size = args[1].IValue()->toIntList();
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));

if (ctx->input_is_dynamic) {
LOG_WARNING("Pooling layer will be run through ATen, not TensorRT. Performance may differ.");

std::vector<int64_t> stride(out_shape.size());
for (size_t i = 0; i < out_shape.size(); i++) {
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_shape[(out_shape.size() - 1) - i];
}
LOG_DEBUG("Stride: " << util::toDims(stride));
auto out_shape = in_shape;
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));

std::vector<int64_t> window(out_shape.size());
for (size_t i = 0; i < out_shape.size(); i++) {
window[window.size() - 1 - i] = in_shape[in_shape.size() - 1 - i] - (out_shape[out_shape.size() - 1 - i] - 1) * stride[stride.size() - 1 - i];
}
auto creator = new plugins::InterpolatePluginCreator();
auto plugin = creator->createPlugin("adaptive_pool2d", in_shape, out_shape, out_size, std::string("adaptive_pool2d"), false);

LOG_DEBUG("Window: " << util::toDims(window));
auto pooling_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
TRTORCH_CHECK(pooling_layer, "Unable to create pooling (interpolation) plugin from node" << *n);

auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window));
TRTORCH_CHECK(new_layer, "Unable to create average pooling layer from node: " << *n);
pooling_layer->setName(util::node_info(n).c_str());

new_layer->setStrideNd(util::toDims(stride));
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], pooling_layer->getOutput(0));

new_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
} else {
std::vector<int64_t> stride(out_size.size());
for (size_t i = 0; i < out_size.size(); i++) {
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_size[(out_size.size() - 1) - i];
}
LOG_DEBUG("Stride: " << util::toDims(stride));

LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
std::vector<int64_t> window(out_size.size());
for (size_t i = 0; i < out_size.size(); i++) {
window[window.size() - 1 - i] = in_shape[in_shape.size() - 1 - i] - (out_size[out_size.size() - 1 - i] - 1) * stride[stride.size() - 1 - i];
}

LOG_DEBUG("Window: " << util::toDims(window));

auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window));
TRTORCH_CHECK(new_layer, "Unable to create average pooling layer from node: " << *n);

new_layer->setStrideNd(util::toDims(stride));

new_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
}

return true;
}
});
Expand All @@ -306,3 +329,5 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
} // namespace conversion
} // namespace core
} // trtorch


26 changes: 26 additions & 0 deletions tests/core/converters/test_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,29 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%0 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=4]()
%6 : int[] = prim::ListConstruct(%2, %3)
%10 : Tensor = aten::adaptive_avg_pool2d(%0, %6)
return (%10))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

//PyTorch MaxPool needs a 3D input
auto in = at::randint(-5, 5, {10, 18, 36}, at::kCUDA);

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

0 comments on commit fa227b0

Please sign in to comment.