Skip to content

Commit

Permalink
FDTensor support GPU device (#190)
Browse files Browse the repository at this point in the history
* fdtensor support GPU

* TRT backend support GPU FDTensor

* FDHostAllocator add FASTDEPLOY_DECL

* fix FDTensor Data

* fix FDTensor dtype

Co-authored-by: Jason <jiangjiajun@baidu.com>
  • Loading branch information
heliqi and jiangjiajun committed Sep 8, 2022
1 parent bc8e9e4 commit 4d1f264
Show file tree
Hide file tree
Showing 17 changed files with 430 additions and 151 deletions.
29 changes: 16 additions & 13 deletions csrc/fastdeploy/backends/ort/ort_backend.cc
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

#include "fastdeploy/backends/ort/ort_backend.h"

#include <memory>

#include "fastdeploy/backends/ort/ops/multiclass_nms.h"
#include "fastdeploy/backends/ort/utils.h"
#include "fastdeploy/utils/utils.h"
Expand Down Expand Up @@ -164,33 +166,34 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
return true;
}

void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name) {
void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
const std::string& name) {
const auto info = value.GetTensorTypeAndShapeInfo();
const auto data_type = info.GetElementType();
size_t numel = info.GetElementCount();
auto shape = info.GetShape();
FDDataType dtype;

if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
tensor->Allocate(info.GetShape(), FDDataType::FP32, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(float));
dtype = FDDataType::FP32;
numel *= sizeof(float);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
tensor->Allocate(info.GetShape(), FDDataType::INT32, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(int32_t));
dtype = FDDataType::INT32;
numel *= sizeof(int32_t);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
tensor->Allocate(info.GetShape(), FDDataType::INT64, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(int64_t));
dtype = FDDataType::INT64;
numel *= sizeof(int64_t);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
tensor->Allocate(info.GetShape(), FDDataType::FP64, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(double));
dtype = FDDataType::FP64;
numel *= sizeof(double);
} else {
FDASSERT(
false,
"Unrecognized data type of %d while calling OrtBackend::CopyToCpu().",
data_type);
}
tensor->Resize(shape, dtype, name);
memcpy(tensor->MutableData(), value.GetTensorData<void*>(), numel);
}

bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
Expand Down
3 changes: 2 additions & 1 deletion csrc/fastdeploy/backends/ort/ort_backend.h
Expand Up @@ -88,6 +88,7 @@ class OrtBackend : public BaseBackend {
Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
#endif
OrtBackendOption option_;
void CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name);
void CopyToCpu(const Ort::Value& value, FDTensor* tensor,
const std::string& name);
};
} // namespace fastdeploy
13 changes: 10 additions & 3 deletions csrc/fastdeploy/backends/paddle/paddle_backend.cc
Expand Up @@ -79,16 +79,23 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
}

TensorInfo PaddleBackend::GetInputInfo(int index) {
FDASSERT(index < NumInputs(), "The index: %d should less than the number of inputs: %d.", index, NumInputs());
FDASSERT(index < NumInputs(),
"The index: %d should less than the number of inputs: %d.", index,
NumInputs());
return inputs_desc_[index];
}

std::vector<TensorInfo> PaddleBackend::GetInputInfo() { return inputs_desc_; }

TensorInfo PaddleBackend::GetOutputInfo(int index) {
FDASSERT(index < NumOutputs(),
"The index: %d should less than the number of outputs %d.", index, NumOutputs());
"The index: %d should less than the number of outputs %d.", index,
NumOutputs());
return outputs_desc_[index];
}

std::vector<TensorInfo> PaddleBackend::GetOutputInfo() { return outputs_desc_; }

bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) {
if (inputs.size() != inputs_desc_.size()) {
Expand All @@ -100,7 +107,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,

for (size_t i = 0; i < inputs.size(); ++i) {
auto handle = predictor_->GetInputHandle(inputs[i].name);
ShareTensorFromCpu(handle.get(), inputs[i]);
ShareTensorFromFDTensor(handle.get(), inputs[i]);
}

predictor_->Run();
Expand Down
7 changes: 6 additions & 1 deletion csrc/fastdeploy/backends/paddle/paddle_backend.h
Expand Up @@ -44,8 +44,11 @@ struct PaddleBackendOption {
std::vector<std::string> delete_pass_names = {};
};

// convert FD device to paddle place type
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device);

// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);

// Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
Expand All @@ -72,6 +75,8 @@ class PaddleBackend : public BaseBackend {

TensorInfo GetInputInfo(int index);
TensorInfo GetOutputInfo(int index);
std::vector<TensorInfo> GetInputInfo();
std::vector<TensorInfo> GetOutputInfo();

private:
paddle_infer::Config config_;
Expand Down
28 changes: 21 additions & 7 deletions csrc/fastdeploy/backends/paddle/util.cc
Expand Up @@ -15,23 +15,33 @@
#include "fastdeploy/backends/paddle/paddle_backend.h"

namespace fastdeploy {
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor) {
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device) {
if (device == Device::GPU) {
return paddle_infer::PlaceType::kGPU;
}
return paddle_infer::PlaceType::kCPU;
}

void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
FDTensor& fd_tensor) {
std::vector<int> shape(fd_tensor.shape.begin(), fd_tensor.shape.end());
tensor->Reshape(shape);
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
if (fd_tensor.dtype == FDDataType::FP32) {
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
} else if (fd_tensor.dtype == FDDataType::INT32) {
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
} else if (fd_tensor.dtype == FDDataType::INT64) {
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor.dtype).c_str());
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor.dtype).c_str());
}

void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
Expand All @@ -51,7 +61,8 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor->dtype).c_str());
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor->dtype).c_str());
}

FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
Expand All @@ -65,7 +76,10 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
} else if (dtype == paddle_infer::UINT8) {
fd_dtype = FDDataType::UINT8;
} else {
FDASSERT(false, "Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.", int(dtype));
FDASSERT(
false,
"Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.",
int(dtype));
}
return fd_dtype;
}
Expand Down
94 changes: 57 additions & 37 deletions csrc/fastdeploy/backends/tensorrt/trt_backend.cc
Expand Up @@ -13,9 +13,11 @@
// limitations under the License.

#include "fastdeploy/backends/tensorrt/trt_backend.h"

#include <cstring>

#include "NvInferSafeRuntime.h"
#include "fastdeploy/utils/utils.h"
#include <cstring>
#ifdef ENABLE_PADDLE_FRONTEND
#include "paddle2onnx/converter.h"
#endif
Expand Down Expand Up @@ -210,9 +212,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
outputs_desc_.resize(onnx_reader.num_outputs);
for (int i = 0; i < onnx_reader.num_inputs; ++i) {
std::string name(onnx_reader.inputs[i].name);
std::vector<int64_t> shape(onnx_reader.inputs[i].shape,
onnx_reader.inputs[i].shape +
onnx_reader.inputs[i].rank);
std::vector<int64_t> shape(
onnx_reader.inputs[i].shape,
onnx_reader.inputs[i].shape + onnx_reader.inputs[i].rank);
inputs_desc_[i].name = name;
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
Expand All @@ -231,9 +233,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,

for (int i = 0; i < onnx_reader.num_outputs; ++i) {
std::string name(onnx_reader.outputs[i].name);
std::vector<int64_t> shape(onnx_reader.outputs[i].shape,
onnx_reader.outputs[i].shape +
onnx_reader.outputs[i].rank);
std::vector<int64_t> shape(
onnx_reader.outputs[i].shape,
onnx_reader.outputs[i].shape + onnx_reader.outputs[i].rank);
outputs_desc_[i].name = name;
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype =
Expand Down Expand Up @@ -286,24 +288,8 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
BuildTrtEngine();
}

AllocateBufferInDynamicShape(inputs, outputs);
std::vector<void*> input_binds(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(inputs[i].Data());
std::vector<int32_t> casted_data(data, data + inputs[i].Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
static_cast<void*>(casted_data.data()),
inputs[i].Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
inputs[i].Data(), inputs[i].Nbytes(),
cudaMemcpyHostToDevice, stream_) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
}
}
SetInputs(inputs);
AllocateOutputsBuffer(outputs);
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false;
Expand Down Expand Up @@ -339,18 +325,50 @@ void TrtBackend::GetInputOutputInfo() {
bindings_.resize(num_binds);
}

void TrtBackend::AllocateBufferInDynamicShape(
const std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs) {
void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
for (const auto& item : inputs) {
auto idx = engine_->getBindingIndex(item.name.c_str());
std::vector<int> shape(item.shape.begin(), item.shape.end());
auto dims = ToDims(shape);
context_->setBindingDimensions(idx, dims);
if (item.Nbytes() > inputs_buffer_[item.name].nbBytes()) {

if (item.device == Device::GPU) {
if (item.dtype == FDDataType::INT64) {
// TODO(liqi): cast int64 to int32
// TRT don't support INT64
FDASSERT(false,
"TRT don't support INT64 input on GPU, "
"please use INT32 input");
} else {
// no copy
inputs_buffer_[item.name].SetExternalData(dims, item.Data());
}
} else {
// Allocate input buffer memory
inputs_buffer_[item.name].resize(dims);
bindings_[idx] = inputs_buffer_[item.name].data();

// copy from cpu to gpu
if (item.dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
std::vector<int32_t> casted_data(data, data + item.Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(),
static_cast<void*>(casted_data.data()),
item.Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(),
item.Nbytes(), cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
}
}
// binding input buffer
bindings_[idx] = inputs_buffer_[item.name].data();
}
}

void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
if (outputs->size() != outputs_desc_.size()) {
outputs->resize(outputs_desc_.size());
}
Expand All @@ -365,13 +383,15 @@ void TrtBackend::AllocateBufferInDynamicShape(
"Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str());
auto ori_idx = iter->second;
std::vector<int64_t> shape(output_dims.d, output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].Allocate(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name);
if ((*outputs)[ori_idx].Nbytes() >
outputs_buffer_[outputs_desc_[i].name].nbBytes()) {
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
}
// set user's outputs info
std::vector<int64_t> shape(output_dims.d,
output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
outputs_desc_[i].name);
// Allocate output buffer memory
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
// binding output buffer
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
}
}

Expand Down Expand Up @@ -580,4 +600,4 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
info.dtype = GetFDDataType(outputs_desc_[index].dtype);
return info;
}
} // namespace fastdeploy
} // namespace fastdeploy
9 changes: 5 additions & 4 deletions csrc/fastdeploy/backends/tensorrt/trt_backend.h
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <cuda_runtime_api.h>

#include <iostream>
#include <map>
#include <string>
Expand All @@ -23,7 +25,6 @@
#include "NvOnnxParser.h"
#include "fastdeploy/backends/backend.h"
#include "fastdeploy/backends/tensorrt/utils.h"
#include <cuda_runtime_api.h>

namespace fastdeploy {

Expand Down Expand Up @@ -109,12 +110,12 @@ class TrtBackend : public BaseBackend {
std::map<std::string, ShapeRangeInfo> shape_range_info_;

void GetInputOutputInfo();
void AllocateBufferInDynamicShape(const std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs);
bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer);
bool BuildTrtEngine();
bool LoadTrtCache(const std::string& trt_engine_file);
int ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs);
void SetInputs(const std::vector<FDTensor>& inputs);
void AllocateOutputsBuffer(std::vector<FDTensor>* outputs);
};

} // namespace fastdeploy
} // namespace fastdeploy

0 comments on commit 4d1f264

Please sign in to comment.