Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FDTensor support GPU device #190

Merged
merged 13 commits into from Sep 8, 2022
29 changes: 16 additions & 13 deletions csrc/fastdeploy/backends/ort/ort_backend.cc
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

#include "fastdeploy/backends/ort/ort_backend.h"

#include <memory>

#include "fastdeploy/backends/ort/ops/multiclass_nms.h"
#include "fastdeploy/backends/ort/utils.h"
#include "fastdeploy/utils/utils.h"
Expand Down Expand Up @@ -164,33 +166,34 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
return true;
}

void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name) {
void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
const std::string& name) {
const auto info = value.GetTensorTypeAndShapeInfo();
const auto data_type = info.GetElementType();
size_t numel = info.GetElementCount();
auto shape = info.GetShape();
FDDataType dtype;

if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
tensor->Allocate(info.GetShape(), FDDataType::FP32, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(float));
dtype = FDDataType::FP32;
numel *= sizeof(float);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
tensor->Allocate(info.GetShape(), FDDataType::INT32, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(int32_t));
dtype = FDDataType::INT32;
numel *= sizeof(int32_t);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
tensor->Allocate(info.GetShape(), FDDataType::INT64, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(int64_t));
dtype = FDDataType::INT64;
numel *= sizeof(int64_t);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
tensor->Allocate(info.GetShape(), FDDataType::FP64, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(double));
dtype = FDDataType::FP64;
numel *= sizeof(double);
} else {
FDASSERT(
false,
"Unrecognized data type of %d while calling OrtBackend::CopyToCpu().",
data_type);
}
tensor->Resize(shape, dtype, name);
memcpy(tensor->MutableData(), value.GetTensorData<void*>(), numel);
heliqi marked this conversation as resolved.
Show resolved Hide resolved
}

bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
Expand Down
3 changes: 2 additions & 1 deletion csrc/fastdeploy/backends/ort/ort_backend.h
Expand Up @@ -88,6 +88,7 @@ class OrtBackend : public BaseBackend {
Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
#endif
OrtBackendOption option_;
void CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name);
void CopyToCpu(const Ort::Value& value, FDTensor* tensor,
const std::string& name);
};
} // namespace fastdeploy
13 changes: 10 additions & 3 deletions csrc/fastdeploy/backends/paddle/paddle_backend.cc
Expand Up @@ -79,16 +79,23 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
}

TensorInfo PaddleBackend::GetInputInfo(int index) {
FDASSERT(index < NumInputs(), "The index: %d should less than the number of inputs: %d.", index, NumInputs());
FDASSERT(index < NumInputs(),
"The index: %d should less than the number of inputs: %d.", index,
NumInputs());
return inputs_desc_[index];
}

std::vector<TensorInfo> PaddleBackend::GetInputInfo() { return inputs_desc_; }

TensorInfo PaddleBackend::GetOutputInfo(int index) {
FDASSERT(index < NumOutputs(),
"The index: %d should less than the number of outputs %d.", index, NumOutputs());
"The index: %d should less than the number of outputs %d.", index,
NumOutputs());
return outputs_desc_[index];
}

std::vector<TensorInfo> PaddleBackend::GetOutputInfo() { return outputs_desc_; }

bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) {
if (inputs.size() != inputs_desc_.size()) {
Expand All @@ -100,7 +107,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,

for (size_t i = 0; i < inputs.size(); ++i) {
auto handle = predictor_->GetInputHandle(inputs[i].name);
ShareTensorFromCpu(handle.get(), inputs[i]);
ShareTensorFromFDTensor(handle.get(), inputs[i]);
}

predictor_->Run();
Expand Down
7 changes: 6 additions & 1 deletion csrc/fastdeploy/backends/paddle/paddle_backend.h
Expand Up @@ -44,8 +44,11 @@ struct PaddleBackendOption {
std::vector<std::string> delete_pass_names = {};
};

// convert FD device to paddle place type
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device);

// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);

// Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
Expand All @@ -72,6 +75,8 @@ class PaddleBackend : public BaseBackend {

TensorInfo GetInputInfo(int index);
TensorInfo GetOutputInfo(int index);
std::vector<TensorInfo> GetInputInfo();
std::vector<TensorInfo> GetOutputInfo();

private:
paddle_infer::Config config_;
Expand Down
28 changes: 21 additions & 7 deletions csrc/fastdeploy/backends/paddle/util.cc
Expand Up @@ -15,23 +15,33 @@
#include "fastdeploy/backends/paddle/paddle_backend.h"

namespace fastdeploy {
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor) {
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device) {
if (device == Device::GPU) {
return paddle_infer::PlaceType::kGPU;
}
return paddle_infer::PlaceType::kCPU;
}

void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
FDTensor& fd_tensor) {
std::vector<int> shape(fd_tensor.shape.begin(), fd_tensor.shape.end());
tensor->Reshape(shape);
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
if (fd_tensor.dtype == FDDataType::FP32) {
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
} else if (fd_tensor.dtype == FDDataType::INT32) {
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
} else if (fd_tensor.dtype == FDDataType::INT64) {
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor.dtype).c_str());
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor.dtype).c_str());
}

void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
Expand All @@ -51,7 +61,8 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor->dtype).c_str());
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor->dtype).c_str());
}

FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
Expand All @@ -65,7 +76,10 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
} else if (dtype == paddle_infer::UINT8) {
fd_dtype = FDDataType::UINT8;
} else {
FDASSERT(false, "Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.", int(dtype));
FDASSERT(
false,
"Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.",
int(dtype));
}
return fd_dtype;
}
Expand Down
94 changes: 57 additions & 37 deletions csrc/fastdeploy/backends/tensorrt/trt_backend.cc
Expand Up @@ -13,9 +13,11 @@
// limitations under the License.

#include "fastdeploy/backends/tensorrt/trt_backend.h"

#include <cstring>

#include "NvInferSafeRuntime.h"
#include "fastdeploy/utils/utils.h"
#include <cstring>
#ifdef ENABLE_PADDLE_FRONTEND
#include "paddle2onnx/converter.h"
#endif
Expand Down Expand Up @@ -210,9 +212,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
outputs_desc_.resize(onnx_reader.num_outputs);
for (int i = 0; i < onnx_reader.num_inputs; ++i) {
std::string name(onnx_reader.inputs[i].name);
std::vector<int64_t> shape(onnx_reader.inputs[i].shape,
onnx_reader.inputs[i].shape +
onnx_reader.inputs[i].rank);
std::vector<int64_t> shape(
onnx_reader.inputs[i].shape,
onnx_reader.inputs[i].shape + onnx_reader.inputs[i].rank);
inputs_desc_[i].name = name;
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
Expand All @@ -231,9 +233,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,

for (int i = 0; i < onnx_reader.num_outputs; ++i) {
std::string name(onnx_reader.outputs[i].name);
std::vector<int64_t> shape(onnx_reader.outputs[i].shape,
onnx_reader.outputs[i].shape +
onnx_reader.outputs[i].rank);
std::vector<int64_t> shape(
onnx_reader.outputs[i].shape,
onnx_reader.outputs[i].shape + onnx_reader.outputs[i].rank);
outputs_desc_[i].name = name;
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype =
Expand Down Expand Up @@ -286,24 +288,8 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
BuildTrtEngine();
}

AllocateBufferInDynamicShape(inputs, outputs);
std::vector<void*> input_binds(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(inputs[i].Data());
std::vector<int32_t> casted_data(data, data + inputs[i].Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
static_cast<void*>(casted_data.data()),
inputs[i].Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
inputs[i].Data(), inputs[i].Nbytes(),
cudaMemcpyHostToDevice, stream_) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
}
}
SetInputs(inputs);
AllocateOutputsBuffer(outputs);
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false;
Expand Down Expand Up @@ -339,18 +325,50 @@ void TrtBackend::GetInputOutputInfo() {
bindings_.resize(num_binds);
}

void TrtBackend::AllocateBufferInDynamicShape(
const std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs) {
void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
for (const auto& item : inputs) {
auto idx = engine_->getBindingIndex(item.name.c_str());
std::vector<int> shape(item.shape.begin(), item.shape.end());
auto dims = ToDims(shape);
context_->setBindingDimensions(idx, dims);
if (item.Nbytes() > inputs_buffer_[item.name].nbBytes()) {

if (item.device == Device::GPU) {
if (item.dtype == FDDataType::INT64) {
// TODO(liqi): cast int64 to int32
// TRT don't support INT64
FDASSERT(false,
"TRT don't support INT64 input on GPU, "
"please use INT32 input");
} else {
// no copy
inputs_buffer_[item.name].SetExternalData(dims, item.Data());
}
} else {
// Allocate input buffer memory
inputs_buffer_[item.name].resize(dims);
bindings_[idx] = inputs_buffer_[item.name].data();

// copy from cpu to gpu
if (item.dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
std::vector<int32_t> casted_data(data, data + item.Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(),
static_cast<void*>(casted_data.data()),
item.Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(),
item.Nbytes(), cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
}
heliqi marked this conversation as resolved.
Show resolved Hide resolved
}
// binding input buffer
bindings_[idx] = inputs_buffer_[item.name].data();
}
}

void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
if (outputs->size() != outputs_desc_.size()) {
outputs->resize(outputs_desc_.size());
}
Expand All @@ -365,13 +383,15 @@ void TrtBackend::AllocateBufferInDynamicShape(
"Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str());
auto ori_idx = iter->second;
std::vector<int64_t> shape(output_dims.d, output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].Allocate(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name);
if ((*outputs)[ori_idx].Nbytes() >
outputs_buffer_[outputs_desc_[i].name].nbBytes()) {
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
}
// set user's outputs info
std::vector<int64_t> shape(output_dims.d,
output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
outputs_desc_[i].name);
// Allocate output buffer memory
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
// binding output buffer
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
}
}

Expand Down Expand Up @@ -580,4 +600,4 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
info.dtype = GetFDDataType(outputs_desc_[index].dtype);
return info;
}
} // namespace fastdeploy
} // namespace fastdeploy
9 changes: 5 additions & 4 deletions csrc/fastdeploy/backends/tensorrt/trt_backend.h
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <cuda_runtime_api.h>

#include <iostream>
#include <map>
#include <string>
Expand All @@ -23,7 +25,6 @@
#include "NvOnnxParser.h"
#include "fastdeploy/backends/backend.h"
#include "fastdeploy/backends/tensorrt/utils.h"
#include <cuda_runtime_api.h>

namespace fastdeploy {

Expand Down Expand Up @@ -109,12 +110,12 @@ class TrtBackend : public BaseBackend {
std::map<std::string, ShapeRangeInfo> shape_range_info_;

void GetInputOutputInfo();
void AllocateBufferInDynamicShape(const std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs);
bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer);
bool BuildTrtEngine();
bool LoadTrtCache(const std::string& trt_engine_file);
int ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs);
void SetInputs(const std::vector<FDTensor>& inputs);
void AllocateOutputsBuffer(std::vector<FDTensor>* outputs);
};

} // namespace fastdeploy
} // namespace fastdeploy