Skip to content

Commit

Permalink
feature/add TRT fc converter (#11043)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Jun 1, 2018
1 parent 18d6402 commit 0c0c5df
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 35 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
3 changes: 2 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace tensorrt {
class Conv2dOpConverter : public OpConverter {
public:
Conv2dOpConverter() {}
void operator()(const framework::proto::OpDesc& op) override {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override {
LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
}
Expand Down
119 changes: 119 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/fc_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
namespace inference {
namespace tensorrt {

// Reorder the elements from istrides to ostrides, borrowed from TRT convert in
// tensorflow.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorrt/convert/convert_nodes.cc#L318
template <typename T>
void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
T* odata, nvinfer1::DimsHW ostrides) {
for (int h = 0; h < shape.h(); ++h) {
for (int w = 0; w < shape.w(); ++w) {
odata[h * ostrides.h() + w * ostrides.w()] =
idata[h * ostrides.h() + w * ostrides.w()];
}
}
}

// Reorder the data layout from CK to KC.
void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
TensorRTEngine::Weight* oweights) {
int c = iweights.dims[0];
int k = iweights.dims[1];
oweights->dims.assign({k, c});
nvinfer1::DimsHW istrides = {1, k};
nvinfer1::DimsHW ostrides = {c, 1};
Reorder2({k, c}, static_cast<float const*>(iweights.get().values), istrides,
static_cast<float*>(const_cast<void*>(oweights->get().values)),
ostrides);
}

/*
* FC converter convert a MUL op in Fluid to a FC layer in TRT.
*/
class FcOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override {
VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias";

framework::OpDesc op_desc(op, nullptr, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);

// Declare inputs
auto* X = engine_->GetITensor(op_desc.Input("X").front());

// Declare weights
auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(Y_v);
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
// This may trigger a GPU->CPU copy, because TRT's weight can only be
// assigned from CPU memory, that can't be avoided.
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace());
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL); // a matrix
size_t n_output = Y_t->dims()[1];

framework::LoDTensor tmp;
tmp.Resize(Y_t->dims());
memcpy(tmp.mutable_data<float>(platform::CPUPlace()), Y_t->data<float>(),
Y_t->dims()[0] * Y_t->dims()[1]);

TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)};
TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT,
static_cast<void*>(tmp.data<float>()),
Y_t->memory_size() / sizeof(float));
weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]});
tmp_weight.dims = weight.dims;

// The data layout of TRT FC layer's weight is different from fluid's FC,
// need to reorder the elements.
ReorderCKtoKC(tmp_weight, &weight);

// Currently, the framework can only handle one fluid op -> one TRT layer,
// but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just
// handle `mul`, leave `add` as another layer.
// DEBUG
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};

auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected,
*const_cast<nvinfer1::ITensor*>(X),
n_output, weight.get(), bias.get());

auto output_name = op_desc.Output("Out").front();
engine_->DeclareOutput(layer, 0, output_name);
}
};

REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);

} // namespace tensorrt
} // namespace inference
} // namespace paddle

USE_OP(mul);
5 changes: 3 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ namespace tensorrt {
class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
void operator()(const framework::proto::OpDesc& op) override {
VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias";
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override {
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";

framework::OpDesc op_desc(op, nullptr);
// Declare inputs
Expand Down
41 changes: 28 additions & 13 deletions paddle/fluid/inference/tensorrt/convert/op_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,42 @@ namespace tensorrt {
class OpConverter {
public:
OpConverter() {}
virtual void operator()(const framework::proto::OpDesc& op) {}

void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
std::string type = op.type();
auto* it = Registry<OpConverter>::Lookup(type);
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine);
(*it)(op);
}
// Converter logic for an op.
virtual void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) {}

// Convert a single fluid operaotr and add the corresponding layer to TRT.
void ConvertOp(const framework::proto::OpDesc& op,
const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) {
framework::OpDesc op_desc(op, nullptr, nullptr);

OpConverter* it{nullptr};

// convert fluid op to tensorrt layer
void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Run(op, engine);
if (op_desc.Type() == "mul") {
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
std::string Y = op_desc.Input("Y")[0];
if (parameters.count(Y)) {
it = Registry<OpConverter>::Lookup("fc");
}
}
if (!it) {
it = Registry<OpConverter>::Lookup(op_desc.Type());
}
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
it->SetEngine(engine);
(*it)(op, scope);
}

// convert fluid block to tensorrt network
void ConvertBlock(const framework::proto::BlockDesc& block,
TensorRTEngine* engine) {
const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) {
for (int i = 0; i < block.ops_size(); i++) {
const auto& op = block.ops(i);
OpConverter::Run(op, engine);
ConvertOp(op, parameters, scope, engine);
}
}

Expand Down
46 changes: 46 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/test_fc_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"

namespace paddle {
namespace inference {
namespace tensorrt {

TEST(fc_op, test) {
std::unordered_set<std::string> parameters({"mul-Y"});
framework::Scope scope;
TRTConvertValidation validator(20, parameters, scope, 1000);

validator.DeclInputVar("mul-X", nvinfer1::Dims4(8, 3, 1, 1));
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(3, 2));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(8, 2));

// Prepare Op description
framework::OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {"mul-X"});
desc.SetInput("Y", {"mul-Y"});
desc.SetOutput("Out", {"mul-Out"});

validator.SetOp(*desc.Proto());

validator.Execute(10);
}

} // namespace tensorrt
} // namespace inference
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ namespace inference {
namespace tensorrt {

TEST(MulOpConverter, main) {
TRTConvertValidation validator(10, 1000);
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

namespace paddle {
namespace inference {
Expand All @@ -27,7 +28,9 @@ TEST(OpConverter, ConvertBlock) {
conv2d_op->SetType("conv2d");

OpConverter converter;
converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/);
framework::Scope scope;
converter.ConvertBlock(*block->Proto(), {}, scope,
nullptr /*TensorRTEngine*/);
}

} // namespace tensorrt
Expand Down
40 changes: 26 additions & 14 deletions paddle/fluid/inference/tensorrt/convert/ut_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ class TRTConvertValidation {
public:
TRTConvertValidation() = delete;

explicit TRTConvertValidation(int batch_size, int workspace_size = 1024) {
TRTConvertValidation(int batch_size,
const std::unordered_set<std::string>& parameters,
framework::Scope& scope, int workspace_size = 1 << 10)
: parameters_(parameters), scope_(scope) {
// create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
engine_->InitNetwork();
Expand All @@ -76,19 +79,22 @@ class TRTConvertValidation {
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
}

// Declare a parameter varaible in the scope.
void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
}

void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
}

// Declare a variable in a fluid Scope.
void DeclVar(const std::string& name, const nvinfer1::Dims& dims) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);

// Init Fluid tensor.
std::vector<int> dim_vec(dims.nbDims);
for (int i = 0; i < dims.nbDims; i++) {
dim_vec[i] = dims.d[i];
}
std::vector<int> dim_vec(dims.d, dims.d + dims.nbDims);
auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec));
Expand All @@ -99,7 +105,7 @@ class TRTConvertValidation {
op_ = framework::OpRegistry::CreateOp(desc);

OpConverter op_converter;
op_converter.ConvertOp(desc, engine_.get());
op_converter.ConvertOp(desc, parameters_, scope_, engine_.get());

engine_->FreezeNetwork();

Expand All @@ -108,38 +114,43 @@ class TRTConvertValidation {

// Set Inputs.
for (const auto& input : op_desc_->InputArgumentNames()) {
if (parameters_.count(input)) continue;
auto* var = scope_.FindVar(input);
PADDLE_ENFORCE(var);
auto tensor = var->GetMutable<framework::LoDTensor>();

engine_->SetInputFromCPU(
input, static_cast<void*>(tensor->data<float>()),
input, static_cast<void*>(tensor->data<void>()),
sizeof(float) *
analysis::AccuDims(tensor->dims(), tensor->dims().size()));
}
}

void Execute(int batch_size) {
// Execute Fluid Op
// Execute TRT
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
engine_->Execute(batch_size);

op_->Run(scope_, place);
// Execute TRT.
engine_->Execute(batch_size);
cudaStreamSynchronize(*engine_->stream());

ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
const size_t output_space_size = 200;
for (const auto& output : op_desc_->OutputArgumentNames()) {
std::vector<float> fluid_out;
std::vector<float> trt_out(200);
engine_->GetOutputInCPU(output, &trt_out[0], 200 * sizeof(float));
std::vector<float> trt_out(output_space_size);
engine_->GetOutputInCPU(output, &trt_out[0],
output_space_size * sizeof(float));
cudaStreamSynchronize(*engine_->stream());

auto* var = scope_.FindVar(output);
auto tensor = var->GetMutable<framework::LoDTensor>();
framework::TensorToVector(*tensor, ctx, &fluid_out);
// Compare two output
ASSERT_FALSE(fluid_out.empty());
for (size_t i = 0; i < fluid_out.size(); i++) {
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 0.001);
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 1e-6);
}
}
}
Expand All @@ -149,9 +160,10 @@ class TRTConvertValidation {
private:
std::unique_ptr<TensorRTEngine> engine_;
cudaStream_t stream_;
framework::Scope scope_;
std::unique_ptr<framework::OperatorBase> op_;
std::unique_ptr<framework::OpDesc> op_desc_;
const std::unordered_set<std::string>& parameters_;
framework::Scope& scope_;
};

} // namespace tensorrt
Expand Down
Loading

0 comments on commit 0c0c5df

Please sign in to comment.