diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index c9719cc9ea..82f06e0036 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -24,7 +24,7 @@ #include "fastdeploy/vision/detection/contrib/yolov5/yolov5.h" #include "fastdeploy/vision/detection/contrib/yolov5lite.h" #include "fastdeploy/vision/detection/contrib/yolov6.h" -#include "fastdeploy/vision/detection/contrib/yolov7.h" +#include "fastdeploy/vision/detection/contrib/yolov7/yolov7.h" #include "fastdeploy/vision/detection/contrib/yolov7end2end_ort.h" #include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h" #include "fastdeploy/vision/detection/contrib/yolox.h" diff --git a/fastdeploy/vision/detection/contrib/yolov5/postprocessor.cc b/fastdeploy/vision/detection/contrib/yolov5/postprocessor.cc index dd61efb002..4fe01dfeb8 100755 --- a/fastdeploy/vision/detection/contrib/yolov5/postprocessor.cc +++ b/fastdeploy/vision/detection/contrib/yolov5/postprocessor.cc @@ -103,9 +103,9 @@ bool YOLOv5Postprocessor::Run(const std::vector& tensors, std::vector< float ipt_h = iter_ipt->second[0]; float ipt_w = iter_ipt->second[1]; float scale = std::min(out_h / ipt_h, out_w / ipt_w); + float pad_h = (out_h - ipt_h * scale) / 2; + float pad_w = (out_w - ipt_w * scale) / 2; for (size_t i = 0; i < (*results)[bs].boxes.size(); ++i) { - float pad_h = (out_h - ipt_h * scale) / 2; - float pad_w = (out_w - ipt_w * scale) / 2; int32_t label_id = ((*results)[bs].label_ids)[i]; // clip box (*results)[bs].boxes[i][0] = (*results)[bs].boxes[i][0] - max_wh_ * label_id; diff --git a/fastdeploy/vision/detection/contrib/yolov5/postprocessor.h b/fastdeploy/vision/detection/contrib/yolov5/postprocessor.h index a1479dd940..88f9400fa2 100755 --- a/fastdeploy/vision/detection/contrib/yolov5/postprocessor.h +++ b/fastdeploy/vision/detection/contrib/yolov5/postprocessor.h @@ -55,7 +55,7 @@ class FASTDEPLOY_DECL YOLOv5Postprocessor { /// Get nms_threshold, default 0.5 float GetNMSThreshold() const { return nms_threshold_; } - /// Set multi_label, default true + /// Set multi_label, set true for eval, default true void SetMultiLabel(bool multi_label) { multi_label_ = multi_label; } diff --git a/fastdeploy/vision/detection/contrib/yolov5/preprocessor.cc b/fastdeploy/vision/detection/contrib/yolov5/preprocessor.cc index 112a4d4d5d..846e251316 100755 --- a/fastdeploy/vision/detection/contrib/yolov5/preprocessor.cc +++ b/fastdeploy/vision/detection/contrib/yolov5/preprocessor.cc @@ -24,7 +24,7 @@ YOLOv5Preprocessor::YOLOv5Preprocessor() { padding_value_ = {114.0, 114.0, 114.0}; is_mini_pad_ = false; is_no_pad_ = false; - is_scale_up_ = false; + is_scale_up_ = true; stride_ = 32; max_wh_ = 7680.0; } @@ -50,7 +50,9 @@ void YOLOv5Preprocessor::LetterBox(FDMat* mat) { resize_h = size_[1]; resize_w = size_[0]; } - Resize::Run(mat, resize_w, resize_h); + if (std::fabs(scale - 1.0f) > 1e-06) { + Resize::Run(mat, resize_w, resize_h); + } if (pad_h > 0 || pad_w > 0) { float half_h = pad_h * 1.0 / 2; int top = int(round(half_h - 0.1)); @@ -67,19 +69,6 @@ bool YOLOv5Preprocessor::Preprocess(FDMat* mat, FDTensor* output, // Record the shape of image and the shape of preprocessed image (*im_info)["input_shape"] = {static_cast(mat->Height()), static_cast(mat->Width())}; - - // process after image load - double ratio = (size_[0] * 1.0) / std::max(static_cast(mat->Height()), - static_cast(mat->Width())); - if (std::fabs(ratio - 1.0f) > 1e-06) { - int interp = cv::INTER_AREA; - if (ratio > 1.0) { - interp = cv::INTER_LINEAR; - } - int resize_h = int(mat->Height() * ratio); - int resize_w = int(mat->Width() * ratio); - Resize::Run(mat, resize_w, resize_h, -1, -1, interp); - } // yolov5's preprocess steps // 1. letterbox // 2. convert_and_permute(swap_rb=true) diff --git a/fastdeploy/vision/detection/contrib/yolov5/preprocessor.h b/fastdeploy/vision/detection/contrib/yolov5/preprocessor.h index b3559685db..f0cf438df0 100755 --- a/fastdeploy/vision/detection/contrib/yolov5/preprocessor.h +++ b/fastdeploy/vision/detection/contrib/yolov5/preprocessor.h @@ -52,6 +52,15 @@ class FASTDEPLOY_DECL YOLOv5Preprocessor { /// Get padding value, size should be the same as channels std::vector GetPaddingValue() const { return padding_value_; } + /// Set is_scale_up, if is_scale_up is false, the input image only + /// can be zoom out, the maximum resize scale cannot exceed 1.0, default true + void SetScaleUp(bool is_scale_up) { + is_scale_up_ = is_scale_up; + } + + /// Get is_scale_up, default true + bool GetScaleUp() const { return is_scale_up_; } + protected: bool Preprocess(FDMat* mat, FDTensor* output, std::map>* im_info); diff --git a/fastdeploy/vision/detection/contrib/yolov5/yolov5_pybind.cc b/fastdeploy/vision/detection/contrib/yolov5/yolov5_pybind.cc index f44891d984..7b1574401f 100755 --- a/fastdeploy/vision/detection/contrib/yolov5/yolov5_pybind.cc +++ b/fastdeploy/vision/detection/contrib/yolov5/yolov5_pybind.cc @@ -35,7 +35,8 @@ void BindYOLOv5(pybind11::module& m) { return make_pair(outputs, ims_info); }) .def_property("size", &vision::detection::YOLOv5Preprocessor::GetSize, &vision::detection::YOLOv5Preprocessor::SetSize) - .def_property("padding_value", &vision::detection::YOLOv5Preprocessor::GetPaddingValue, &vision::detection::YOLOv5Preprocessor::SetPaddingValue); + .def_property("padding_value", &vision::detection::YOLOv5Preprocessor::GetPaddingValue, &vision::detection::YOLOv5Preprocessor::SetPaddingValue) + .def_property("is_scale_up", &vision::detection::YOLOv5Preprocessor::GetScaleUp, &vision::detection::YOLOv5Preprocessor::SetScaleUp); pybind11::class_( m, "YOLOv5Postprocessor") diff --git a/fastdeploy/vision/detection/contrib/yolov7.cc b/fastdeploy/vision/detection/contrib/yolov7.cc deleted file mode 100755 index 9185e16ed0..0000000000 --- a/fastdeploy/vision/detection/contrib/yolov7.cc +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/detection/contrib/yolov7.h" - -#include "fastdeploy/utils/perf.h" -#include "fastdeploy/vision/utils/utils.h" -#ifdef ENABLE_CUDA_PREPROCESS -#include "fastdeploy/vision/utils/cuda_utils.h" -#endif // ENABLE_CUDA_PREPROCESS - -namespace fastdeploy { -namespace vision { -namespace detection { - -void YOLOv7::LetterBox(Mat* mat, const std::vector& size, - const std::vector& color, bool _auto, - bool scale_fill, bool scale_up, int stride) { - float scale = - std::min(size[1] * 1.0 / mat->Height(), size[0] * 1.0 / mat->Width()); - if (!scale_up) { - scale = std::min(scale, 1.0f); - } - - int resize_h = int(round(mat->Height() * scale)); - int resize_w = int(round(mat->Width() * scale)); - - int pad_w = size[0] - resize_w; - int pad_h = size[1] - resize_h; - if (_auto) { - pad_h = pad_h % stride; - pad_w = pad_w % stride; - } else if (scale_fill) { - pad_h = 0; - pad_w = 0; - resize_h = size[1]; - resize_w = size[0]; - } - if (resize_h != mat->Height() || resize_w != mat->Width()) { - Resize::Run(mat, resize_w, resize_h); - } - if (pad_h > 0 || pad_w > 0) { - float half_h = pad_h * 1.0 / 2; - int top = int(round(half_h - 0.1)); - int bottom = int(round(half_h + 0.1)); - float half_w = pad_w * 1.0 / 2; - int left = int(round(half_w - 0.1)); - int right = int(round(half_w + 0.1)); - Pad::Run(mat, top, bottom, left, right, color); - } -} - -YOLOv7::YOLOv7(const std::string& model_file, const std::string& params_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - if (model_format == ModelFormat::ONNX) { - valid_cpu_backends = {Backend::OPENVINO, Backend::ORT}; - valid_gpu_backends = {Backend::ORT, Backend::TRT}; - } else { - valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::LITE}; - valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; - } - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; -#ifdef ENABLE_CUDA_PREPROCESS - cudaSetDevice(runtime_option.device_id); - cudaStream_t stream; - CUDA_CHECK(cudaStreamCreate(&stream)); - cuda_stream_ = reinterpret_cast(stream); - runtime_option.SetExternalStream(cuda_stream_); -#endif // ENABLE_CUDA_PREPROCESS - initialized = Initialize(); -} - -bool YOLOv7::Initialize() { - // parameters for preprocess - size = {640, 640}; - padding_value = {114.0, 114.0, 114.0}; - is_mini_pad = false; - is_no_pad = false; - is_scale_up = false; - stride = 32; - max_wh = 7680.0; - reused_input_tensors_.resize(1); - - if (!InitRuntime()) { - FDERROR << "Failed to initialize fastdeploy backend." << std::endl; - return false; - } - // Check if the input shape is dynamic after Runtime already initialized, - // Note that, We need to force is_mini_pad 'false' to keep static - // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. - is_dynamic_input_ = false; - auto shape = InputInfoOfRuntime(0).shape; - for (int i = 0; i < shape.size(); ++i) { - // if height or width is dynamic - if (i >= 2 && shape[i] <= 0) { - is_dynamic_input_ = true; - break; - } - } - if (!is_dynamic_input_) { - is_mini_pad = false; - } - return true; -} - -YOLOv7::~YOLOv7() { -#ifdef ENABLE_CUDA_PREPROCESS - if (use_cuda_preprocessing_) { - CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); - CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); - CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); - CUDA_CHECK(cudaStreamDestroy(reinterpret_cast(cuda_stream_))); - } -#endif // ENABLE_CUDA_PREPROCESS -} - -bool YOLOv7::Preprocess(Mat* mat, FDTensor* output, - std::map>* im_info) { - // process after image load - float ratio = std::min(size[1] * 1.0f / static_cast(mat->Height()), - size[0] * 1.0f / static_cast(mat->Width())); - if (std::fabs(ratio - 1.0f) > 1e-06) { - int interp = cv::INTER_AREA; - if (ratio > 1.0) { - interp = cv::INTER_LINEAR; - } - int resize_h = int(mat->Height() * ratio); - int resize_w = int(mat->Width() * ratio); - Resize::Run(mat, resize_w, resize_h, -1, -1, interp); - } - // yolov7's preprocess steps - // 1. letterbox - // 2. BGR->RGB - // 3. HWC->CHW - YOLOv7::LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, - is_scale_up, stride); - BGR2RGB::Run(mat); - // Normalize::Run(mat, std::vector(mat->Channels(), 0.0), - // std::vector(mat->Channels(), 1.0)); - // Compute `result = mat * alpha + beta` directly by channel - std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; - std::vector beta = {0.0f, 0.0f, 0.0f}; - Convert::Run(mat, alpha, beta); - - // Record output shape of preprocessed image - (*im_info)["output_shape"] = {static_cast(mat->Height()), - static_cast(mat->Width())}; - - HWC2CHW::Run(mat); - Cast::Run(mat, "float"); - mat->ShareWithTensor(output); - output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c - return true; -} - -void YOLOv7::UseCudaPreprocessing(int max_image_size) { -#ifdef ENABLE_CUDA_PREPROCESS - use_cuda_preprocessing_ = true; - is_scale_up = true; - if (input_img_cuda_buffer_host_ == nullptr) { - // prepare input data cache in GPU pinned memory - CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, - max_image_size * 3)); - // prepare input data cache in GPU device memory - CUDA_CHECK( - cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); - CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, - 3 * size[0] * size[1] * sizeof(float))); - } -#else - FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." - << std::endl; - use_cuda_preprocessing_ = false; -#endif -} - -bool YOLOv7::CudaPreprocess( - Mat* mat, FDTensor* output, - std::map>* im_info) { -#ifdef ENABLE_CUDA_PREPROCESS - if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { - FDERROR << "Preprocessing with CUDA is only available when the arguments " - "satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." - << std::endl; - return false; - } - - // Record the shape of image and the shape of preprocessed image - (*im_info)["input_shape"] = {static_cast(mat->Height()), - static_cast(mat->Width())}; - (*im_info)["output_shape"] = {static_cast(mat->Height()), - static_cast(mat->Width())}; - - cudaStream_t stream = reinterpret_cast(cuda_stream_); - int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels(); - memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size); - CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_, - input_img_cuda_buffer_host_, src_img_buf_size, - cudaMemcpyHostToDevice, stream)); - utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(), - mat->Height(), input_tensor_cuda_buffer_device_, - size[0], size[1], padding_value, stream); - - // Record output shape of preprocessed image - (*im_info)["output_shape"] = {static_cast(size[0]), - static_cast(size[1])}; - - output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32, - input_tensor_cuda_buffer_device_); - output->device = Device::GPU; - output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c - return true; -#else - FDERROR << "CUDA src code was not enabled." << std::endl; - return false; -#endif // ENABLE_CUDA_PREPROCESS -} - -bool YOLOv7::Postprocess( - FDTensor& infer_result, DetectionResult* result, - const std::map>& im_info, - float conf_threshold, float nms_iou_threshold) { - FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now."); - result->Clear(); - result->Reserve(infer_result.shape[1]); - if (infer_result.dtype != FDDataType::FP32) { - FDERROR << "Only support post process with float32 data." << std::endl; - return false; - } - float* data = static_cast(infer_result.Data()); - for (size_t i = 0; i < infer_result.shape[1]; ++i) { - int s = i * infer_result.shape[2]; - float confidence = data[s + 4]; - float* max_class_score = - std::max_element(data + s + 5, data + s + infer_result.shape[2]); - confidence *= (*max_class_score); - // filter boxes by conf_threshold - if (confidence <= conf_threshold) { - continue; - } - int32_t label_id = std::distance(data + s + 5, max_class_score); - // convert from [x, y, w, h] to [x1, y1, x2, y2] - result->boxes.emplace_back(std::array{ - data[s] - data[s + 2] / 2.0f + label_id * max_wh, - data[s + 1] - data[s + 3] / 2.0f + label_id * max_wh, - data[s + 0] + data[s + 2] / 2.0f + label_id * max_wh, - data[s + 1] + data[s + 3] / 2.0f + label_id * max_wh}); - result->label_ids.push_back(label_id); - result->scores.push_back(confidence); - } - utils::NMS(result, nms_iou_threshold); - - // scale the boxes to the origin image shape - auto iter_out = im_info.find("output_shape"); - auto iter_ipt = im_info.find("input_shape"); - FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(), - "Cannot find input_shape or output_shape from im_info."); - float out_h = iter_out->second[0]; - float out_w = iter_out->second[1]; - float ipt_h = iter_ipt->second[0]; - float ipt_w = iter_ipt->second[1]; - float scale = std::min(out_h / ipt_h, out_w / ipt_w); - float pad_h = (out_h - ipt_h * scale) / 2.0f; - float pad_w = (out_w - ipt_w * scale) / 2.0f; - if (is_mini_pad) { - pad_h = static_cast(static_cast(pad_h) % stride); - pad_w = static_cast(static_cast(pad_w) % stride); - } - for (size_t i = 0; i < result->boxes.size(); ++i) { - int32_t label_id = (result->label_ids)[i]; - // clip box - result->boxes[i][0] = result->boxes[i][0] - max_wh * label_id; - result->boxes[i][1] = result->boxes[i][1] - max_wh * label_id; - result->boxes[i][2] = result->boxes[i][2] - max_wh * label_id; - result->boxes[i][3] = result->boxes[i][3] - max_wh * label_id; - result->boxes[i][0] = std::max((result->boxes[i][0] - pad_w) / scale, 0.0f); - result->boxes[i][1] = std::max((result->boxes[i][1] - pad_h) / scale, 0.0f); - result->boxes[i][2] = std::max((result->boxes[i][2] - pad_w) / scale, 0.0f); - result->boxes[i][3] = std::max((result->boxes[i][3] - pad_h) / scale, 0.0f); - result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f); - result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f); - result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f); - result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f); - } - return true; -} - -bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, - float nms_iou_threshold) { - Mat mat(*im); - - std::map> im_info; - - // Record the shape of image and the shape of preprocessed image - im_info["input_shape"] = {static_cast(mat.Height()), - static_cast(mat.Width())}; - im_info["output_shape"] = {static_cast(mat.Height()), - static_cast(mat.Width())}; - - if (use_cuda_preprocessing_) { - if (!CudaPreprocess(&mat, &reused_input_tensors_[0], &im_info)) { - FDERROR << "Failed to preprocess input image." << std::endl; - return false; - } - } else { - if (!Preprocess(&mat, &reused_input_tensors_[0], &im_info)) { - FDERROR << "Failed to preprocess input image." << std::endl; - return false; - } - } - - reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; - if (!Infer()) { - FDERROR << "Failed to inference." << std::endl; - return false; - } - - if (!Postprocess(reused_output_tensors_[0], result, im_info, conf_threshold, - nms_iou_threshold)) { - FDERROR << "Failed to post process." << std::endl; - return false; - } - - return true; -} - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7.h b/fastdeploy/vision/detection/contrib/yolov7.h deleted file mode 100644 index b9d637ed9a..0000000000 --- a/fastdeploy/vision/detection/contrib/yolov7.h +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "fastdeploy/fastdeploy_model.h" -#include "fastdeploy/vision/common/processors/transform.h" -#include "fastdeploy/vision/common/result.h" - -namespace fastdeploy { -namespace vision { -namespace detection { -/*! @brief YOLOv7 model object used when to load a YOLOv7 model exported by YOLOv7. - */ -class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { - public: - /** \brief Set path of model file and the configuration of runtime. - * - * \param[in] model_file Path of model file, e.g ./yolov7.onnx - * \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored - * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends" - * \param[in] model_format Model format of the loaded model, default is ONNX format - */ - YOLOv7(const std::string& model_file, const std::string& params_file = "", - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::ONNX); - - ~YOLOv7(); - - virtual std::string ModelName() const { return "yolov7"; } - /** \brief Predict the detection result for an input image - * - * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format - * \param[in] result The output detection result will be writen to this structure - * \param[in] conf_threshold confidence threashold for postprocessing, default is 0.25 - * \param[in] nms_iou_threshold iou threashold for NMS, default is 0.5 - * \return true if the prediction successed, otherwise false - */ - virtual bool Predict(cv::Mat* im, DetectionResult* result, - float conf_threshold = 0.25, - float nms_iou_threshold = 0.5); - - - void UseCudaPreprocessing(int max_img_size = 3840 * 2160); - - /*! @brief - Argument for image preprocessing step, tuple of (width, height), decide the target size after resize, default size = {640, 640} - */ - std::vector size; - // padding value, size should be the same as channels - - std::vector padding_value; - // only pad to the minimum rectange which height and width is times of stride - bool is_mini_pad; - // while is_mini_pad = false and is_no_pad = true, - // will resize the image to the set size - bool is_no_pad; - // if is_scale_up is false, the input image only can be zoom out, - // the maximum resize scale cannot exceed 1.0 - bool is_scale_up; - // padding stride, for is_mini_pad - int stride; - // for offseting the boxes by classes when using NMS - float max_wh; - - private: - bool Initialize(); - - bool Preprocess(Mat* mat, FDTensor* output, - std::map>* im_info); - - bool CudaPreprocess(Mat* mat, FDTensor* output, - std::map>* im_info); - - bool Postprocess(FDTensor& infer_result, DetectionResult* result, - const std::map>& im_info, - float conf_threshold, float nms_iou_threshold); - - void LetterBox(Mat* mat, const std::vector& size, - const std::vector& color, bool _auto, - bool scale_fill = false, bool scale_up = true, - int stride = 32); - - // whether to inference with dynamic shape (e.g ONNX export with dynamic shape - // or not.) - // while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This - // value will - // auto check by fastdeploy after the internal Runtime already initialized. - bool is_dynamic_input_; - // CUDA host buffer for input image - uint8_t* input_img_cuda_buffer_host_ = nullptr; - // CUDA device buffer for input image - uint8_t* input_img_cuda_buffer_device_ = nullptr; - // CUDA device buffer for TRT input tensor - float* input_tensor_cuda_buffer_device_ = nullptr; - // Whether to use CUDA preprocessing - bool use_cuda_preprocessing_ = false; - // CUDA stream - void* cuda_stream_ = nullptr; -}; -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7/postprocessor.cc b/fastdeploy/vision/detection/contrib/yolov7/postprocessor.cc new file mode 100755 index 0000000000..01d657adb3 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov7/postprocessor.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/yolov7/postprocessor.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +YOLOv7Postprocessor::YOLOv7Postprocessor() { + conf_threshold_ = 0.25; + nms_threshold_ = 0.5; + max_wh_ = 7680.0; +} + +bool YOLOv7Postprocessor::Run(const std::vector& tensors, std::vector* results, + const std::vector>>& ims_info) { + int batch = tensors[0].shape[0]; + + results->resize(batch); + + for (size_t bs = 0; bs < batch; ++bs) { + (*results)[bs].Clear(); + (*results)[bs].Reserve(tensors[0].shape[1]); + if (tensors[0].dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + const float* data = reinterpret_cast(tensors[0].Data()) + bs * tensors[0].shape[1] * tensors[0].shape[2]; + for (size_t i = 0; i < tensors[0].shape[1]; ++i) { + int s = i * tensors[0].shape[2]; + float confidence = data[s + 4]; + const float* max_class_score = + std::max_element(data + s + 5, data + s + tensors[0].shape[2]); + confidence *= (*max_class_score); + // filter boxes by conf_threshold + if (confidence <= conf_threshold_) { + continue; + } + int32_t label_id = std::distance(data + s + 5, max_class_score); + // convert from [x, y, w, h] to [x1, y1, x2, y2] + (*results)[bs].boxes.emplace_back(std::array{ + data[s] - data[s + 2] / 2.0f + label_id * max_wh_, + data[s + 1] - data[s + 3] / 2.0f + label_id * max_wh_, + data[s + 0] + data[s + 2] / 2.0f + label_id * max_wh_, + data[s + 1] + data[s + 3] / 2.0f + label_id * max_wh_}); + (*results)[bs].label_ids.push_back(label_id); + (*results)[bs].scores.push_back(confidence); + } + + if ((*results)[bs].boxes.size() == 0) { + return true; + } + + utils::NMS(&((*results)[bs]), nms_threshold_); + + // scale the boxes to the origin image shape + auto iter_out = ims_info[bs].find("output_shape"); + auto iter_ipt = ims_info[bs].find("input_shape"); + FDASSERT(iter_out != ims_info[bs].end() && iter_ipt != ims_info[bs].end(), + "Cannot find input_shape or output_shape from im_info."); + float out_h = iter_out->second[0]; + float out_w = iter_out->second[1]; + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + float scale = std::min(out_h / ipt_h, out_w / ipt_w); + float pad_h = (out_h - ipt_h * scale) / 2; + float pad_w = (out_w - ipt_w * scale) / 2; + for (size_t i = 0; i < (*results)[bs].boxes.size(); ++i) { + int32_t label_id = ((*results)[bs].label_ids)[i]; + // clip box + (*results)[bs].boxes[i][0] = (*results)[bs].boxes[i][0] - max_wh_ * label_id; + (*results)[bs].boxes[i][1] = (*results)[bs].boxes[i][1] - max_wh_ * label_id; + (*results)[bs].boxes[i][2] = (*results)[bs].boxes[i][2] - max_wh_ * label_id; + (*results)[bs].boxes[i][3] = (*results)[bs].boxes[i][3] - max_wh_ * label_id; + (*results)[bs].boxes[i][0] = std::max(((*results)[bs].boxes[i][0] - pad_w) / scale, 0.0f); + (*results)[bs].boxes[i][1] = std::max(((*results)[bs].boxes[i][1] - pad_h) / scale, 0.0f); + (*results)[bs].boxes[i][2] = std::max(((*results)[bs].boxes[i][2] - pad_w) / scale, 0.0f); + (*results)[bs].boxes[i][3] = std::max(((*results)[bs].boxes[i][3] - pad_h) / scale, 0.0f); + (*results)[bs].boxes[i][0] = std::min((*results)[bs].boxes[i][0], ipt_w - 1.0f); + (*results)[bs].boxes[i][1] = std::min((*results)[bs].boxes[i][1], ipt_h - 1.0f); + (*results)[bs].boxes[i][2] = std::min((*results)[bs].boxes[i][2], ipt_w - 1.0f); + (*results)[bs].boxes[i][3] = std::min((*results)[bs].boxes[i][3], ipt_h - 1.0f); + } + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7/postprocessor.h b/fastdeploy/vision/detection/contrib/yolov7/postprocessor.h new file mode 100755 index 0000000000..5ece87eb8b --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov7/postprocessor.h @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace detection { +/*! @brief Postprocessor object for YOLOv7 serials model. + */ +class FASTDEPLOY_DECL YOLOv7Postprocessor { + public: + /** \brief Create a postprocessor instance for YOLOv7 serials model + */ + YOLOv7Postprocessor(); + + /** \brief Process the result of runtime and fill to DetectionResult structure + * + * \param[in] tensors The inference result from runtime + * \param[in] result The output result of detection + * \param[in] ims_info The shape info list, record input_shape and output_shape + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* results, + const std::vector>>& ims_info); + + /// Set conf_threshold, default 0.25 + void SetConfThreshold(const float& conf_threshold) { + conf_threshold_ = conf_threshold; + } + + /// Get conf_threshold, default 0.25 + float GetConfThreshold() const { return conf_threshold_; } + + /// Set nms_threshold, default 0.5 + void SetNMSThreshold(const float& nms_threshold) { + nms_threshold_ = nms_threshold; + } + + /// Get nms_threshold, default 0.5 + float GetNMSThreshold() const { return nms_threshold_; } + + protected: + float conf_threshold_; + float nms_threshold_; + float max_wh_; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7/preprocessor.cc b/fastdeploy/vision/detection/contrib/yolov7/preprocessor.cc new file mode 100755 index 0000000000..91e22f32b4 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov7/preprocessor.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/yolov7/preprocessor.h" +#include "fastdeploy/function/concat.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +YOLOv7Preprocessor::YOLOv7Preprocessor() { + size_ = {640, 640}; + padding_value_ = {114.0, 114.0, 114.0}; + is_mini_pad_ = false; + is_no_pad_ = false; + is_scale_up_ = true; + stride_ = 32; + max_wh_ = 7680.0; +} + +void YOLOv7Preprocessor::LetterBox(FDMat* mat) { + float scale = + std::min(size_[1] * 1.0 / mat->Height(), size_[0] * 1.0 / mat->Width()); + if (!is_scale_up_) { + scale = std::min(scale, 1.0f); + } + + int resize_h = int(round(mat->Height() * scale)); + int resize_w = int(round(mat->Width() * scale)); + + int pad_w = size_[0] - resize_w; + int pad_h = size_[1] - resize_h; + if (is_mini_pad_) { + pad_h = pad_h % stride_; + pad_w = pad_w % stride_; + } else if (is_no_pad_) { + pad_h = 0; + pad_w = 0; + resize_h = size_[1]; + resize_w = size_[0]; + } + if (std::fabs(scale - 1.0f) > 1e-06) { + Resize::Run(mat, resize_w, resize_h); + } + if (pad_h > 0 || pad_w > 0) { + float half_h = pad_h * 1.0 / 2; + int top = int(round(half_h - 0.1)); + int bottom = int(round(half_h + 0.1)); + float half_w = pad_w * 1.0 / 2; + int left = int(round(half_w - 0.1)); + int right = int(round(half_w + 0.1)); + Pad::Run(mat, top, bottom, left, right, padding_value_); + } +} + +bool YOLOv7Preprocessor::Preprocess(FDMat* mat, FDTensor* output, + std::map>* im_info) { + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + // yolov7's preprocess steps + // 1. letterbox + // 2. convert_and_permute(swap_rb=true) + LetterBox(mat); + std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; + std::vector beta = {0.0f, 0.0f, 0.0f}; + ConvertAndPermute::Run(mat, alpha, beta, true); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + mat->ShareWithTensor(output); + output->ExpandDim(0); // reshape to n, h, w, c + return true; +} + +bool YOLOv7Preprocessor::Run(std::vector* images, std::vector* outputs, + std::vector>>* ims_info) { + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." << std::endl; + return false; + } + ims_info->resize(images->size()); + outputs->resize(1); + // Concat all the preprocessed data to a batch tensor + std::vector tensors(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + if (!Preprocess(&(*images)[i], &tensors[i], &(*ims_info)[i])) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + } + + if (tensors.size() == 1) { + (*outputs)[0] = std::move(tensors[0]); + } else { + function::Concat(tensors, &((*outputs)[0]), 0); + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7/preprocessor.h b/fastdeploy/vision/detection/contrib/yolov7/preprocessor.h new file mode 100755 index 0000000000..ff6c6cad55 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov7/preprocessor.h @@ -0,0 +1,96 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace detection { +/*! @brief Preprocessor object for YOLOv7 serials model. + */ +class FASTDEPLOY_DECL YOLOv7Preprocessor { + public: + /** \brief Create a preprocessor instance for YOLOv7 serials model + */ + YOLOv7Preprocessor(); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \param[in] ims_info The shape info list, record input_shape and output_shape + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, std::vector* outputs, + std::vector>>* ims_info); + + /// Set target size, tuple of (width, height), default size = {640, 640} + void SetSize(const std::vector& size) { size_ = size; } + + /// Get target size, tuple of (width, height), default size = {640, 640} + std::vector GetSize() const { return size_; } + + /// Set padding value, size should be the same as channels + void SetPaddingValue(const std::vector& padding_value) { + padding_value_ = padding_value; + } + + /// Get padding value, size should be the same as channels + std::vector GetPaddingValue() const { return padding_value_; } + + /// Set is_scale_up, if is_scale_up is false, the input image only + /// can be zoom out, the maximum resize scale cannot exceed 1.0, default true + void SetScaleUp(bool is_scale_up) { + is_scale_up_ = is_scale_up; + } + + /// Get is_scale_up, default true + bool GetScaleUp() const { return is_scale_up_; } + + protected: + bool Preprocess(FDMat* mat, FDTensor* output, + std::map>* im_info); + + void LetterBox(FDMat* mat); + + // target size, tuple of (width, height), default size = {640, 640} + std::vector size_; + + // padding value, size should be the same as channels + std::vector padding_value_; + + // only pad to the minimum rectange which height and width is times of stride + bool is_mini_pad_; + + // while is_mini_pad = false and is_no_pad = true, + // will resize the image to the set size + bool is_no_pad_; + + // if is_scale_up is false, the input image only can be zoom out, + // the maximum resize scale cannot exceed 1.0 + bool is_scale_up_; + + // padding stride, for is_mini_pad + int stride_; + + // for offseting the boxes by classes when using NMS + float max_wh_; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7/yolov7.cc b/fastdeploy/vision/detection/contrib/yolov7/yolov7.cc new file mode 100755 index 0000000000..513351a095 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov7/yolov7.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/yolov7/yolov7.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +YOLOv7::YOLOv7(const std::string& model_file, const std::string& params_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + if (model_format == ModelFormat::ONNX) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::LITE}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool YOLOv7::Initialize() { + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, float nms_threshold) { + postprocessor_.SetConfThreshold(conf_threshold); + postprocessor_.SetNMSThreshold(nms_threshold); + if (!Predict(*im, result)) { + return false; + } + return true; +} + +bool YOLOv7::Predict(const cv::Mat& im, DetectionResult* result) { + std::vector results; + if (!BatchPredict({im}, &results)) { + return false; + } + *result = std::move(results[0]); + return true; +} + +bool YOLOv7::BatchPredict(const std::vector& images, std::vector* results) { + std::vector>> ims_info; + std::vector fd_images = WrapMat(images); + + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &ims_info)) { + FDERROR << "Failed to preprocess the input image." << std::endl; + return false; + } + + reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; + return false; + } + + if (!postprocessor_.Run(reused_output_tensors_, results, ims_info)) { + FDERROR << "Failed to postprocess the inference results by runtime." << std::endl; + return false; + } + + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7/yolov7.h b/fastdeploy/vision/detection/contrib/yolov7/yolov7.h new file mode 100755 index 0000000000..2c36fd0c80 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov7/yolov7.h @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/detection/contrib/yolov7/preprocessor.h" +#include "fastdeploy/vision/detection/contrib/yolov7/postprocessor.h" + +namespace fastdeploy { +namespace vision { +namespace detection { +/*! @brief YOLOv7 model object used when to load a YOLOv7 model exported by YOLOv7. + */ +class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { + public: + /** \brief Set path of model file and the configuration of runtime. + * + * \param[in] model_file Path of model file, e.g ./yolov7.onnx + * \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends" + * \param[in] model_format Model format of the loaded model, default is ONNX format + */ + YOLOv7(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX); + + std::string ModelName() const { return "yolov7"; } + + /** \brief DEPRECATED Predict the detection result for an input image, remove at 1.0 version + * + * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output detection result will be writen to this structure + * \param[in] conf_threshold confidence threashold for postprocessing, default is 0.25 + * \param[in] nms_threshold iou threashold for NMS, default is 0.5 + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(cv::Mat* im, DetectionResult* result, + float conf_threshold = 0.25, + float nms_threshold = 0.5); + + /** \brief Predict the detection result for an input image + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output detection result will be writen to this structure + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(const cv::Mat& img, DetectionResult* result); + + /** \brief Predict the detection results for a batch of input images + * + * \param[in] imgs, The input image list, each element comes from cv::imread() + * \param[in] results The output detection result list + * \return true if the prediction successed, otherwise false + */ + virtual bool BatchPredict(const std::vector& imgs, + std::vector* results); + + /// Get preprocessor reference of YOLOv7 + virtual YOLOv7Preprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference of YOLOv7 + virtual YOLOv7Postprocessor& GetPostprocessor() { + return postprocessor_; + } + + protected: + bool Initialize(); + YOLOv7Preprocessor preprocessor_; + YOLOv7Postprocessor postprocessor_; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7/yolov7_pybind.cc b/fastdeploy/vision/detection/contrib/yolov7/yolov7_pybind.cc new file mode 100755 index 0000000000..6899faa916 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov7/yolov7_pybind.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindYOLOv7(pybind11::module& m) { + pybind11::class_( + m, "YOLOv7Preprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::detection::YOLOv7Preprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + std::vector>> ims_info; + if (!self.Run(&images, &outputs, &ims_info)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleClasPreprocessor.')"); + } + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + return make_pair(outputs, ims_info); + }) + .def_property("size", &vision::detection::YOLOv7Preprocessor::GetSize, &vision::detection::YOLOv7Preprocessor::SetSize) + .def_property("padding_value", &vision::detection::YOLOv7Preprocessor::GetPaddingValue, &vision::detection::YOLOv7Preprocessor::SetPaddingValue) + .def_property("is_scale_up", &vision::detection::YOLOv7Preprocessor::GetScaleUp, &vision::detection::YOLOv7Preprocessor::SetScaleUp); + + pybind11::class_( + m, "YOLOv7Postprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::detection::YOLOv7Postprocessor& self, std::vector& inputs, + const std::vector>>& ims_info) { + std::vector results; + if (!self.Run(inputs, &results, ims_info)) { + pybind11::eval("raise Exception('Failed to postprocess the runtime result in YOLOv7Postprocessor.')"); + } + return results; + }) + .def("run", [](vision::detection::YOLOv7Postprocessor& self, std::vector& input_array, + const std::vector>>& ims_info) { + std::vector results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results, ims_info)) { + pybind11::eval("raise Exception('Failed to postprocess the runtime result in YOLOv7Postprocessor.')"); + } + return results; + }) + .def_property("conf_threshold", &vision::detection::YOLOv7Postprocessor::GetConfThreshold, &vision::detection::YOLOv7Postprocessor::SetConfThreshold) + .def_property("nms_threshold", &vision::detection::YOLOv7Postprocessor::GetNMSThreshold, &vision::detection::YOLOv7Postprocessor::SetNMSThreshold); + + pybind11::class_(m, "YOLOv7") + .def(pybind11::init()) + .def("predict", + [](vision::detection::YOLOv7& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(mat, &res); + return res; + }) + .def("batch_predict", [](vision::detection::YOLOv7& self, std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; + }) + .def_property_readonly("preprocessor", &vision::detection::YOLOv7::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::detection::YOLOv7::GetPostprocessor); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov7_pybind.cc b/fastdeploy/vision/detection/contrib/yolov7_pybind.cc deleted file mode 100644 index d7ab993401..0000000000 --- a/fastdeploy/vision/detection/contrib/yolov7_pybind.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/pybind/main.h" - -namespace fastdeploy { -void BindYOLOv7(pybind11::module& m) { - pybind11::class_(m, "YOLOv7") - .def(pybind11::init()) - .def("predict", - [](vision::detection::YOLOv7& self, pybind11::array& data, - float conf_threshold, float nms_iou_threshold) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); - return res; - }) - .def("use_cuda_preprocessing", - [](vision::detection::YOLOv7& self, int max_image_size) { - self.UseCudaPreprocessing(max_image_size); - }) - .def_readwrite("size", &vision::detection::YOLOv7::size) - .def_readwrite("padding_value", &vision::detection::YOLOv7::padding_value) - .def_readwrite("is_mini_pad", &vision::detection::YOLOv7::is_mini_pad) - .def_readwrite("is_no_pad", &vision::detection::YOLOv7::is_no_pad) - .def_readwrite("is_scale_up", &vision::detection::YOLOv7::is_scale_up) - .def_readwrite("stride", &vision::detection::YOLOv7::stride) - .def_readwrite("max_wh", &vision::detection::YOLOv7::max_wh); -} -} // namespace fastdeploy diff --git a/python/fastdeploy/vision/detection/__init__.py b/python/fastdeploy/vision/detection/__init__.py index 6de4a3fa63..b5f01f3a77 100755 --- a/python/fastdeploy/vision/detection/__init__.py +++ b/python/fastdeploy/vision/detection/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import absolute_import -from .contrib.yolov7 import YOLOv7 +from .contrib.yolov7 import * from .contrib.yolor import YOLOR from .contrib.scaled_yolov4 import ScaledYOLOv4 from .contrib.nanodet_plus import NanoDetPlus diff --git a/python/fastdeploy/vision/detection/contrib/yolov5.py b/python/fastdeploy/vision/detection/contrib/yolov5.py index 42eccb88d4..b8113f3b83 100644 --- a/python/fastdeploy/vision/detection/contrib/yolov5.py +++ b/python/fastdeploy/vision/detection/contrib/yolov5.py @@ -41,9 +41,19 @@ def size(self): @property def padding_value(self): + """ + padding value for preprocessing, default [114.0, 114.0, 114.0] + """ # padding value, size should be the same as channels return self._preprocessor.padding_value + @property + def is_scale_up(self): + """ + is_scale_up for preprocessing, the input image only can be zoom out, the maximum resize scale cannot exceed 1.0, default true + """ + return self._preprocessor.is_scale_up + @size.setter def size(self, wh): assert isinstance(wh, (list, tuple)),\ @@ -60,6 +70,13 @@ def padding_value(self, value): list), "The value to set `padding_value` must be type of list." self._preprocessor.padding_value = value + @is_scale_up.setter + def is_scale_up(self, value): + assert isinstance( + value, + bool), "The value to set `is_scale_up` must be type of bool." + self._preprocessor.is_scale_up = value + class YOLOv5Postprocessor: def __init__(self): @@ -93,7 +110,7 @@ def nms_threshold(self): @property def multi_label(self): """ - multi_label for postprocessing, default is true + multi_label for postprocessing, set true for eval, default is True """ return self._postprocessor.multi_label diff --git a/python/fastdeploy/vision/detection/contrib/yolov7.py b/python/fastdeploy/vision/detection/contrib/yolov7.py index 0334504851..510b72ed65 100644 --- a/python/fastdeploy/vision/detection/contrib/yolov7.py +++ b/python/fastdeploy/vision/detection/contrib/yolov7.py @@ -18,77 +18,41 @@ from .... import c_lib_wrap as C -class YOLOv7(FastDeployModel): - def __init__(self, - model_file, - params_file="", - runtime_option=None, - model_format=ModelFormat.ONNX): - """Load a YOLOv7 model exported by YOLOv7. - - :param model_file: (str)Path of model file, e.g ./yolov7.onnx - :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string - :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU - :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model +class YOLOv7Preprocessor: + def __init__(self): + """Create a preprocessor for YOLOv7 """ - # 调用基函数进行backend_option的初始化 - # 初始化后的option保存在self._runtime_option - super(YOLOv7, self).__init__(runtime_option) - - self._model = C.vision.detection.YOLOv7( - model_file, params_file, self._runtime_option, model_format) - # 通过self.initialized判断整个模型的初始化是否成功 - assert self.initialized, "YOLOv7 initialize failed." + self._preprocessor = C.vision.detection.YOLOv7Preprocessor() - def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5): - """Detect an input image + def run(self, input_ims): + """Preprocess input images for YOLOv7 - :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format - :param conf_threshold: confidence threashold for postprocessing, default is 0.25 - :param nms_iou_threshold: iou threashold for NMS, default is 0.5 - :return: DetectionResult + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor """ - return self._model.predict(input_image, conf_threshold, - nms_iou_threshold) + return self._preprocessor.run(input_ims) - # 一些跟YOLOv7模型有关的属性封装 - # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持) @property def size(self): """ Argument for image preprocessing step, the preprocess image size, tuple of (width, height), default size = [640, 640] """ - return self._model.size + return self._preprocessor.size @property def padding_value(self): + """ + padding value for preprocessing, default [114.0, 114.0, 114.0] + """ # padding value, size should be the same as channels - return self._model.padding_value - - @property - def is_no_pad(self): - # while is_mini_pad = false and is_no_pad = true, will resize the image to the set size - return self._model.is_no_pad - - @property - def is_mini_pad(self): - # only pad to the minimum rectange which height and width is times of stride - return self._model.is_mini_pad + return self._preprocessor.padding_value @property def is_scale_up(self): - # if is_scale_up is false, the input image only can be zoom out, the maximum resize scale cannot exceed 1.0 - return self._model.is_scale_up - - @property - def stride(self): - # padding stride, for is_mini_pad - return self._model.stride - - @property - def max_wh(self): - # for offseting the boxes by classes when using NMS - return self._model.max_wh + """ + is_scale_up for preprocessing, the input image only can be zoom out, the maximum resize scale cannot exceed 1.0, default true + """ + return self._preprocessor.is_scale_up @size.setter def size(self, wh): @@ -97,43 +61,122 @@ def size(self, wh): assert len(wh) == 2,\ "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( len(wh)) - self._model.size = wh + self._preprocessor.size = wh @padding_value.setter def padding_value(self, value): assert isinstance( value, list), "The value to set `padding_value` must be type of list." - self._model.padding_value = value - - @is_no_pad.setter - def is_no_pad(self, value): - assert isinstance( - value, bool), "The value to set `is_no_pad` must be type of bool." - self._model.is_no_pad = value - - @is_mini_pad.setter - def is_mini_pad(self, value): - assert isinstance( - value, - bool), "The value to set `is_mini_pad` must be type of bool." - self._model.is_mini_pad = value + self._preprocessor.padding_value = value @is_scale_up.setter def is_scale_up(self, value): assert isinstance( value, bool), "The value to set `is_scale_up` must be type of bool." - self._model.is_scale_up = value + self._preprocessor.is_scale_up = value - @stride.setter - def stride(self, value): - assert isinstance( - value, int), "The value to set `stride` must be type of int." - self._model.stride = value - @max_wh.setter - def max_wh(self, value): - assert isinstance( - value, float), "The value to set `max_wh` must be type of float." - self._model.max_wh = value +class YOLOv7Postprocessor: + def __init__(self): + """Create a postprocessor for YOLOv7 + """ + self._postprocessor = C.vision.detection.YOLOv7Postprocessor() + + def run(self, runtime_results, ims_info): + """Postprocess the runtime results for YOLOv7 + + :param: runtime_results: (list of FDTensor)The output FDTensor results from runtime + :param: ims_info: (list of dict)Record input_shape and output_shape + :return: list of DetectionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results, ims_info) + + @property + def conf_threshold(self): + """ + confidence threshold for postprocessing, default is 0.25 + """ + return self._postprocessor.conf_threshold + + @property + def nms_threshold(self): + """ + nms threshold for postprocessing, default is 0.5 + """ + return self._postprocessor.nms_threshold + + @conf_threshold.setter + def conf_threshold(self, conf_threshold): + assert isinstance(conf_threshold, float),\ + "The value to set `conf_threshold` must be type of float." + self._postprocessor.conf_threshold = conf_threshold + + @nms_threshold.setter + def nms_threshold(self, nms_threshold): + assert isinstance(nms_threshold, float),\ + "The value to set `nms_threshold` must be type of float." + self._postprocessor.nms_threshold = nms_threshold + + +class YOLOv7(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=ModelFormat.ONNX): + """Load a YOLOv7 model exported by YOLOv7. + + :param model_file: (str)Path of model file, e.g ./yolov7.onnx + :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + # 调用基函数进行backend_option的初始化 + # 初始化后的option保存在self._runtime_option + super(YOLOv7, self).__init__(runtime_option) + + assert model_format == ModelFormat.ONNX, "YOLOv7 only support model format of ModelFormat.ONNX now." + self._model = C.vision.detection.YOLOv7( + model_file, params_file, self._runtime_option, model_format) + # 通过self.initialized判断整个模型的初始化是否成功 + assert self.initialized, "YOLOv7 initialize failed." + + def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5): + """Detect an input image + + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :param conf_threshold: confidence threshold for postprocessing, default is 0.25 + :param nms_iou_threshold: iou threshold for NMS, default is 0.5 + :return: DetectionResult + """ + + self.postprocessor.conf_threshold = conf_threshold + self.postprocessor.nms_threshold = nms_iou_threshold + return self._model.predict(input_image) + + def batch_predict(self, images): + """Classify a batch of input image + + :param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return list of DetectionResult + """ + + return self._model.batch_predict(images) + + @property + def preprocessor(self): + """Get YOLOv7Preprocessor object of the loaded model + + :return YOLOv7Preprocessor + """ + return self._model.preprocessor + + @property + def postprocessor(self): + """Get YOLOv7Postprocessor object of the loaded model + + :return YOLOv7Postprocessor + """ + return self._model.postprocessor diff --git a/tests/models/test_yolov7.py b/tests/models/test_yolov7.py new file mode 100755 index 0000000000..ba08fbaf5b --- /dev/null +++ b/tests/models/test_yolov7.py @@ -0,0 +1,165 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastdeploy import ModelFormat +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_yolov7(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov7.onnx" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + input_url2 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000570688.jpg" + result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov7_result1.pkl" + result_url2 = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov7_result2.pkl" + fd.download(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(input_url2, "resources") + fd.download(result_url1, "resources") + fd.download(result_url2, "resources") + + model_file = "resources/yolov7.onnx" + model = fd.vision.detection.YOLOv7( + model_file, runtime_option=rc.test_option) + + with open("resources/yolov7_result1.pkl", "rb") as f: + expect1 = pickle.load(f) + + with open("resources/yolov7_result2.pkl", "rb") as f: + expect2 = pickle.load(f) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + im2 = cv2.imread("./resources/000000570688.jpg") + + for i in range(3): + # test single predict + result1 = model.predict(im1) + result2 = model.predict(im2) + + diff_boxes_1 = np.fabs( + np.array(result1.boxes) - np.array(expect1["boxes"])) + diff_boxes_2 = np.fabs( + np.array(result2.boxes) - np.array(expect2["boxes"])) + + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expect1["label_ids"])) + diff_label_2 = np.fabs( + np.array(result2.label_ids) - np.array(expect2["label_ids"])) + + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expect1["scores"])) + diff_scores_2 = np.fabs( + np.array(result2.scores) - np.array(expect2["scores"])) + + assert diff_boxes_1.max( + ) < 1e-06, "There's difference in detection boxes 1." + assert diff_label_1.max( + ) < 1e-06, "There's difference in detection label 1." + assert diff_scores_1.max( + ) < 1e-05, "There's difference in detection score 1." + + assert diff_boxes_2.max( + ) < 1e-06, "There's difference in detection boxes 2." + assert diff_label_2.max( + ) < 1e-06, "There's difference in detection label 2." + assert diff_scores_2.max( + ) < 1e-05, "There's difference in detection score 2." + + # test batch predict + results = model.batch_predict([im1, im2]) + result1 = results[0] + result2 = results[1] + + diff_boxes_1 = np.fabs( + np.array(result1.boxes) - np.array(expect1["boxes"])) + diff_boxes_2 = np.fabs( + np.array(result2.boxes) - np.array(expect2["boxes"])) + + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expect1["label_ids"])) + diff_label_2 = np.fabs( + np.array(result2.label_ids) - np.array(expect2["label_ids"])) + + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expect1["scores"])) + diff_scores_2 = np.fabs( + np.array(result2.scores) - np.array(expect2["scores"])) + assert diff_boxes_1.max( + ) < 1e-06, "There's difference in detection boxes 1." + assert diff_label_1.max( + ) < 1e-06, "There's difference in detection label 1." + assert diff_scores_1.max( + ) < 1e-05, "There's difference in detection score 1." + + assert diff_boxes_2.max( + ) < 1e-06, "There's difference in detection boxes 2." + assert diff_label_2.max( + ) < 1e-06, "There's difference in detection label 2." + assert diff_scores_2.max( + ) < 1e-05, "There's difference in detection score 2." + + +def test_detection_yolov7_runtime(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov7.onnx" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov7_result1.pkl" + fd.download(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url1, "resources") + + model_file = "resources/yolov7.onnx" + + preprocessor = fd.vision.detection.YOLOv7Preprocessor() + postprocessor = fd.vision.detection.YOLOv7Postprocessor() + + rc.test_option.set_model_path(model_file, model_format=ModelFormat.ONNX) + rc.test_option.use_openvino_backend() + runtime = fd.Runtime(rc.test_option) + + with open("resources/yolov7_result1.pkl", "rb") as f: + expect1 = pickle.load(f) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + + for i in range(3): + # test runtime + input_tensors, ims_info = preprocessor.run([im1.copy()]) + output_tensors = runtime.infer({"images": input_tensors[0]}) + results = postprocessor.run(output_tensors, ims_info) + result1 = results[0] + + diff_boxes_1 = np.fabs( + np.array(result1.boxes) - np.array(expect1["boxes"])) + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expect1["label_ids"])) + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expect1["scores"])) + + assert diff_boxes_1.max( + ) < 1e-04, "There's difference in detection boxes 1." + assert diff_label_1.max( + ) < 1e-06, "There's difference in detection label 1." + assert diff_scores_1.max( + ) < 1e-05, "There's difference in detection score 1." + + +if __name__ == "__main__": + test_detection_yolov7() + test_detection_yolov7_runtime()