From 0c0c5df4cbed8a9c947fd2819640e9d402555ed1 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Fri, 1 Jun 2018 15:39:30 +0800 Subject: [PATCH] feature/add TRT fc converter (#11043) --- .../inference/tensorrt/convert/CMakeLists.txt | 2 + .../inference/tensorrt/convert/conv2d_op.cc | 3 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 119 ++++++++++++++++++ .../inference/tensorrt/convert/mul_op.cc | 5 +- .../inference/tensorrt/convert/op_converter.h | 41 ++++-- .../inference/tensorrt/convert/test_fc_op.cc | 46 +++++++ .../inference/tensorrt/convert/test_mul_op.cc | 4 +- .../tensorrt/convert/test_op_converter.cc | 7 +- .../inference/tensorrt/convert/ut_helper.h | 40 +++--- paddle/fluid/inference/tensorrt/engine.cc | 1 + paddle/fluid/inference/tensorrt/engine.h | 4 +- paddle/fluid/operators/tensorrt_engine_op.cc | 3 +- 12 files changed, 240 insertions(+), 35 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/fc_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/test_fc_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 5ada1d6312692..23ca8bfac84f3 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -8,3 +8,5 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) +nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 209936c3bafb0..668d344f1bba1 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -21,7 +21,8 @@ namespace tensorrt { class Conv2dOpConverter : public OpConverter { public: Conv2dOpConverter() {} - void operator()(const framework::proto::OpDesc& op) override { + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope) override { LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias"; } diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc new file mode 100644 index 0000000000000..bd05608d7620e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +// Reorder the elements from istrides to ostrides, borrowed from TRT convert in +// tensorflow. +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorrt/convert/convert_nodes.cc#L318 +template +void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, + T* odata, nvinfer1::DimsHW ostrides) { + for (int h = 0; h < shape.h(); ++h) { + for (int w = 0; w < shape.w(); ++w) { + odata[h * ostrides.h() + w * ostrides.w()] = + idata[h * ostrides.h() + w * ostrides.w()]; + } + } +} + +// Reorder the data layout from CK to KC. +void ReorderCKtoKC(TensorRTEngine::Weight& iweights, + TensorRTEngine::Weight* oweights) { + int c = iweights.dims[0]; + int k = iweights.dims[1]; + oweights->dims.assign({k, c}); + nvinfer1::DimsHW istrides = {1, k}; + nvinfer1::DimsHW ostrides = {c, 1}; + Reorder2({k, c}, static_cast(iweights.get().values), istrides, + static_cast(const_cast(oweights->get().values)), + ostrides); +} + +/* + * FC converter convert a MUL op in Fluid to a FC layer in TRT. + */ +class FcOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope) override { + VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias"; + + framework::OpDesc op_desc(op, nullptr, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + // Declare inputs + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + + // Declare weights + auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); + PADDLE_ENFORCE_NOT_NULL(Y_v); + auto* Y_t = Y_v->GetMutable(); + // This may trigger a GPU->CPU copy, because TRT's weight can only be + // assigned from CPU memory, that can't be avoided. + auto* weight_data = Y_t->mutable_data(platform::CPUPlace()); + PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL); // a matrix + size_t n_output = Y_t->dims()[1]; + + framework::LoDTensor tmp; + tmp.Resize(Y_t->dims()); + memcpy(tmp.mutable_data(platform::CPUPlace()), Y_t->data(), + Y_t->dims()[0] * Y_t->dims()[1]); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + Y_t->memory_size() / sizeof(float)}; + TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT, + static_cast(tmp.data()), + Y_t->memory_size() / sizeof(float)); + weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]}); + tmp_weight.dims = weight.dims; + + // The data layout of TRT FC layer's weight is different from fluid's FC, + // need to reorder the elements. + ReorderCKtoKC(tmp_weight, &weight); + + // Currently, the framework can only handle one fluid op -> one TRT layer, + // but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just + // handle `mul`, leave `add` as another layer. + // DEBUG + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, + *const_cast(X), + n_output, weight.get(), bias.get()); + + auto output_name = op_desc.Output("Out").front(); + engine_->DeclareOutput(layer, 0, output_name); + } +}; + +REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(mul); diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc index aa8e66490f7e4..6bb07709c7ee1 100644 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -24,8 +24,9 @@ namespace tensorrt { class MulOpConverter : public OpConverter { public: MulOpConverter() {} - void operator()(const framework::proto::OpDesc& op) override { - VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias"; + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope) override { + VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias"; framework::OpDesc op_desc(op, nullptr); // Declare inputs diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 1cd3ed9a00ace..4d21e241c0fe0 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -31,27 +31,42 @@ namespace tensorrt { class OpConverter { public: OpConverter() {} - virtual void operator()(const framework::proto::OpDesc& op) {} - void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) { - std::string type = op.type(); - auto* it = Registry::Lookup(type); - PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type); - it->SetEngine(engine); - (*it)(op); - } + // Converter logic for an op. + virtual void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope) {} + + // Convert a single fluid operaotr and add the corresponding layer to TRT. + void ConvertOp(const framework::proto::OpDesc& op, + const std::unordered_set& parameters, + const framework::Scope& scope, TensorRTEngine* engine) { + framework::OpDesc op_desc(op, nullptr, nullptr); + + OpConverter* it{nullptr}; - // convert fluid op to tensorrt layer - void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) { - OpConverter::Run(op, engine); + if (op_desc.Type() == "mul") { + PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL); + std::string Y = op_desc.Input("Y")[0]; + if (parameters.count(Y)) { + it = Registry::Lookup("fc"); + } + } + if (!it) { + it = Registry::Lookup(op_desc.Type()); + } + PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", + op_desc.Type()); + it->SetEngine(engine); + (*it)(op, scope); } // convert fluid block to tensorrt network void ConvertBlock(const framework::proto::BlockDesc& block, - TensorRTEngine* engine) { + const std::unordered_set& parameters, + const framework::Scope& scope, TensorRTEngine* engine) { for (int i = 0; i < block.ops_size(); i++) { const auto& op = block.ops(i); - OpConverter::Run(op, engine); + ConvertOp(op, parameters, scope, engine); } } diff --git a/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc new file mode 100644 index 0000000000000..a30253072ac58 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(fc_op, test) { + std::unordered_set parameters({"mul-Y"}); + framework::Scope scope; + TRTConvertValidation validator(20, parameters, scope, 1000); + + validator.DeclInputVar("mul-X", nvinfer1::Dims4(8, 3, 1, 1)); + validator.DeclParamVar("mul-Y", nvinfer1::Dims2(3, 2)); + validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(8, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("mul"); + desc.SetInput("X", {"mul-X"}); + desc.SetInput("Y", {"mul-Y"}); + desc.SetOutput("Out", {"mul-Out"}); + + validator.SetOp(*desc.Proto()); + + validator.Execute(10); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc b/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc index d8b61d5f08ffd..1ce1130e5d660 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc @@ -21,7 +21,9 @@ namespace inference { namespace tensorrt { TEST(MulOpConverter, main) { - TRTConvertValidation validator(10, 1000); + framework::Scope scope; + std::unordered_set parameters; + TRTConvertValidation validator(10, parameters, scope, 1000); validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6)); validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10)); validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10)); diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index 9ae7de9cbfa65..1d3f5eabb2f83 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + #include #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" namespace paddle { namespace inference { @@ -27,7 +28,9 @@ TEST(OpConverter, ConvertBlock) { conv2d_op->SetType("conv2d"); OpConverter converter; - converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/); + framework::Scope scope; + converter.ConvertBlock(*block->Proto(), {}, scope, + nullptr /*TensorRTEngine*/); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 684bbc208fc1c..d7e05dd5b5b23 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -61,7 +61,10 @@ class TRTConvertValidation { public: TRTConvertValidation() = delete; - explicit TRTConvertValidation(int batch_size, int workspace_size = 1024) { + TRTConvertValidation(int batch_size, + const std::unordered_set& parameters, + framework::Scope& scope, int workspace_size = 1 << 10) + : parameters_(parameters), scope_(scope) { // create engine. engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_)); engine_->InitNetwork(); @@ -76,19 +79,22 @@ class TRTConvertValidation { engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims); } + // Declare a parameter varaible in the scope. + void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) { + DeclVar(name, dims); + } + void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) { DeclVar(name, dims); } + // Declare a variable in a fluid Scope. void DeclVar(const std::string& name, const nvinfer1::Dims& dims) { platform::CPUPlace place; platform::CPUDeviceContext ctx(place); // Init Fluid tensor. - std::vector dim_vec(dims.nbDims); - for (int i = 0; i < dims.nbDims; i++) { - dim_vec[i] = dims.d[i]; - } + std::vector dim_vec(dims.d, dims.d + dims.nbDims); auto* x = scope_.Var(name); auto* x_tensor = x->GetMutable(); x_tensor->Resize(framework::make_ddim(dim_vec)); @@ -99,7 +105,7 @@ class TRTConvertValidation { op_ = framework::OpRegistry::CreateOp(desc); OpConverter op_converter; - op_converter.ConvertOp(desc, engine_.get()); + op_converter.ConvertOp(desc, parameters_, scope_, engine_.get()); engine_->FreezeNetwork(); @@ -108,11 +114,13 @@ class TRTConvertValidation { // Set Inputs. for (const auto& input : op_desc_->InputArgumentNames()) { + if (parameters_.count(input)) continue; auto* var = scope_.FindVar(input); PADDLE_ENFORCE(var); auto tensor = var->GetMutable(); + engine_->SetInputFromCPU( - input, static_cast(tensor->data()), + input, static_cast(tensor->data()), sizeof(float) * analysis::AccuDims(tensor->dims(), tensor->dims().size())); } @@ -120,18 +128,21 @@ class TRTConvertValidation { void Execute(int batch_size) { // Execute Fluid Op - // Execute TRT platform::CPUPlace place; platform::CPUDeviceContext ctx(place); - engine_->Execute(batch_size); - op_->Run(scope_, place); + // Execute TRT. + engine_->Execute(batch_size); + cudaStreamSynchronize(*engine_->stream()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); + const size_t output_space_size = 200; for (const auto& output : op_desc_->OutputArgumentNames()) { std::vector fluid_out; - std::vector trt_out(200); - engine_->GetOutputInCPU(output, &trt_out[0], 200 * sizeof(float)); + std::vector trt_out(output_space_size); + engine_->GetOutputInCPU(output, &trt_out[0], + output_space_size * sizeof(float)); + cudaStreamSynchronize(*engine_->stream()); auto* var = scope_.FindVar(output); auto tensor = var->GetMutable(); @@ -139,7 +150,7 @@ class TRTConvertValidation { // Compare two output ASSERT_FALSE(fluid_out.empty()); for (size_t i = 0; i < fluid_out.size(); i++) { - EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 0.001); + EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 1e-6); } } } @@ -149,9 +160,10 @@ class TRTConvertValidation { private: std::unique_ptr engine_; cudaStream_t stream_; - framework::Scope scope_; std::unique_ptr op_; std::unique_ptr op_desc_; + const std::unordered_set& parameters_; + framework::Scope& scope_; }; } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index a88236ae98e18..3d75fefc1a735 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -106,6 +106,7 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, name); auto* output = layer->getOutput(offset); + SetITensor(name, output); PADDLE_ENFORCE(output != nullptr); output->setName(name.c_str()); infer_network_->markOutput(*output); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index d9d3163b66d4c..fabcfd9e80cc0 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -37,13 +37,15 @@ class TensorRTEngine : public EngineBase { // Weight is model parameter. class Weight { public: - Weight(nvinfer1::DataType dtype, void* value, int num_elem) { + Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) { w_.type = dtype; w_.values = value; w_.count = num_elem; } const nvinfer1::Weights& get() { return w_; } + std::vector dims; + private: nvinfer1::Weights w_; }; diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 83e768b4dc9c6..855157e7c4c5c 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -31,8 +31,9 @@ void paddle::operators::TensorRTEngineKernel::Prepare( auto max_workspace = context.Attr("max_workspace"); engine_.reset(new inference::tensorrt::TensorRTEngine( max_batch_, max_workspace, nullptr)); + // TODO(Superjomn) parameters should be passed after analysised from outside. inference::Singleton::Global().ConvertBlock( - block, engine_.get()); + block, {}, context.scope(), engine_.get()); engine_->FreezeNetwork(); }