diff --git a/fastdeploy/backends/paddle/paddle_backend.cc b/fastdeploy/backends/paddle/paddle_backend.cc index 674a37954..25951dae5 100644 --- a/fastdeploy/backends/paddle/paddle_backend.cc +++ b/fastdeploy/backends/paddle/paddle_backend.cc @@ -19,6 +19,7 @@ namespace fastdeploy { void PaddleBackend::BuildOption(const PaddleBackendOption& option) { + option_ = option; if (option.use_gpu) { config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id); if (option.enable_trt) { @@ -190,6 +191,7 @@ bool PaddleBackend::Infer(std::vector& inputs, outputs->resize(outputs_desc_.size()); for (size_t i = 0; i < outputs_desc_.size(); ++i) { auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name); + (*outputs)[i].is_pinned_memory = option_.enable_pinned_memory; CopyTensorToCpu(handle, &((*outputs)[i])); } return true; diff --git a/fastdeploy/backends/paddle/paddle_backend.h b/fastdeploy/backends/paddle/paddle_backend.h index 78b939fea..1d4f53db3 100755 --- a/fastdeploy/backends/paddle/paddle_backend.h +++ b/fastdeploy/backends/paddle/paddle_backend.h @@ -53,6 +53,7 @@ struct PaddleBackendOption { int gpu_mem_init_size = 100; // gpu device id int gpu_id = 0; + bool enable_pinned_memory = false; std::vector delete_pass_names = {}; }; @@ -105,6 +106,7 @@ class PaddleBackend : public BaseBackend { std::map>* opt_shape) const; void SetTRTDynamicShapeToConfig(const PaddleBackendOption& option); #endif + PaddleBackendOption option_; paddle_infer::Config config_; std::shared_ptr predictor_; std::vector inputs_desc_; diff --git a/fastdeploy/backends/paddle/util.cc b/fastdeploy/backends/paddle/util.cc index 005f9966b..d8cc1dbb9 100644 --- a/fastdeploy/backends/paddle/util.cc +++ b/fastdeploy/backends/paddle/util.cc @@ -67,7 +67,7 @@ void CopyTensorToCpu(std::unique_ptr& tensor, std::vector shape; auto tmp_shape = tensor->shape(); shape.assign(tmp_shape.begin(), tmp_shape.end()); - fd_tensor->Allocate(shape, fd_dtype, tensor->name()); + fd_tensor->Resize(shape, fd_dtype, tensor->name()); if (fd_tensor->dtype == FDDataType::FP32) { tensor->CopyToCpu(static_cast(fd_tensor->MutableData())); return; diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc index 395215db0..363a9d1ce 100644 --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -306,17 +306,21 @@ bool TrtBackend::Infer(std::vector& inputs, SetInputs(inputs); AllocateOutputsBuffer(outputs); + if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) { FDERROR << "Failed to Infer with TensorRT." << std::endl; return false; } for (size_t i = 0; i < outputs->size(); ++i) { FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(), - outputs_buffer_[(*outputs)[i].name].data(), + outputs_device_buffer_[(*outputs)[i].name].data(), (*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost, stream_) == 0, "[ERROR] Error occurs while copy memory from GPU to CPU."); } + FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess, + "[ERROR] Error occurs while sync cuda stream."); + return true; } @@ -332,10 +336,10 @@ void TrtBackend::GetInputOutputInfo() { auto dtype = engine_->getBindingDataType(i); if (engine_->bindingIsInput(i)) { inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype}); - inputs_buffer_[name] = FDDeviceBuffer(dtype); + inputs_device_buffer_[name] = FDDeviceBuffer(dtype); } else { outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype}); - outputs_buffer_[name] = FDDeviceBuffer(dtype); + outputs_device_buffer_[name] = FDDeviceBuffer(dtype); } } bindings_.resize(num_binds); @@ -357,30 +361,31 @@ void TrtBackend::SetInputs(const std::vector& inputs) { "please use INT32 input"); } else { // no copy - inputs_buffer_[item.name].SetExternalData(dims, item.Data()); + inputs_device_buffer_[item.name].SetExternalData(dims, item.Data()); } } else { // Allocate input buffer memory - inputs_buffer_[item.name].resize(dims); + inputs_device_buffer_[item.name].resize(dims); // copy from cpu to gpu if (item.dtype == FDDataType::INT64) { int64_t* data = static_cast(const_cast(item.Data())); std::vector casted_data(data, data + item.Numel()); - FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), + FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(), static_cast(casted_data.data()), item.Nbytes() / 2, cudaMemcpyHostToDevice, stream_) == 0, "Error occurs while copy memory from CPU to GPU."); } else { - FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(), + FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(), + item.Data(), item.Nbytes(), cudaMemcpyHostToDevice, stream_) == 0, "Error occurs while copy memory from CPU to GPU."); } } // binding input buffer - bindings_[idx] = inputs_buffer_[item.name].data(); + bindings_[idx] = inputs_device_buffer_[item.name].data(); } } @@ -399,15 +404,19 @@ void TrtBackend::AllocateOutputsBuffer(std::vector* outputs) { "Cannot find output: %s of tensorrt network from the original model.", outputs_desc_[i].name.c_str()); auto ori_idx = iter->second; + // set user's outputs info std::vector shape(output_dims.d, output_dims.d + output_dims.nbDims); + (*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory; (*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name); + // Allocate output buffer memory - outputs_buffer_[outputs_desc_[i].name].resize(output_dims); + outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims); + // binding output buffer - bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data(); + bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data(); } } diff --git a/fastdeploy/backends/tensorrt/trt_backend.h b/fastdeploy/backends/tensorrt/trt_backend.h index ad3ace6a4..09f18b2df 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.h +++ b/fastdeploy/backends/tensorrt/trt_backend.h @@ -70,6 +70,7 @@ struct TrtBackendOption { std::map> min_shape; std::map> opt_shape; std::string serialize_file = ""; + bool enable_pinned_memory = false; // inside parameter, maybe remove next version bool remove_multiclass_nms_ = false; @@ -118,8 +119,8 @@ class TrtBackend : public BaseBackend { std::vector bindings_; std::vector inputs_desc_; std::vector outputs_desc_; - std::map inputs_buffer_; - std::map outputs_buffer_; + std::map inputs_device_buffer_; + std::map outputs_device_buffer_; std::string calibration_str_; diff --git a/fastdeploy/backends/tensorrt/utils.h b/fastdeploy/backends/tensorrt/utils.h index f76230526..7f2e7344b 100644 --- a/fastdeploy/backends/tensorrt/utils.h +++ b/fastdeploy/backends/tensorrt/utils.h @@ -206,6 +206,8 @@ class FDGenericBuffer { }; using FDDeviceBuffer = FDGenericBuffer; +using FDDeviceHostBuffer = FDGenericBuffer; class FDTrtLogger : public nvinfer1::ILogger { public: diff --git a/fastdeploy/core/allocate.cc b/fastdeploy/core/allocate.cc index 285642d5c..e71cd3443 100644 --- a/fastdeploy/core/allocate.cc +++ b/fastdeploy/core/allocate.cc @@ -34,6 +34,12 @@ bool FDDeviceAllocator::operator()(void** ptr, size_t size) const { void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); } +bool FDDeviceHostAllocator::operator()(void** ptr, size_t size) const { + return cudaMallocHost(ptr, size) == cudaSuccess; +} + +void FDDeviceHostFree::operator()(void* ptr) const { cudaFreeHost(ptr); } + #endif } // namespace fastdeploy diff --git a/fastdeploy/core/allocate.h b/fastdeploy/core/allocate.h index c48bb7cee..1e88787f4 100644 --- a/fastdeploy/core/allocate.h +++ b/fastdeploy/core/allocate.h @@ -45,6 +45,16 @@ class FASTDEPLOY_DECL FDDeviceFree { void operator()(void* ptr) const; }; +class FASTDEPLOY_DECL FDDeviceHostAllocator { + public: + bool operator()(void** ptr, size_t size) const; +}; + +class FASTDEPLOY_DECL FDDeviceHostFree { + public: + void operator()(void* ptr) const; +}; + #endif } // namespace fastdeploy diff --git a/fastdeploy/core/fd_tensor.cc b/fastdeploy/core/fd_tensor.cc index 1161d2b0e..e98a81e1b 100644 --- a/fastdeploy/core/fd_tensor.cc +++ b/fastdeploy/core/fd_tensor.cc @@ -207,9 +207,27 @@ bool FDTensor::ReallocFn(size_t nbytes) { "-DWITH_GPU=ON," "so this is an unexpected problem happend."); #endif + } else { + if (is_pinned_memory) { +#ifdef WITH_GPU + size_t original_nbytes = Nbytes(); + if (nbytes > original_nbytes) { + if (buffer_ != nullptr) { + FDDeviceHostFree()(buffer_); + } + FDDeviceHostAllocator()(&buffer_, nbytes); + } + return buffer_ != nullptr; +#else + FDASSERT(false, + "The FastDeploy FDTensor allocator didn't compile under " + "-DWITH_GPU=ON," + "so this is an unexpected problem happend."); +#endif + } + buffer_ = realloc(buffer_, nbytes); + return buffer_ != nullptr; } - buffer_ = realloc(buffer_, nbytes); - return buffer_ != nullptr; } void FDTensor::FreeFn() { @@ -220,7 +238,13 @@ void FDTensor::FreeFn() { FDDeviceFree()(buffer_); #endif } else { - FDHostFree()(buffer_); + if (is_pinned_memory) { +#ifdef WITH_GPU + FDDeviceHostFree()(buffer_); +#endif + } else { + FDHostFree()(buffer_); + } } buffer_ = nullptr; } @@ -231,7 +255,6 @@ void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) { #ifdef WITH_GPU FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0, "[ERROR] Error occurs while copy memory from GPU to GPU"); - #else FDASSERT(false, "The FastDeploy didn't compile under -DWITH_GPU=ON, so copying " @@ -239,7 +262,19 @@ void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) { "an unexpected problem happend."); #endif } else { - std::memcpy(dst, src, nbytes); + if (is_pinned_memory) { +#ifdef WITH_GPU + FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyHostToHost) == 0, + "[ERROR] Error occurs while copy memory from host to host"); +#else + FDASSERT(false, + "The FastDeploy didn't compile under -DWITH_GPU=ON, so copying " + "gpu buffer is " + "an unexpected problem happend."); +#endif + } else { + std::memcpy(dst, src, nbytes); + } } } diff --git a/fastdeploy/core/fd_tensor.h b/fastdeploy/core/fd_tensor.h index 7e8bb7851..1619fe271 100644 --- a/fastdeploy/core/fd_tensor.h +++ b/fastdeploy/core/fd_tensor.h @@ -40,6 +40,10 @@ struct FASTDEPLOY_DECL FDTensor { // so we can skip data transfer, which may improve the efficience Device device = Device::CPU; + // Whether the data buffer is in pinned memory, which is allocated + // with cudaMallocHost() + bool is_pinned_memory = false; + // if the external data is not on CPU, we use this temporary buffer // to transfer data to CPU at some cases we need to visit the // other devices' data diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index 6d8eb7804..70f9a5917 100755 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -44,6 +44,8 @@ void BindRuntime(pybind11::module& m) { .def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16) .def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16) .def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile) + .def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory) + .def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory) .def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape) .def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape) .def_readwrite("model_file", &RuntimeOption::model_file) @@ -200,6 +202,7 @@ void BindRuntime(pybind11::module& m) { .def("numel", &FDTensor::Numel) .def("nbytes", &FDTensor::Nbytes) .def_readwrite("name", &FDTensor::name) + .def_readwrite("is_pinned_memory", &FDTensor::is_pinned_memory) .def_readonly("shape", &FDTensor::shape) .def_readonly("dtype", &FDTensor::dtype) .def_readonly("device", &FDTensor::device); diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc index 0877402d7..5037dc120 100755 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -356,6 +356,10 @@ void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; } void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; } +void RuntimeOption::EnablePinnedMemory() { enable_pinned_memory = true; } + +void RuntimeOption::DisablePinnedMemory() { enable_pinned_memory = false; } + void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) { trt_serialize_file = cache_file_path; } @@ -503,6 +507,7 @@ void Runtime::CreatePaddleBackend() { pd_option.gpu_id = option.device_id; pd_option.delete_pass_names = option.pd_delete_pass_names; pd_option.cpu_thread_num = option.cpu_thread_num; + pd_option.enable_pinned_memory = option.enable_pinned_memory; #ifdef ENABLE_TRT_BACKEND if (pd_option.use_gpu && option.pd_enable_trt) { pd_option.enable_trt = true; @@ -516,6 +521,7 @@ void Runtime::CreatePaddleBackend() { trt_option.min_shape = option.trt_min_shape; trt_option.opt_shape = option.trt_opt_shape; trt_option.serialize_file = option.trt_serialize_file; + trt_option.enable_pinned_memory = option.enable_pinned_memory; pd_option.trt_option = trt_option; } #endif @@ -606,6 +612,7 @@ void Runtime::CreateTrtBackend() { trt_option.min_shape = option.trt_min_shape; trt_option.opt_shape = option.trt_opt_shape; trt_option.serialize_file = option.trt_serialize_file; + trt_option.enable_pinned_memory = option.enable_pinned_memory; // TODO(jiangjiajun): inside usage, maybe remove this later trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_; diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h index 32ad1615c..021103cb2 100755 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -204,6 +204,15 @@ struct FASTDEPLOY_DECL RuntimeOption { */ void SetTrtCacheFile(const std::string& cache_file_path); + /** + * @brief Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend. + */ + void EnablePinnedMemory(); + + /** + * @brief Disable pinned memory + */ + void DisablePinnedMemory(); /** * @brief Enable to collect shape in paddle trt backend @@ -223,6 +232,8 @@ struct FASTDEPLOY_DECL RuntimeOption { Device device = Device::CPU; + bool enable_pinned_memory = false; + // ======Only for ORT Backend======== // -1 means use default value by ort // 0: ORT_DISABLE_ALL 1: ORT_ENABLE_BASIC 2: ORT_ENABLE_EXTENDED 3: diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index 90e64d400..61d103931 100755 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -319,6 +319,16 @@ def disable_trt_fp16(self): """ return self._option.disable_trt_fp16() + def enable_pinned_memory(self): + """Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend. + """ + return self._option.enable_pinned_memory() + + def disable_pinned_memory(self): + """Disable pinned memory. + """ + return self._option.disable_pinned_memory() + def enable_paddle_to_trt(self): """While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead. """