Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Yolov5/v5lite/v6/v7/v7end2end: CUDA preprocessing #370

Merged
merged 24 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4ea3407
add yolo cuda preprocessing
wang-xinyu Oct 14, 2022
757c151
cmake build cuda src
wang-xinyu Oct 14, 2022
4365073
yolov5 support cuda preprocessing
wang-xinyu Oct 14, 2022
ef1475e
Merge branch 'develop' into cuda_preproc
jiangjiajun Oct 14, 2022
07eb593
Merge branch 'PaddlePaddle:develop' into cuda_preproc
wang-xinyu Oct 17, 2022
3f7e6ba
yolov5 cuda preprocessing configurable
wang-xinyu Oct 18, 2022
73f5d89
Merge branch 'PaddlePaddle:develop' into cuda_preproc
wang-xinyu Oct 18, 2022
010e66d
Merge branch 'cuda_preproc' of https://github.com/wang-xinyu/FastDepl…
wang-xinyu Oct 18, 2022
e061b35
yolov5 update get mat data api
wang-xinyu Oct 18, 2022
389da0a
yolov5 check cuda preprocess args
wang-xinyu Oct 18, 2022
d580801
refactor cuda function name
wang-xinyu Oct 18, 2022
e3d7e1e
yolo cuda preprocess padding value configurable
wang-xinyu Oct 18, 2022
fdd6621
yolov5 release cuda memory
wang-xinyu Oct 18, 2022
fd3725a
cuda preprocess pybind api update
wang-xinyu Oct 18, 2022
1e1d2f3
Merge branch 'develop' into cuda_preproc
wang-xinyu Oct 18, 2022
d43e604
move use_cuda_preprocessing option to yolov5 model
wang-xinyu Oct 18, 2022
4f4c456
yolov5lite cuda preprocessing
wang-xinyu Oct 18, 2022
4c8f8d2
yolov6 cuda preprocessing
wang-xinyu Oct 18, 2022
5e10efa
yolov7 cuda preprocessing
wang-xinyu Oct 18, 2022
1905986
yolov7_e2e cuda preprocessing
wang-xinyu Oct 18, 2022
8ce4ea6
remove cuda preprocessing in runtime option
wang-xinyu Oct 18, 2022
6a7a81e
Merge branch 'develop' into cuda_preproc
wang-xinyu Oct 19, 2022
05c289c
refine log and cmake variable name
wang-xinyu Oct 19, 2022
2ba76dd
fix model runtime ptr type
wang-xinyu Oct 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ if(BUILD_ON_JETSON)
set(ENABLE_ORT_BACKEND ON)
endif()

# Whether to build CUDA source files in fastdeploy
# Only support CPU Inference & GPU(TensorRT) Inference Now
option(BUILD_CUDA_SRC "Whether to build CUDA source files in fastdeploy" OFF)
wang-xinyu marked this conversation as resolved.
Show resolved Hide resolved
if(WITH_GPU AND UNIX)
set(BUILD_CUDA_SRC ON)
add_definitions(-DENABLE_CUDA_SRC)
enable_language(CUDA)
set(CUDA_PROPAGATE_HOST_FLAGS FALSE)
else()
set(BUILD_CUDA_SRC OFF)
endif()

# config GIT_URL with github mirrors to speed up dependent repos clone
option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL})
if(NOT GIT_URL)
Expand Down Expand Up @@ -174,6 +186,7 @@ file(GLOB_RECURSE DEPLOY_TRT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastde
file(GLOB_RECURSE DEPLOY_OPENVINO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/openvino/*.cc)
file(GLOB_RECURSE DEPLOY_LITE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/lite/*.cc)
file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc)
file(GLOB_RECURSE DEPLOY_VISION_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cu)
file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc)
file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc)
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_POROS_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS})
Expand Down Expand Up @@ -373,6 +386,9 @@ if(ENABLE_VISION)
endif()
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp)
list(APPEND DEPEND_LIBS yaml-cpp)
if(BUILD_CUDA_SRC)
list(APPEND DEPLOY_VISION_SRCS ${DEPLOY_VISION_CUDA_SRCS})
endif()
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_VISION_SRCS})
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
include(${PROJECT_SOURCE_DIR}/cmake/opencv.cmake)
Expand Down Expand Up @@ -428,7 +444,13 @@ elseif(ANDROID)
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_MINSIZEREL ${COMMON_LINK_FLAGS_REL})
elseif(MSVC)
else()
set_target_properties(${LIBRARY_NAME} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
if(BUILD_CUDA_SRC)
set_target_properties(${LIBRARY_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
set_target_properties(${LIBRARY_NAME} PROPERTIES INTERFACE_COMPILE_OPTIONS
"$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CXX>>:-fvisibility=hidden>$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-fvisibility=hidden>")
else()
set_target_properties(${LIBRARY_NAME} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
endif()
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-Wl,--exclude-libs,ALL")
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_RELEASE -s)
endif()
Expand Down
1 change: 1 addition & 0 deletions cmake/summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ function(fastdeploy_summary)
message(STATUS " WITH_GPU : ${WITH_GPU}")
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
message(STATUS " TRT_DRECTORY : ${TRT_DIRECTORY}")
message(STATUS " BUILD_CUDA_SRC : ${BUILD_CUDA_SRC}")
endif()
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/fastdeploy_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class FASTDEPLOY_DECL FastDeployModel {
std::vector<Backend> valid_external_backends;

private:
std::unique_ptr<Runtime> runtime_;
std::shared_ptr<Runtime> runtime_;
bool runtime_initialized_ = false;
// whether to record inference time
bool enable_record_time_of_runtime_ = false;
Expand Down
96 changes: 91 additions & 5 deletions fastdeploy/vision/detection/contrib/yolov5.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
#ifdef ENABLE_CUDA_SRC
#include "fastdeploy/vision/utils/cuda_utils.h"
#endif // ENABLE_CUDA_SRC

namespace fastdeploy {
namespace vision {
Expand Down Expand Up @@ -104,9 +107,20 @@ bool YOLOv5::Initialize() {
// if (!is_dynamic_input_) {
// is_mini_pad_ = false;
// }

return true;
}

YOLOv5::~YOLOv5() {
#ifdef ENABLE_CUDA_SRC
if (use_cuda_preprocessing_) {
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
}
#endif // ENABLE_CUDA_SRC
}

bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size,
Expand Down Expand Up @@ -156,6 +170,69 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
return true;
}

void YOLOv5::UseCudaPreprocessing(int max_image_size) {
#ifdef ENABLE_CUDA_SRC
use_cuda_preprocessing_ = true;
is_scale_up_ = true;
if (input_img_cuda_buffer_host_ == nullptr) {
// prepare input data cache in GPU pinned memory
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
// prepare input data cache in GPU device memory
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size_[0] * size_[1] * sizeof(float)));
}
#else
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
<< std::endl;
use_cuda_preprocessing_ = false;
#endif
}

bool YOLOv5::CudaPreprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size,
const std::vector<float> padding_value,
bool is_mini_pad, bool is_no_pad, bool is_scale_up,
int stride, float max_wh, bool multi_label) {
#ifdef ENABLE_CUDA_SRC
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
FDERROR << "Upsupported arguments for CUDA preprocess." << std::endl;
wang-xinyu marked this conversation as resolved.
Show resolved Hide resolved
return false;
}

// Record the shape of image and the shape of preprocessed image
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};

cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
input_img_cuda_buffer_host_,
src_img_buf_size, cudaMemcpyHostToDevice, stream));
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
mat->Height(), input_tensor_cuda_buffer_device_,
size[0], size[1], padding_value, stream);
cudaStreamSynchronize(stream);
cudaStreamDestroy(stream);

// Record output shape of preprocessed image
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};

output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
input_tensor_cuda_buffer_device_);
output->device = Device::GPU;
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
#else
FDERROR << "CUDA src code was not enabled." << std::endl;
return false;
#endif // ENABLE_CUDA_SRC
}

bool YOLOv5::Postprocess(
std::vector<FDTensor>& infer_results, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
Expand Down Expand Up @@ -262,11 +339,20 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,

std::map<std::string, std::array<float, 2>> im_info;

if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
if (use_cuda_preprocessing_) {
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
} else {
if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
}

input_tensors[0].name = InputInfoOfRuntime(0).name;
Expand Down
22 changes: 22 additions & 0 deletions fastdeploy/vision/detection/contrib/yolov5.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#pragma once

#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
Expand All @@ -27,6 +28,8 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::ONNX);

~YOLOv5();

std::string ModelName() const { return "yolov5"; }

virtual bool Predict(cv::Mat* im, DetectionResult* result,
Expand All @@ -42,6 +45,17 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
bool is_scale_up = false, int stride = 32,
float max_wh = 7680.0, bool multi_label = true);

void UseCudaPreprocessing(int max_img_size = 3840 * 2160);

bool CudaPreprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size = {640, 640},
const std::vector<float> padding_value = {114.0, 114.0,
114.0},
bool is_mini_pad = false, bool is_no_pad = false,
bool is_scale_up = false, int stride = 32,
float max_wh = 7680.0, bool multi_label = true);

static bool Postprocess(
std::vector<FDTensor>& infer_results, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
Expand Down Expand Up @@ -85,6 +99,14 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
// value will
// auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_;
// CUDA host buffer for input image
uint8_t* input_img_cuda_buffer_host_ = nullptr;
// CUDA device buffer for input image
uint8_t* input_img_cuda_buffer_device_ = nullptr;
// CUDA device buffer for TRT input tensor
float* input_tensor_cuda_buffer_device_ = nullptr;
// Whether to use CUDA preprocessing
bool use_cuda_preprocessing_ = false;
};

} // namespace detection
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/vision/detection/contrib/yolov5_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ void BindYOLOv5(pybind11::module& m) {
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def("use_cuda_preprocessing",
[](vision::detection::YOLOv5& self, int max_image_size) {
self.UseCudaPreprocessing(max_image_size);
})
.def_static("preprocess",
[](pybind11::array& data, const std::vector<int>& size,
const std::vector<float> padding_value, bool is_mini_pad,
Expand Down
85 changes: 82 additions & 3 deletions fastdeploy/vision/detection/contrib/yolov5lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include "fastdeploy/vision/detection/contrib/yolov5lite.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
#ifdef ENABLE_CUDA_SRC
#include "fastdeploy/vision/utils/cuda_utils.h"
#endif // ENABLE_CUDA_SRC

namespace fastdeploy {
namespace vision {
Expand Down Expand Up @@ -136,6 +139,16 @@ bool YOLOv5Lite::Initialize() {
return true;
}

YOLOv5Lite::~YOLOv5Lite() {
#ifdef ENABLE_CUDA_SRC
if (use_cuda_preprocessing_) {
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
}
#endif // ENABLE_CUDA_SRC
}

bool YOLOv5Lite::Preprocess(
Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
Expand Down Expand Up @@ -176,6 +189,65 @@ bool YOLOv5Lite::Preprocess(
return true;
}

void YOLOv5Lite::UseCudaPreprocessing(int max_image_size) {
#ifdef ENABLE_CUDA_SRC
use_cuda_preprocessing_ = true;
is_scale_up = true;
if (input_img_cuda_buffer_host_ == nullptr) {
// prepare input data cache in GPU pinned memory
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
// prepare input data cache in GPU device memory
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float)));
}
#else
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
<< std::endl;
use_cuda_preprocessing_ = false;
#endif
}

bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
#ifdef ENABLE_CUDA_SRC
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
FDERROR << "Upsupported arguments for CUDA preprocess." << std::endl;
return false;
}

// Record the shape of image and the shape of preprocessed image
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};

cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
input_img_cuda_buffer_host_,
src_img_buf_size, cudaMemcpyHostToDevice, stream));
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
mat->Height(), input_tensor_cuda_buffer_device_,
size[0], size[1], padding_value, stream);
cudaStreamSynchronize(stream);
cudaStreamDestroy(stream);

// Record output shape of preprocessed image
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};

output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
input_tensor_cuda_buffer_device_);
output->device = Device::GPU;
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
#else
FDERROR << "CUDA src code was not enabled." << std::endl;
return false;
#endif // ENABLE_CUDA_SRC
}

bool YOLOv5Lite::PostprocessWithDecode(
FDTensor& infer_result, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
Expand Down Expand Up @@ -348,9 +420,16 @@ bool YOLOv5Lite::Predict(cv::Mat* im, DetectionResult* result,
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};

if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
if (use_cuda_preprocessing_) {
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
} else {
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
}

input_tensors[0].name = InputInfoOfRuntime(0).name;
Expand Down
Loading