Skip to content

Commit

Permalink
bugfix/trt engine op (#11487)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Jun 15, 2018
1 parent 34ac0eb commit 5fd142c
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 37 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/op_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class OpConverter {
(*it)(op, scope, test_mode);
}

// convert fluid block to tensorrt network
// Convert a fluid block to tensorrt network, NOTE it just convert operators,
// the INetwork's inputs and outputs should specified in some other modules.
void ConvertBlock(const framework::proto::BlockDesc& block,
const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) {
Expand Down
32 changes: 23 additions & 9 deletions paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ class TensorRTEngine : public EngineBase {
nvinfer1::Weights w_;
};

TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream,
TensorRTEngine(int max_batch, int max_workspace,
cudaStream_t* stream = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
stream_(stream),
stream_(stream ? stream : &default_stream_),
logger_(logger) {}

virtual ~TensorRTEngine();
Expand Down Expand Up @@ -121,6 +122,8 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
int max_workspace_;
cudaStream_t* stream_;
// If stream_ is not set from outside, hold its own stream.
cudaStream_t default_stream_;
nvinfer1::ILogger& logger_;

std::vector<Buffer> buffers_;
Expand Down Expand Up @@ -165,20 +168,31 @@ class TensorRTEngine : public EngineBase {
*/
class TRT_EngineManager {
public:
TensorRTEngine* Create(int max_batch, int max_workspace,
cudaStream_t* stream) {
engines_.emplace_back(new TensorRTEngine(max_batch, max_workspace, stream));
return engines_.back().get();
bool HasEngine(const std::string& name) const {
return engines_.count(name) != 0;
}

// Get an engine called `name`.
TensorRTEngine* Get(const std::string& name) const {
return engines_.at(name).get();
}

// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream,
const std::string& name) {
auto* p = new TensorRTEngine(max_batch, max_workspace, stream);
engines_[name].reset(p);
return p;
}

void DeleteALl() {
for (auto& ptr : engines_) {
ptr.reset(nullptr);
for (auto& item : engines_) {
item.second.reset(nullptr);
}
}

private:
std::vector<std::unique_ptr<TensorRTEngine>> engines_;
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
};

} // namespace tensorrt
Expand Down
28 changes: 18 additions & 10 deletions paddle/fluid/operators/tensorrt_engine_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,25 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
} // namespace

template <typename DeviceContext, typename T>
void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
void TensorRTEngineKernel<DeviceContext, T>::Prepare(
const framework::ExecutionContext &context) const {
VLOG(4) << "Prepare engine";
// Get the ProgramDesc and pass to convert.
framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
max_batch_ = context.Attr<int>("max_batch");
int max_batch = context.Attr<int>("max_batch");
auto max_workspace = context.Attr<int>("max_workspace");
engine_ = Singleton<TRT_EngineManager>::Global().Create(
max_batch_, max_workspace, &stream_);
engine_->InitNetwork();
auto params = context.Attr<std::vector<std::string>>("parameters");
std::unordered_set<std::string> parameters;
for (const auto &param : params) {
parameters.insert(param);
}

// TODO(Superjomn) replace this with a different stream
auto *engine = Singleton<TRT_EngineManager>::Global().Create(
max_batch, max_workspace, nullptr /*engine hold its own stream*/,
context.Attr<std::string>("engine_uniq_key"));
engine->InitNetwork();

framework::BlockDesc block(nullptr /*programdesc*/, &block_desc);
// Add inputs
Expand All @@ -87,24 +95,23 @@ void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
"TensorRT engine only takes LoDTensor as input");
auto shape = var->GetShape();
engine_->DeclareInput(
engine->DeclareInput(
input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(var->GetShape()));
}

// TODO(Superjomn) parameters should be passed after analysised from outside.
inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock(
block_desc, {}, context.scope(), engine_);
block_desc, parameters, context.scope(), engine);

// Add outputs
VLOG(4) << "declare outputs";
for (auto &output : context.Outputs("Ys")) {
VLOG(4) << "declare output " << output;
engine_->DeclareOutput(output);
engine->DeclareOutput(output);
}

engine_->FreezeNetwork();
engine->FreezeNetwork();
}

class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
Expand All @@ -113,6 +120,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDuplicable();
AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine.");
AddAttr<int>("max_batch", "the maximum batch size.");
AddAttr<int>("max_workspace", "the maximum batch size.");
AddComment("TensorRT engine operator.");
Expand Down
33 changes: 17 additions & 16 deletions paddle/fluid/operators/tensorrt_engine_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"

namespace paddle {
namespace operators {

using inference::Singleton;
using inference::tensorrt::TRT_EngineManager;

class TensorRTEngineOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand All @@ -47,37 +51,39 @@ template <typename DeviceContext, typename T>
class TensorRTEngineKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
if (!engine_) {
auto engine_name = context.Attr<std::string>("engine_uniq_key");
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
Prepare(context);
}
auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name);
auto input_names = context.op().Inputs("Xs");
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
// Try to determine a batch_size
auto& tensor0 = inference::analysis::GetFromScope<framework::LoDTensor>(
context.scope(), input_names.front());
int batch_size = tensor0.dims()[0];
PADDLE_ENFORCE_LE(batch_size, max_batch_);
PADDLE_ENFORCE_LE(batch_size, context.Attr<int>("max_batch"));

// Convert input tensor from fluid to engine.
for (const auto& x : context.Inputs("Xs")) {
// convert input and copy to TRT engine's buffer
auto& t = inference::analysis::GetFromScope<framework::LoDTensor>(
context.scope(), x);
if (platform::is_cpu_place(t.place())) {
engine_->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
engine->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
} else {
engine_->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
engine->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
}
}
// Execute the engine.
PADDLE_ENFORCE_GT(batch_size, 0);
engine_->Execute(batch_size);
engine->Execute(batch_size);
// Convert output tensor from engine to fluid
for (const auto& y : context.Outputs("Ys")) {
// convert output and copy to fluid.
nvinfer1::ITensor* trt_t = engine_->GetITensor(y);
nvinfer1::ITensor* trt_t = engine->GetITensor(y);
auto dims = trt_t->getDimensions();
// Use the output ITensor's dims to reshape the Fluid Tensor.
std::vector<int> ddim(dims.d, dims.d + dims.nbDims);
Expand All @@ -89,27 +95,22 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
if (platform::is_cpu_place(fluid_t->place())) {
// TODO(Superjomn) change this float to dtype size.
engine_->GetOutputInCPU(
engine->GetOutputInCPU(
y, fluid_t->mutable_data<float>(platform::CPUPlace()),
size * sizeof(float));
} else {
engine_->GetOutputInGPU(
engine->GetOutputInGPU(
y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
size * sizeof(float));
}
}

cudaStreamSynchronize(stream_);
cudaStreamSynchronize(*engine->stream());
}

protected:
// Build the engine.
void Prepare(const framework::ExecutionContext& context) const;

private:
mutable cudaStream_t stream_;
mutable inference::tensorrt::TensorRTEngine* engine_{nullptr};
mutable int max_batch_{0};
};

} // namespace operators
Expand Down
99 changes: 98 additions & 1 deletion paddle/fluid/operators/tensorrt_engine_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ void SetAttr<int64_t>(framework::proto::OpDesc* op, const std::string& name,
attr->set_type(paddle::framework::proto::AttrType::LONG);
attr->set_l(data);
}
template <>
void SetAttr<std::vector<std::string>>(framework::proto::OpDesc* op,
const std::string& name,
const std::vector<std::string>& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRINGS);
for (const auto& s : data) {
attr->add_strings(s.c_str());
}
}

} // namespace

Expand Down Expand Up @@ -123,11 +134,15 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch", 30);
SetAttr<int>(engine_op_desc.Proto(), "max_batch", 100);
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 1 << 10);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({}));

LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
LOG(INFO) << "engine_op " << engine_op.get();

framework::Scope scope;
platform::CPUPlace place;
Expand All @@ -145,6 +160,88 @@ TEST(TensorRTEngineOp, manual) {
engine_op->Run(scope, place);
}

void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);

auto* block_ = program.Proto()->add_blocks();
block_->set_idx(0);
block_->set_parent_idx(-1);

using shape_t = std::vector<int64_t>;

LOG(INFO) << "create block desc";
framework::BlockDesc block_desc(&program, block_);

auto AddFCLayer = [&](const std::string& x_name, const std::string& y_name,
const std::string& z_name, bool x_created,
const shape_t& x_shape, const shape_t& y_shape,
const shape_t& z_shape) {

LOG(INFO) << "create fc op";
auto* fc = block_desc.AppendOp();
fc->SetType("mul");
fc->SetInput("X", std::vector<std::string>({x_name}));
fc->SetInput("Y", std::vector<std::string>({y_name}));
fc->SetOutput("Out", std::vector<std::string>({z_name}));

// Set inputs' variable shape in BlockDesc
if (!x_created) {
AddTensorToBlockDesc(block_, x_name,
std::vector<int64_t>({batch_size, input_dim, 1, 1}));
}
AddTensorToBlockDesc(block_, y_name,
std::vector<int64_t>({input_dim, output_dim}));
AddTensorToBlockDesc(block_, z_name,
std::vector<int64_t>({batch_size, output_dim}));

// Prepare variables.
if (!x_created) {
CreateCPUTensor(&scope, x_name, std::vector<int64_t>(x_shape));
}
CreateCPUTensor(&scope, y_name, std::vector<int64_t>(y_shape));
CreateCPUTensor(&scope, z_name, std::vector<int64_t>(z_shape));

// It is wired, need to copy manually.
*block_->add_ops() = *fc->Proto();
};

// Test with 4 layer FC
AddFCLayer("x0", "y0", "z0", false, {batch_size, input_dim},
{input_dim, output_dim}, {batch_size, output_dim});
AddFCLayer("z0", "y1", "z1", true, {}, {output_dim, output_dim},
{batch_size, output_dim});
AddFCLayer("z1", "y2", "z2", true, {}, {output_dim, output_dim},
{batch_size, output_dim});
AddFCLayer("z2", "y3", "z3", true, {}, {output_dim, output_dim},
{batch_size, output_dim});

LOG(INFO) << "create tensorrt desc";
framework::OpDesc engine_op_desc(nullptr);
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z3"}));

SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 2 << 10);
SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "b_engine");

auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());

// Execute them.
engine_op->Run(scope, place);
}

// Test with a larger FC layer.
TEST(TensorRTEngineOp, fc) { Execute(40, 256, 256); }

} // namespace operators
} // namespace paddle

Expand Down

0 comments on commit 5fd142c

Please sign in to comment.