From 63dd0c8f978902320a5b8bd5c9c086b46ec609e4 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Mon, 1 Aug 2022 09:34:41 +0000 Subject: [PATCH 1/5] Fix runtime with python --- csrcs/fastdeploy/backends/ort/ort_backend.cc | 30 +---- csrcs/fastdeploy/backends/ort/utils.cc | 6 +- csrcs/fastdeploy/backends/ort/utils.h | 6 +- csrcs/fastdeploy/core/fd_tensor.cc | 6 +- csrcs/fastdeploy/core/fd_type.cc | 116 +++++++---------- csrcs/fastdeploy/core/fd_type.h | 8 +- csrcs/fastdeploy/pybind/fastdeploy_runtime.cc | 1 + csrcs/fastdeploy/pybind/main.cc | 9 +- csrcs/fastdeploy/vision/ppdet/ppyoloe.cc | 29 ++++- csrcs/fastdeploy/vision/ppdet/ppyoloe.h | 4 + external/paddle2onnx.cmake | 2 +- fastdeploy/__init__.py | 74 +---------- .../{fastdeploy_runtime.py => model.py} | 32 ----- fastdeploy/runtime.py | 121 ++++++++++++++++++ 14 files changed, 219 insertions(+), 225 deletions(-) rename fastdeploy/{fastdeploy_runtime.py => model.py} (61%) create mode 100644 fastdeploy/runtime.py diff --git a/csrcs/fastdeploy/backends/ort/ort_backend.cc b/csrcs/fastdeploy/backends/ort/ort_backend.cc index 9fdb3c66b7f..c17890109b7 100644 --- a/csrcs/fastdeploy/backends/ort/ort_backend.cc +++ b/csrcs/fastdeploy/backends/ort/ort_backend.cc @@ -26,35 +26,6 @@ namespace fastdeploy { std::vector OrtBackend::custom_operators_ = std::vector(); -ONNXTensorElementDataType GetOrtDtype(FDDataType fd_dtype) { - if (fd_dtype == FDDataType::FP32) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (fd_dtype == FDDataType::FP64) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - } else if (fd_dtype == FDDataType::INT32) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; - } else if (fd_dtype == FDDataType::INT64) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } - FDERROR << "Unrecognized fastdeply data type:" << FDDataTypeStr(fd_dtype) - << "." << std::endl; - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; -} - -FDDataType GetFdDtype(ONNXTensorElementDataType ort_dtype) { - if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return FDDataType::FP32; - } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - return FDDataType::FP64; - } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { - return FDDataType::INT32; - } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - return FDDataType::INT64; - } - FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl; - return FDDataType::FP32; -} - void OrtBackend::BuildOption(const OrtBackendOption& option) { option_ = option; if (option.graph_optimization_level >= 0) { @@ -263,6 +234,7 @@ bool OrtBackend::Infer(std::vector& inputs, (*outputs)[i].name = outputs_desc_[i].name; CopyToCpu(ort_outputs[i], &((*outputs)[i])); } + return true; } diff --git a/csrcs/fastdeploy/backends/ort/utils.cc b/csrcs/fastdeploy/backends/ort/utils.cc index bbef1f3786e..ae3e45b8664 100644 --- a/csrcs/fastdeploy/backends/ort/utils.cc +++ b/csrcs/fastdeploy/backends/ort/utils.cc @@ -27,8 +27,8 @@ ONNXTensorElementDataType GetOrtDtype(const FDDataType& fd_dtype) { } else if (fd_dtype == FDDataType::INT64) { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; } - FDERROR << "Unrecognized fastdeply data type:" << FDDataTypeStr(fd_dtype) - << "." << std::endl; + FDERROR << "Unrecognized fastdeply data type:" << Str(fd_dtype) << "." + << std::endl; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -64,4 +64,4 @@ Ort::Value CreateOrtValue(FDTensor& tensor, bool is_backend_cuda) { return ort_value; } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/backends/ort/utils.h b/csrcs/fastdeploy/backends/ort/utils.h index b1b29e5ab11..e2912ad38f5 100644 --- a/csrcs/fastdeploy/backends/ort/utils.h +++ b/csrcs/fastdeploy/backends/ort/utils.h @@ -20,7 +20,7 @@ #include #include "fastdeploy/backends/backend.h" -#include "onnxruntime_cxx_api.h" // NOLINT +#include "onnxruntime_cxx_api.h" // NOLINT namespace fastdeploy { @@ -28,7 +28,7 @@ namespace fastdeploy { ONNXTensorElementDataType GetOrtDtype(const FDDataType& fd_dtype); // Convert OrtDataType to FDDataType -FDDataType GetFdDtype(const ONNXTensorElementDataType* ort_dtype); +FDDataType GetFdDtype(const ONNXTensorElementDataType& ort_dtype); // Create Ort::Value // is_backend_cuda specify if the onnxruntime use CUDAExectionProvider @@ -36,4 +36,4 @@ FDDataType GetFdDtype(const ONNXTensorElementDataType* ort_dtype); // Will directly share the cuda data in tensor to OrtValue Ort::Value CreateOrtValue(FDTensor& tensor, bool is_backend_cuda = false); -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/core/fd_tensor.cc b/csrcs/fastdeploy/core/fd_tensor.cc index 97b33dad589..dbefbd9ecca 100644 --- a/csrcs/fastdeploy/core/fd_tensor.cc +++ b/csrcs/fastdeploy/core/fd_tensor.cc @@ -119,9 +119,9 @@ void FDTensor::PrintInfo(const std::string& prefix) { for (int i = 0; i < shape.size(); ++i) { std::cout << shape[i] << " "; } - std::cout << ", dtype=" << FDDataTypeStr(dtype) << ", mean=" << mean - << ", max=" << max << ", min=" << min << std::endl; + std::cout << ", dtype=" << Str(dtype) << ", mean=" << mean << ", max=" << max + << ", min=" << min << std::endl; } FDTensor::FDTensor(const std::string& tensor_name) { name = tensor_name; } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/core/fd_type.cc b/csrcs/fastdeploy/core/fd_type.cc index b66cabeb8b0..8d624cdf270 100644 --- a/csrcs/fastdeploy/core/fd_type.cc +++ b/csrcs/fastdeploy/core/fd_type.cc @@ -17,7 +17,7 @@ namespace fastdeploy { -int FDDataTypeSize(FDDataType data_type) { +int FDDataTypeSize(const FDDataType& data_type) { FDASSERT(data_type != FDDataType::FP16, "Float16 is not supported."); if (data_type == FDDataType::BOOL) { return sizeof(bool); @@ -34,89 +34,63 @@ int FDDataTypeSize(FDDataType data_type) { } else if (data_type == FDDataType::UINT8) { return sizeof(uint8_t); } else { - FDASSERT(false, "Unexpected data type: " + FDDataTypeStr(data_type)); + FDASSERT(false, "Unexpected data type: " + Str(data_type)); } return -1; } -std::string FDDataTypeStr(FDDataType data_type) { - FDASSERT(data_type != FDDataType::FP16, "Float16 is not supported."); - if (data_type == FDDataType::BOOL) { - return "bool"; - } else if (data_type == FDDataType::INT16) { - return "int16"; - } else if (data_type == FDDataType::INT32) { - return "int32"; - } else if (data_type == FDDataType::INT64) { - return "int64"; - } else if (data_type == FDDataType::FP16) { - return "float16"; - } else if (data_type == FDDataType::FP32) { - return "float32"; - } else if (data_type == FDDataType::FP64) { - return "float64"; - } else if (data_type == FDDataType::UINT8) { - return "uint8"; - } else if (data_type == FDDataType::INT8) { - return "int8"; - } else { - FDASSERT(false, "Unexpected data type: " + FDDataTypeStr(data_type)); - } - return "UNKNOWN!"; -} - -std::string Str(Device& d) { +std::string Str(const Device& d) { std::string out; switch (d) { - case Device::DEFAULT: - out = "Device::DEFAULT"; - break; - case Device::CPU: - out = "Device::CPU"; - break; - case Device::GPU: - out = "Device::GPU"; - break; - default: - out = "Device::UNKOWN"; + case Device::DEFAULT: + out = "Device::DEFAULT"; + break; + case Device::CPU: + out = "Device::CPU"; + break; + case Device::GPU: + out = "Device::GPU"; + break; + default: + out = "Device::UNKOWN"; } return out; } -std::string Str(FDDataType& fdt) { +std::string Str(const FDDataType& fdt) { std::string out; switch (fdt) { - case FDDataType::BOOL: - out = "FDDataType::BOOL"; - break; - case FDDataType::INT16: - out = "FDDataType::INT16"; - break; - case FDDataType::INT32: - out = "FDDataType::INT32"; - break; - case FDDataType::INT64: - out = "FDDataType::INT64"; - break; - case FDDataType::FP32: - out = "FDDataType::FP32"; - break; - case FDDataType::FP64: - out = "FDDataType::FP64"; - break; - case FDDataType::FP16: - out = "FDDataType::FP16"; - break; - case FDDataType::UINT8: - out = "FDDataType::UINT8"; - break; - case FDDataType::INT8: - out = "FDDataType::INT8"; - break; - default: - out = "FDDataType::UNKNOWN"; + case FDDataType::BOOL: + out = "FDDataType::BOOL"; + break; + case FDDataType::INT16: + out = "FDDataType::INT16"; + break; + case FDDataType::INT32: + out = "FDDataType::INT32"; + break; + case FDDataType::INT64: + out = "FDDataType::INT64"; + break; + case FDDataType::FP32: + out = "FDDataType::FP32"; + break; + case FDDataType::FP64: + out = "FDDataType::FP64"; + break; + case FDDataType::FP16: + out = "FDDataType::FP16"; + break; + case FDDataType::UINT8: + out = "FDDataType::UINT8"; + break; + case FDDataType::INT8: + out = "FDDataType::INT8"; + break; + default: + out = "FDDataType::UNKNOWN"; } return out; } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/core/fd_type.h b/csrcs/fastdeploy/core/fd_type.h index 768ac1e369f..325551dfb3a 100644 --- a/csrcs/fastdeploy/core/fd_type.h +++ b/csrcs/fastdeploy/core/fd_type.h @@ -24,7 +24,7 @@ namespace fastdeploy { enum FASTDEPLOY_DECL Device { DEFAULT, CPU, GPU }; -FASTDEPLOY_DECL std::string Str(Device& d); +FASTDEPLOY_DECL std::string Str(const Device& d); enum FASTDEPLOY_DECL FDDataType { BOOL, @@ -51,9 +51,7 @@ enum FASTDEPLOY_DECL FDDataType { INT8 }; -FASTDEPLOY_DECL std::string Str(FDDataType& fdt); +FASTDEPLOY_DECL std::string Str(const FDDataType& fdt); -FASTDEPLOY_DECL int32_t FDDataTypeSize(FDDataType data_dtype); - -FASTDEPLOY_DECL std::string FDDataTypeStr(FDDataType data_dtype); +FASTDEPLOY_DECL int32_t FDDataTypeSize(const FDDataType& data_dtype); } // namespace fastdeploy diff --git a/csrcs/fastdeploy/pybind/fastdeploy_runtime.cc b/csrcs/fastdeploy/pybind/fastdeploy_runtime.cc index 3ede38040bf..412b1ccefd3 100644 --- a/csrcs/fastdeploy/pybind/fastdeploy_runtime.cc +++ b/csrcs/fastdeploy/pybind/fastdeploy_runtime.cc @@ -79,6 +79,7 @@ void BindRuntime(pybind11::module& m) { memcpy(inputs[index].data.data(), iter->second.mutable_data(), iter->second.nbytes()); inputs[index].name = iter->first; + index += 1; } std::vector outputs(self.NumOutputs()); diff --git a/csrcs/fastdeploy/pybind/main.cc b/csrcs/fastdeploy/pybind/main.cc index 86467215e22..e0c00c8a044 100644 --- a/csrcs/fastdeploy/pybind/main.cc +++ b/csrcs/fastdeploy/pybind/main.cc @@ -32,7 +32,7 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) { dt = pybind11::dtype::of(); } else { FDASSERT(false, "The function doesn't support data type of " + - FDDataTypeStr(fd_dtype) + "."); + Str(fd_dtype) + "."); } return dt; } @@ -47,8 +47,9 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) { } else if (np_dtype.is(pybind11::dtype::of())) { return FDDataType::FP64; } - FDASSERT(false, "NumpyDataTypeToFDDataType() only support " - "int32/int64/float32/float64 now."); + FDASSERT(false, + "NumpyDataTypeToFDDataType() only support " + "int32/int64/float32/float64 now."); return FDDataType::FP32; } @@ -112,4 +113,4 @@ PYBIND11_MODULE(fastdeploy_main, m) { #endif } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc index 9c698976e0a..5152db3fa26 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc @@ -23,20 +23,31 @@ PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file, initialized = Initialize(); } -bool PPYOLOE::Initialize() { -#ifdef ENABLE_PADDLE_FRONTEND - // remove multiclass_nms3 now - // this is a trick operation for ppyoloe while inference on trt +void PPYOLOE::GetNmsInfo() { if (runtime_option.model_format == Frontend::PADDLE) { std::string contents; if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) { - return false; + return; } auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size()); if (reader.has_nms) { has_nms_ = true; + background_label = reader.nms_params.background_label; + keep_top_k = reader.nms_params.keep_top_k; + nms_eta = reader.nms_params.nms_eta; + nms_threshold = reader.nms_params.nms_threshold; + score_threshold = reader.nms_params.score_threshold; + nms_top_k = reader.nms_params.nms_top_k; + normalized = reader.nms_params.normalized; } } +} + +bool PPYOLOE::Initialize() { +#ifdef ENABLE_PADDLE_FRONTEND + // remove multiclass_nms3 now + // this is a trick operation for ppyoloe while inference on trt + GetNmsInfo(); runtime_option.remove_multiclass_nms_ = true; runtime_option.custom_op_info_["multiclass_nms3"] = "MultiClassNMS"; #endif @@ -52,8 +63,12 @@ bool PPYOLOE::Initialize() { if (has_nms_ && runtime_option.backend == Backend::TRT) { FDINFO << "Detected operator multiclass_nms3 in your model, will replace " - "it with fastdeploy::backend::MultiClassNMS replace it." - << std::endl; + "it with fastdeploy::backend::MultiClassNMS(background_label=" + << background_label << ", keep_top_k=" << keep_top_k + << ", nms_eta=" << nms_eta << ", nms_threshold=" << nms_threshold + << ", score_threshold=" << score_threshold + << ", nms_top_k=" << nms_top_k << ", normalized=" << normalized + << ")." << std::endl; has_nms_ = false; } return true; diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h index ec22aa2cedb..d86508fa184 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h @@ -42,6 +42,10 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { int64_t nms_top_k = 10000; bool normalized = true; bool has_nms_ = false; + + // This function will used to check if this model contains multiclass_nms + // and get parameters from the operator + void GetNmsInfo(); }; } // namespace ppdet } // namespace vision diff --git a/external/paddle2onnx.cmake b/external/paddle2onnx.cmake index e226bc6c954..ae6f4acdab5 100644 --- a/external/paddle2onnx.cmake +++ b/external/paddle2onnx.cmake @@ -43,7 +43,7 @@ else() endif(WIN32) set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/paddle2onnx/libs/") -set(PADDLE2ONNX_VERSION "1.0.0rc2") +set(PADDLE2ONNX_VERSION "1.0.0rc3") if(WIN32) set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") if(NOT CMAKE_CL_64) diff --git a/fastdeploy/__init__.py b/fastdeploy/__init__.py index f9b9f686e0c..6a23cd3d2c4 100644 --- a/fastdeploy/__init__.py +++ b/fastdeploy/__init__.py @@ -16,12 +16,14 @@ import os import sys + def add_dll_search_dir(dir_path): os.environ["path"] = dir_path + ";" + os.environ["path"] sys.path.insert(0, dir_path) if sys.version_info[:2] >= (3, 8): os.add_dll_directory(dir_path) + if os.name == "nt": current_path = os.path.abspath(__file__) dirname = os.path.dirname(current_path) @@ -33,82 +35,19 @@ def add_dll_search_dir(dir_path): add_dll_search_dir(os.path.join(dirname, root, d)) from .fastdeploy_main import Frontend, Backend, FDDataType, TensorInfo, Device -from .fastdeploy_runtime import * +from .runtime import Runtime, RuntimeOption +from .model import FastDeployModel from . import fastdeploy_main as C from . import vision from .download import download, download_and_decompress + def TensorInfoStr(tensor_info): message = "TensorInfo(name : '{}', dtype : '{}', shape : '{}')".format( tensor_info.name, tensor_info.dtype, tensor_info.shape) return message -class RuntimeOption: - def __init__(self): - self._option = C.RuntimeOption() - - def set_model_path(self, model_path, params_path="", model_format="paddle"): - return self._option.set_model_path(model_path, params_path, model_format) - - def use_gpu(self, device_id=0): - return self._option.use_gpu(device_id) - - def use_cpu(self): - return self._option.use_cpu() - - def set_cpu_thread_num(self, thread_num=8): - return self._option.set_cpu_thread_num(thread_num) - - def use_paddle_backend(self): - return self._option.use_paddle_backend() - - def use_ort_backend(self): - return self._option.use_ort_backend() - - def use_trt_backend(self): - return self._option.use_trt_backend() - - def enable_paddle_mkldnn(self): - return self._option.enable_paddle_mkldnn() - - def disable_paddle_mkldnn(self): - return self._option.disable_paddle_mkldnn() - - def set_paddle_mkldnn_cache_size(self, cache_size): - return self._option.set_paddle_mkldnn_cache_size(cache_size) - - def set_trt_input_shape(self, tensor_name, min_shape, opt_shape=None, max_shape=None): - if opt_shape is None and max_shape is None: - opt_shape = min_shape - max_shape = min_shape - else: - assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both." - return self._option.set_trt_input_shape(tensor_name, min_shape, opt_shape, max_shape) - - def set_trt_cache_file(self, cache_file_path): - return self._option.set_trt_cache_file(cache_file_path) - - def enable_trt_fp16(self): - return self._option.enable_trt_fp16() - - def dissable_trt_fp16(self): - return self._option.disable_trt_fp16() - - def __repr__(self): - attrs = dir(self._option) - message = "RuntimeOption(\n" - for attr in attrs: - if attr.startswith("__"): - continue - if hasattr(getattr(self._option, attr), "__call__"): - continue - message += " {} : {}\t\n".format(attr, getattr(self._option, attr)) - message.strip("\n") - message += ")" - return message - - def RuntimeOptionStr(runtime_option): attrs = dir(runtime_option) message = "RuntimeOption(\n" @@ -122,5 +61,6 @@ def RuntimeOptionStr(runtime_option): message += ")" return message + C.TensorInfo.__repr__ = TensorInfoStr -C.RuntimeOption.__repr__ = RuntimeOptionStr \ No newline at end of file +C.RuntimeOption.__repr__ = RuntimeOptionStr diff --git a/fastdeploy/fastdeploy_runtime.py b/fastdeploy/model.py similarity index 61% rename from fastdeploy/fastdeploy_runtime.py rename to fastdeploy/model.py index e07e28993bc..f0faa161032 100644 --- a/fastdeploy/fastdeploy_runtime.py +++ b/fastdeploy/model.py @@ -54,35 +54,3 @@ def initialized(self): if self._model is None: return False return self._model.initialized() - - -class Runtime: - def __init__(self, runtime_option): - self._runtime = C.Runtime() - assert self._runtime.init(runtime_option), "Initialize Runtime Failed!" - - def infer(self, data): - assert isinstance(data, dict), "The input data should be type of dict." - return self._runtime.infer(data) - - def num_inputs(self): - return self._runtime.num_inputs() - - def num_outputs(self): - return self._runtime.num_outputs() - - def get_input_info(self, index): - assert isinstance( - index, int), "The input parameter index should be type of int." - assert index < self.num_inputs( - ), "The input parameter index:{} should less than number of inputs:{}.".format( - index, self.num_inputs) - return self._runtime.get_input_info(index) - - def get_output_info(self, index): - assert isinstance( - index, int), "The input parameter index should be type of int." - assert index < self.num_outputs( - ), "The input parameter index:{} should less than number of outputs:{}.".format( - index, self.num_outputs) - return self._runtime.get_output_info(index) diff --git a/fastdeploy/runtime.py b/fastdeploy/runtime.py new file mode 100644 index 00000000000..a560f63a22b --- /dev/null +++ b/fastdeploy/runtime.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +import logging +from . import fastdeploy_main as C + + +class Runtime: + def __init__(self, runtime_option): + self._runtime = C.Runtime() + assert self._runtime.init(runtime_option), "Initialize Runtime Failed!" + + def infer(self, data): + assert isinstance(data, dict), "The input data should be type of dict." + return self._runtime.infer(data) + + def num_inputs(self): + return self._runtime.num_inputs() + + def num_outputs(self): + return self._runtime.num_outputs() + + def get_input_info(self, index): + assert isinstance( + index, int), "The input parameter index should be type of int." + assert index < self.num_inputs( + ), "The input parameter index:{} should less than number of inputs:{}.".format( + index, self.num_inputs) + return self._runtime.get_input_info(index) + + def get_output_info(self, index): + assert isinstance( + index, int), "The input parameter index should be type of int." + assert index < self.num_outputs( + ), "The input parameter index:{} should less than number of outputs:{}.".format( + index, self.num_outputs) + return self._runtime.get_output_info(index) + + +class RuntimeOption: + def __init__(self): + self._option = C.RuntimeOption() + + def set_model_path(self, model_path, params_path="", + model_format="paddle"): + return self._option.set_model_path(model_path, params_path, + model_format) + + def use_gpu(self, device_id=0): + return self._option.use_gpu(device_id) + + def use_cpu(self): + return self._option.use_cpu() + + def set_cpu_thread_num(self, thread_num=8): + return self._option.set_cpu_thread_num(thread_num) + + def use_paddle_backend(self): + return self._option.use_paddle_backend() + + def use_ort_backend(self): + return self._option.use_ort_backend() + + def use_trt_backend(self): + return self._option.use_trt_backend() + + def enable_paddle_mkldnn(self): + return self._option.enable_paddle_mkldnn() + + def disable_paddle_mkldnn(self): + return self._option.disable_paddle_mkldnn() + + def set_paddle_mkldnn_cache_size(self, cache_size): + return self._option.set_paddle_mkldnn_cache_size(cache_size) + + def set_trt_input_shape(self, + tensor_name, + min_shape, + opt_shape=None, + max_shape=None): + if opt_shape is None and max_shape is None: + opt_shape = min_shape + max_shape = min_shape + else: + assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both." + return self._option.set_trt_input_shape(tensor_name, min_shape, + opt_shape, max_shape) + + def set_trt_cache_file(self, cache_file_path): + return self._option.set_trt_cache_file(cache_file_path) + + def enable_trt_fp16(self): + return self._option.enable_trt_fp16() + + def disable_trt_fp16(self): + return self._option.disable_trt_fp16() + + def __repr__(self): + attrs = dir(self._option) + message = "RuntimeOption(\n" + for attr in attrs: + if attr.startswith("__"): + continue + if hasattr(getattr(self._option, attr), "__call__"): + continue + message += " {} : {}\t\n".format(attr, + getattr(self._option, attr)) + message.strip("\n") + message += ")" + return message From 379c58cae34db5f9c4d40a6a115361c3d379e531 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 2 Aug 2022 13:06:03 +0000 Subject: [PATCH 2/5] Add CenterNet/PicoDet/PPYOLO/PPYOLOv2/YOLOv3 --- CMakeLists.txt | 2 +- FastDeploy.cmake.in | 2 +- csrcs/fastdeploy/vision.h | 4 +- csrcs/fastdeploy/vision/ppdet/centernet.cc | 25 ++++++++ csrcs/fastdeploy/vision/ppdet/centernet.h | 19 ++++++ csrcs/fastdeploy/vision/ppdet/model.h | 6 ++ csrcs/fastdeploy/vision/ppdet/picodet.cc | 52 ++++++++++++++++ csrcs/fastdeploy/vision/ppdet/picodet.h | 22 +++++++ csrcs/fastdeploy/vision/ppdet/ppyolo.cc | 70 ++++++++++++++++++++++ csrcs/fastdeploy/vision/ppdet/ppyolo.h | 25 ++++++++ csrcs/fastdeploy/vision/ppdet/ppyoloe.cc | 16 ++--- csrcs/fastdeploy/vision/ppdet/ppyoloe.h | 16 ++--- csrcs/fastdeploy/vision/ppdet/yolov3.cc | 56 +++++++++++++++++ csrcs/fastdeploy/vision/ppdet/yolov3.h | 21 +++++++ model_zoo/vision/ppyoloe/api.md | 18 ++---- 15 files changed, 322 insertions(+), 32 deletions(-) create mode 100644 csrcs/fastdeploy/vision/ppdet/centernet.cc create mode 100644 csrcs/fastdeploy/vision/ppdet/centernet.h create mode 100644 csrcs/fastdeploy/vision/ppdet/model.h create mode 100644 csrcs/fastdeploy/vision/ppdet/picodet.cc create mode 100644 csrcs/fastdeploy/vision/ppdet/picodet.h create mode 100644 csrcs/fastdeploy/vision/ppdet/ppyolo.cc create mode 100644 csrcs/fastdeploy/vision/ppdet/ppyolo.h create mode 100644 csrcs/fastdeploy/vision/ppdet/yolov3.cc create mode 100644 csrcs/fastdeploy/vision/ppdet/yolov3.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c15fac1b12..7c1bfdde744 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -348,6 +348,6 @@ endif(BUILD_FASTDEPLOY_PYTHON) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS "5.4.0") string(STRIP "${CMAKE_CXX_COMPILER_VERSION}" CMAKE_CXX_COMPILER_VERSION) - message(WARNING "[WARNING] FastDeploy require g++ version >= 5.4.0, but now your g++ version is ${CMAKE_CXX_COMPILER_VERSION}, this may cause failure! Use -DCMAKE_CXX_COMPILER to define path of your compiler.") + message(FATAL_ERROR "[ERROR] FastDeploy require g++ version >= 5.4.0, but now your g++ version is ${CMAKE_CXX_COMPILER_VERSION}, this may cause failure! Use -DCMAKE_CXX_COMPILER to define path of your compiler.") endif() endif() diff --git a/FastDeploy.cmake.in b/FastDeploy.cmake.in index 818533c8bd4..4f4643fdfba 100644 --- a/FastDeploy.cmake.in +++ b/FastDeploy.cmake.in @@ -113,6 +113,6 @@ message(STATUS " ENABLE_VISION : ${ENABLE_VISION}") if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS "5.4.0") string(STRIP "${CMAKE_CXX_COMPILER_VERSION}" CMAKE_CXX_COMPILER_VERSION) - message(WARNING "[WARNING] FastDeploy require g++ version >= 5.4.0, but now your g++ version is ${CMAKE_CXX_COMPILER_VERSION}, this may cause failure! Use -DCMAKE_CXX_COMPILER to define path of your compiler.") + message(FATAL_ERROR "[ERROR] FastDeploy require g++ version >= 5.4.0, but now your g++ version is ${CMAKE_CXX_COMPILER_VERSION}, this may cause failure! Use -DCMAKE_CXX_COMPILER to define path of your compiler.") endif() endif() diff --git a/csrcs/fastdeploy/vision.h b/csrcs/fastdeploy/vision.h index 2c0bdd1fa85..205183bcdc7 100644 --- a/csrcs/fastdeploy/vision.h +++ b/csrcs/fastdeploy/vision.h @@ -21,14 +21,14 @@ #include "fastdeploy/vision/megvii/yolox.h" #include "fastdeploy/vision/meituan/yolov6.h" #include "fastdeploy/vision/ppcls/model.h" -#include "fastdeploy/vision/ppdet/ppyoloe.h" +#include "fastdeploy/vision/ppdet/model.h" +#include "fastdeploy/vision/ppogg/yolov5lite.h" #include "fastdeploy/vision/ppseg/model.h" #include "fastdeploy/vision/rangilyu/nanodet_plus.h" #include "fastdeploy/vision/ultralytics/yolov5.h" #include "fastdeploy/vision/wongkinyiu/scaledyolov4.h" #include "fastdeploy/vision/wongkinyiu/yolor.h" #include "fastdeploy/vision/wongkinyiu/yolov7.h" -#include "fastdeploy/vision/ppogg/yolov5lite.h" #endif #include "fastdeploy/vision/visualize/visualize.h" diff --git a/csrcs/fastdeploy/vision/ppdet/centernet.cc b/csrcs/fastdeploy/vision/ppdet/centernet.cc new file mode 100644 index 00000000000..259ecf620d2 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/centernet.cc @@ -0,0 +1,25 @@ +#include "fastdeploy/vision/ppdet/centernet.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +CenterNet::CenterNet(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::PDINFER}; + valid_gpu_backends = {Backend::PDINFER}; + has_nms_ = true; + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/centernet.h b/csrcs/fastdeploy/vision/ppdet/centernet.h new file mode 100644 index 00000000000..a6ae756cd10 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/centernet.h @@ -0,0 +1,19 @@ +#pragma once +#include "fastdeploy/vision/ppdet/ppyolo.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +class FASTDEPLOY_DECL CenterNet : public PPYOLO { + public: + CenterNet(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE); + + virtual std::string ModelName() const { return "PaddleDetection/CenterNet"; } +}; +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/model.h b/csrcs/fastdeploy/vision/ppdet/model.h new file mode 100644 index 00000000000..89c59fd1a9f --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/model.h @@ -0,0 +1,6 @@ +#pragma once +#include "fastdeploy/vision/ppdet/centernet.h" +#include "fastdeploy/vision/ppdet/picodet.h" +#include "fastdeploy/vision/ppdet/ppyolo.h" +#include "fastdeploy/vision/ppdet/ppyoloe.h" +#include "fastdeploy/vision/ppdet/yolov3.h" diff --git a/csrcs/fastdeploy/vision/ppdet/picodet.cc b/csrcs/fastdeploy/vision/ppdet/picodet.cc new file mode 100644 index 00000000000..f8070a3d240 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/picodet.cc @@ -0,0 +1,52 @@ +#include "fastdeploy/vision/ppdet/picodet.h" +#include "yaml-cpp/yaml.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +PicoDet::PicoDet(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + background_label = -1; + keep_top_k = 100; + nms_eta = 1; + nms_threshold = 0.6; + nms_top_k = 1000; + normalized = true; + score_threshold = 0.025; + CheckIfContainDecodeAndNMS(); + initialized = Initialize(); +} + +bool PicoDet::CheckIfContainDecodeAndNMS() { + YAML::Node cfg; + try { + cfg = YAML::LoadFile(config_file_); + } catch (YAML::BadFile& e) { + FDERROR << "Failed to load yaml file " << config_file_ + << ", maybe you should check this file." << std::endl; + return false; + } + + if (cfg["arch"].as() == "PicoDet") { + FDERROR << "The arch in config file is PicoDet, which means this model " + "doesn contain box decode and nms, please export model with " + "decode and nms." + << std::endl; + return false; + } + return true; +} + +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/picodet.h b/csrcs/fastdeploy/vision/ppdet/picodet.h new file mode 100644 index 00000000000..90242692610 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/picodet.h @@ -0,0 +1,22 @@ +#pragma once +#include "fastdeploy/vision/ppdet/ppyoloe.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +class FASTDEPLOY_DECL PicoDet : public PPYOLOE { + public: + PicoDet(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE); + + // Only support picodet contains decode and nms + bool CheckIfContainDecodeAndNMS(); + + virtual std::string ModelName() const { return "PaddleDetection/PicoDet"; } +}; +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/ppyolo.cc b/csrcs/fastdeploy/vision/ppdet/ppyolo.cc new file mode 100644 index 00000000000..2926b54a092 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/ppyolo.cc @@ -0,0 +1,70 @@ +#include "fastdeploy/vision/ppdet/ppyolo.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +PPYOLO::PPYOLO(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::PDINFER}; + valid_gpu_backends = {Backend::PDINFER}; + has_nms_ = true; + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool PPYOLO::Initialize() { + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." + << std::endl; + return false; + } + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool PPYOLO::Preprocess(Mat* mat, std::vector* outputs) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); + mat->PrintInfo("Origin"); + for (size_t i = 0; i < processors_.size(); ++i) { + if (!(*(processors_[i].get()))(mat)) { + FDERROR << "Failed to process image data in " << processors_[i]->Name() + << "." << std::endl; + return false; + } + mat->PrintInfo(processors_[i]->Name()); + } + + outputs->resize(3); + (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape"); + (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor"); + std::cout << "111111111" << std::endl; + float* ptr0 = static_cast((*outputs)[0].MutableData()); + ptr0[0] = mat->Height(); + ptr0[1] = mat->Width(); + std::cout << "090909" << std::endl; + float* ptr2 = static_cast((*outputs)[2].MutableData()); + ptr2[0] = mat->Height() * 1.0 / origin_h; + ptr2[1] = mat->Width() * 1.0 / origin_w; + std::cout << "88888" << std::endl; + (*outputs)[1].name = "image"; + mat->ShareWithTensor(&((*outputs)[1])); + // reshape to [1, c, h, w] + (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); + std::cout << "??????" << std::endl; + return true; +} + +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/ppyolo.h b/csrcs/fastdeploy/vision/ppdet/ppyolo.h new file mode 100644 index 00000000000..b17f54b3e6a --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/ppyolo.h @@ -0,0 +1,25 @@ +#pragma once +#include "fastdeploy/vision/ppdet/ppyoloe.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +class FASTDEPLOY_DECL PPYOLO : public PPYOLOE { + public: + PPYOLO(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE); + + virtual std::string ModelName() const { return "PaddleDetection/PPYOLO"; } + + virtual bool Preprocess(Mat* mat, std::vector* outputs); + virtual bool Initialize(); + + protected: + PPYOLO() {} +}; +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc index 5152db3fa26..99dae0fc6b0 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc @@ -85,12 +85,6 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() { return false; } - if (cfg["arch"].as() != "YOLO") { - FDERROR << "Require the arch of model is YOLO, but arch defined in " - "config file is " - << cfg["arch"].as() << "." << std::endl; - return false; - } processors_.push_back(std::make_shared()); for (const auto& op : cfg["Preprocess"]) { @@ -128,12 +122,14 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() { bool PPYOLOE::Preprocess(Mat* mat, std::vector* outputs) { int origin_w = mat->Width(); int origin_h = mat->Height(); + mat->PrintInfo("Origin"); for (size_t i = 0; i < processors_.size(); ++i) { if (!(*(processors_[i].get()))(mat)) { FDERROR << "Failed to process image data in " << processors_[i]->Name() << "." << std::endl; return false; } + mat->PrintInfo(processors_[i]->Name()); } outputs->resize(2); @@ -217,8 +213,7 @@ bool PPYOLOE::Postprocess(std::vector& infer_result, return true; } -bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result, - float conf_threshold, float iou_threshold) { +bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) { Mat mat(*im); std::vector processed_data; if (!Preprocess(&mat, &processed_data)) { @@ -227,6 +222,9 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result, return false; } + processed_data[0].PrintInfo("Before infer"); + float* tmp = static_cast(processed_data[1].Data()); + std::cout << "==== " << tmp[0] << " " << tmp[1] << std::endl; std::vector infer_result; if (!Infer(processed_data, &infer_result)) { FDERROR << "Failed to inference while using model:" << ModelName() << "." @@ -234,6 +232,8 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result, return false; } + infer_result[0].PrintInfo("Boxes"); + infer_result[1].PrintInfo("Num"); if (!Postprocess(infer_result, result)) { FDERROR << "Failed to postprocess while using model:" << ModelName() << "." << std::endl; diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h index d86508fa184..84fe0781ba1 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h @@ -16,7 +16,7 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { const RuntimeOption& custom_option = RuntimeOption(), const Frontend& model_format = Frontend::PADDLE); - std::string ModelName() const { return "PaddleDetection/PPYOLOE"; } + virtual std::string ModelName() const { return "PaddleDetection/PPYOLOE"; } virtual bool Initialize(); @@ -27,10 +27,14 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { virtual bool Postprocess(std::vector& infer_result, DetectionResult* result); - virtual bool Predict(cv::Mat* im, DetectionResult* result, - float conf_threshold = 0.5, float nms_threshold = 0.7); + virtual bool Predict(cv::Mat* im, DetectionResult* result); + + protected: + PPYOLOE() {} + // This function will used to check if this model contains multiclass_nms + // and get parameters from the operator + void GetNmsInfo(); - private: std::vector> processors_; std::string config_file_; // configuration for nms @@ -42,10 +46,6 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { int64_t nms_top_k = 10000; bool normalized = true; bool has_nms_ = false; - - // This function will used to check if this model contains multiclass_nms - // and get parameters from the operator - void GetNmsInfo(); }; } // namespace ppdet } // namespace vision diff --git a/csrcs/fastdeploy/vision/ppdet/yolov3.cc b/csrcs/fastdeploy/vision/ppdet/yolov3.cc new file mode 100644 index 00000000000..9d683819473 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/yolov3.cc @@ -0,0 +1,56 @@ +#include "fastdeploy/vision/ppdet/yolov3.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +YOLOv3::YOLOv3(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::PDINFER}; + valid_gpu_backends = {Backend::PDINFER}; + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool YOLOv3::Preprocess(Mat* mat, std::vector* outputs) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); + mat->PrintInfo("Origin"); + for (size_t i = 0; i < processors_.size(); ++i) { + if (!(*(processors_[i].get()))(mat)) { + FDERROR << "Failed to process image data in " << processors_[i]->Name() + << "." << std::endl; + return false; + } + mat->PrintInfo(processors_[i]->Name()); + } + + outputs->resize(3); + (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape"); + (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor"); + std::cout << "111111111" << std::endl; + float* ptr0 = static_cast((*outputs)[0].MutableData()); + ptr0[0] = mat->Height(); + ptr0[1] = mat->Width(); + std::cout << "090909" << std::endl; + float* ptr2 = static_cast((*outputs)[2].MutableData()); + ptr2[0] = mat->Height() * 1.0 / origin_h; + ptr2[1] = mat->Width() * 1.0 / origin_w; + std::cout << "88888" << std::endl; + (*outputs)[1].name = "image"; + mat->ShareWithTensor(&((*outputs)[1])); + // reshape to [1, c, h, w] + (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); + std::cout << "??????" << std::endl; + return true; +} + +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/yolov3.h b/csrcs/fastdeploy/vision/ppdet/yolov3.h new file mode 100644 index 00000000000..9919c023358 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/yolov3.h @@ -0,0 +1,21 @@ +#pragma once +#include "fastdeploy/vision/ppdet/ppyoloe.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +class FASTDEPLOY_DECL YOLOv3 : public PPYOLOE { + public: + YOLOv3(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE); + + virtual std::string ModelName() const { return "PaddleDetection/YOLOv3"; } + + virtual bool Preprocess(Mat* mat, std::vector* outputs); +}; +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/model_zoo/vision/ppyoloe/api.md b/model_zoo/vision/ppyoloe/api.md index 1c5cbcaadbd..99565e56365 100644 --- a/model_zoo/vision/ppyoloe/api.md +++ b/model_zoo/vision/ppyoloe/api.md @@ -4,7 +4,7 @@ ### PPYOLOE类 ``` -fastdeploy.vision.ultralytics.PPYOLOE(model_file, params_file, config_file, runtime_option=None, model_format=fd.Frontend.PADDLE) +fastdeploy.vision.ppdet.PPYOLOE(model_file, params_file, config_file, runtime_option=None, model_format=fd.Frontend.PADDLE) ``` PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式 @@ -18,15 +18,13 @@ PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当 #### predict函数 > ``` -> PPYOLOE.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5) +> PPYOLOE.predict(image_data) > ``` > 模型预测结口,输入图像直接输出检测结果。 > > **参数** > > > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 -> > * **conf_threshold**(float): 检测框置信度过滤阈值 -> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值(当模型中包含nms处理时,此参数自动无效) 示例代码参考[ppyoloe.py](./ppyoloe.py) @@ -35,12 +33,12 @@ PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当 ### PPYOLOE类 ``` -fastdeploy::vision::ultralytics::PPYOLOE( +fastdeploy::vision::ppdet::PPYOLOE( const string& model_file, const string& params_file, const string& config_file, const RuntimeOption& runtime_option = RuntimeOption(), - const Frontend& model_format = Frontend::ONNX) + const Frontend& model_format = Frontend::PADDLE) ``` PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式 @@ -54,9 +52,7 @@ PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当 #### Predict函数 > ``` -> YOLOv5::Predict(cv::Mat* im, DetectionResult* result, -> float conf_threshold = 0.25, -> float nms_iou_threshold = 0.5) +> PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) > ``` > 模型预测接口,输入图像直接输出检测结果。 > @@ -64,10 +60,8 @@ PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当 > > > * **im**: 输入图像,注意需为HWC,BGR格式 > > * **result**: 检测结果,包括检测框,各个框的置信度 -> > * **conf_threshold**: 检测框置信度过滤阈值 -> > * **nms_iou_threshold**: NMS处理过程中iou阈值(当模型中包含nms处理时,此参数自动无效) -示例代码参考[cpp/yolov5.cc](cpp/yolov5.cc) +示例代码参考[cpp/ppyoloe.cc](cpp/ppyoloe.cc) ## 其它API使用 From 56bdbe5b0df177df64adbe9e77bf72d4d63024aa Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Wed, 3 Aug 2022 08:15:23 +0000 Subject: [PATCH 3/5] add more ppdet models --- .../vision/common/processors/pad_to_size.cc | 141 ++++++++++++++++++ .../vision/common/processors/pad_to_size.h | 46 ++++++ .../vision/common/processors/stride_pad.cc | 124 +++++++++++++++ .../vision/common/processors/stride_pad.h | 44 ++++++ .../vision/common/processors/transform.h | 2 + .../vision/ppdet/build_preprocess.cc | 78 ++++++++++ csrcs/fastdeploy/vision/ppdet/centernet.cc | 25 ---- csrcs/fastdeploy/vision/ppdet/centernet.h | 19 --- csrcs/fastdeploy/vision/ppdet/model.h | 16 ++ csrcs/fastdeploy/vision/ppdet/picodet.cc | 14 ++ csrcs/fastdeploy/vision/ppdet/picodet.h | 14 ++ csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc | 16 ++ csrcs/fastdeploy/vision/ppdet/ppyolo.cc | 18 ++- csrcs/fastdeploy/vision/ppdet/ppyoloe.cc | 34 +++-- csrcs/fastdeploy/vision/ppdet/ppyoloe.h | 19 +++ csrcs/fastdeploy/vision/ppdet/rcnn.cc | 89 +++++++++++ csrcs/fastdeploy/vision/ppdet/rcnn.h | 39 +++++ csrcs/fastdeploy/vision/ppdet/yolov3.cc | 18 ++- csrcs/fastdeploy/vision/ppdet/yolov3.h | 14 ++ csrcs/fastdeploy/vision/ppdet/yolox.cc | 74 +++++++++ csrcs/fastdeploy/vision/ppdet/yolox.h | 35 +++++ csrcs/fastdeploy/vision/ppogg/yolov5lite.cc | 1 - fastdeploy/vision/ppdet/__init__.py | 66 +++++++- 23 files changed, 883 insertions(+), 63 deletions(-) create mode 100644 csrcs/fastdeploy/vision/common/processors/pad_to_size.cc create mode 100644 csrcs/fastdeploy/vision/common/processors/pad_to_size.h create mode 100644 csrcs/fastdeploy/vision/common/processors/stride_pad.cc create mode 100644 csrcs/fastdeploy/vision/common/processors/stride_pad.h create mode 100644 csrcs/fastdeploy/vision/ppdet/build_preprocess.cc delete mode 100644 csrcs/fastdeploy/vision/ppdet/centernet.cc delete mode 100644 csrcs/fastdeploy/vision/ppdet/centernet.h create mode 100644 csrcs/fastdeploy/vision/ppdet/rcnn.cc create mode 100644 csrcs/fastdeploy/vision/ppdet/rcnn.h create mode 100644 csrcs/fastdeploy/vision/ppdet/yolox.cc create mode 100644 csrcs/fastdeploy/vision/ppdet/yolox.h diff --git a/csrcs/fastdeploy/vision/common/processors/pad_to_size.cc b/csrcs/fastdeploy/vision/common/processors/pad_to_size.cc new file mode 100644 index 00000000000..d4cbacd879c --- /dev/null +++ b/csrcs/fastdeploy/vision/common/processors/pad_to_size.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/common/processors/pad_to_size.h" + +namespace fastdeploy { +namespace vision { + +bool PadToSize::CpuRun(Mat* mat) { + if (mat->layout != Layout::HWC) { + FDERROR << "PadToSize: The input data must be Layout::HWC format!" + << std::endl; + return false; + } + if (mat->Channels() > 4) { + FDERROR << "PadToSize: Only support channels <= 4." << std::endl; + return false; + } + if (mat->Channels() != value_.size()) { + FDERROR + << "PadToSize: Require input channels equals to size of padding value, " + "but now channels = " + << mat->Channels() << ", the size of padding values = " << value_.size() + << "." << std::endl; + return false; + } + int origin_w = mat->Width(); + int origin_h = mat->Height(); + if (origin_w > width_) { + FDERROR << "PadToSize: the input width:" << origin_w + << " is greater than the target width: " << width_ << "." + << std::endl; + return false; + } + if (origin_h > height_) { + FDERROR << "PadToSize: the input height:" << origin_h + << " is greater than the target height: " << height_ << "." + << std::endl; + return false; + } + if (origin_w == width_ && origin_h == height_) { + return true; + } + + cv::Mat* im = mat->GetCpuMat(); + cv::Scalar value; + if (value_.size() == 1) { + value = cv::Scalar(value_[0]); + } else if (value_.size() == 2) { + value = cv::Scalar(value_[0], value_[1]); + } else if (value_.size() == 3) { + value = cv::Scalar(value_[0], value_[1], value_[2]); + } else { + value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]); + } + // top, bottom, left, right + cv::copyMakeBorder(*im, *im, 0, height_ - origin_h, 0, width_ - origin_w, + cv::BORDER_CONSTANT, value); + mat->SetHeight(height_); + mat->SetWidth(width_); + return true; +} + +#ifdef ENABLE_OPENCV_CUDA +bool PadToSize::GpuRun(Mat* mat) { + if (mat->layout != Layout::HWC) { + FDERROR << "PadToSize: The input data must be Layout::HWC format!" + << std::endl; + return false; + } + if (mat->Channels() > 4) { + FDERROR << "PadToSize: Only support channels <= 4." << std::endl; + return false; + } + if (mat->Channels() != value_.size()) { + FDERROR + << "PadToSize: Require input channels equals to size of padding value, " + "but now channels = " + << mat->Channels() << ", the size of padding values = " << value_.size() + << "." << std::endl; + return false; + } + + int origin_w = mat->Width(); + int origin_h = mat->Height(); + if (origin_w > width_) { + FDERROR << "PadToSize: the input width:" << origin_w + << " is greater than the target width: " << width_ << "." + << std::endl; + return false; + } + if (origin_h > height_) { + FDERROR << "PadToSize: the input height:" << origin_h + << " is greater than the target height: " << height_ << "." + << std::endl; + return false; + } + if (origin_w == width_ && origin_h == height_) { + return true; + } + + cv::cuda::GpuMat* im = mat->GetGpuMat(); + cv::Scalar value; + if (value_.size() == 1) { + value = cv::Scalar(value_[0]); + } else if (value_.size() == 2) { + value = cv::Scalar(value_[0], value_[1]); + } else if (value_.size() == 3) { + value = cv::Scalar(value_[0], value_[1], value_[2]); + } else { + value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]); + } + + // top, bottom, left, right + cv::cuda::copyMakeBorder(*im, *im, 0, height_ - origin_h, 0, + width_ - origin_w, cv::BORDER_CONSTANT, value); + mat->SetHeight(height_); + mat->SetWidth(width_); + return true; +} +#endif + +bool PadToSize::Run(Mat* mat, int width, int height, + const std::vector& value, ProcLib lib) { + auto p = PadToSize(width, height, value); + return p(mat, lib); +} + +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/common/processors/pad_to_size.h b/csrcs/fastdeploy/vision/common/processors/pad_to_size.h new file mode 100644 index 00000000000..ece0158f7be --- /dev/null +++ b/csrcs/fastdeploy/vision/common/processors/pad_to_size.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/vision/common/processors/base.h" + +namespace fastdeploy { +namespace vision { + +class PadToSize : public Processor { + public: + // only support pad with left-top padding mode + PadToSize(int width, int height, const std::vector& value) { + width_ = width; + height_ = height; + value_ = value; + } + bool CpuRun(Mat* mat); +#ifdef ENABLE_OPENCV_CUDA + bool GpuRun(Mat* mat); +#endif + std::string Name() { return "PadToSize"; } + + static bool Run(Mat* mat, int width, int height, + const std::vector& value, + ProcLib lib = ProcLib::OPENCV_CPU); + + private: + int width_; + int height_; + std::vector value_; +}; +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/common/processors/stride_pad.cc b/csrcs/fastdeploy/vision/common/processors/stride_pad.cc new file mode 100644 index 00000000000..8597c83758e --- /dev/null +++ b/csrcs/fastdeploy/vision/common/processors/stride_pad.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/common/processors/stride_pad.h" + +namespace fastdeploy { +namespace vision { + +bool StridePad::CpuRun(Mat* mat) { + if (mat->layout != Layout::HWC) { + FDERROR << "StridePad: The input data must be Layout::HWC format!" + << std::endl; + return false; + } + if (mat->Channels() > 4) { + FDERROR << "StridePad: Only support channels <= 4." << std::endl; + return false; + } + if (mat->Channels() != value_.size()) { + FDERROR + << "StridePad: Require input channels equals to size of padding value, " + "but now channels = " + << mat->Channels() << ", the size of padding values = " << value_.size() + << "." << std::endl; + return false; + } + int origin_w = mat->Width(); + int origin_h = mat->Height(); + + int pad_h = (mat->Height() / stride_) * stride_ + + (mat->Height() % stride_ != 0) * stride_ - mat->Height(); + int pad_w = (mat->Width() / stride_) * stride_ + + (mat->Width() % stride_ != 0) * stride_ - mat->Width(); + if (pad_h == 0 && pad_w == 0) { + return true; + } + cv::Mat* im = mat->GetCpuMat(); + cv::Scalar value; + if (value_.size() == 1) { + value = cv::Scalar(value_[0]); + } else if (value_.size() == 2) { + value = cv::Scalar(value_[0], value_[1]); + } else if (value_.size() == 3) { + value = cv::Scalar(value_[0], value_[1], value_[2]); + } else { + value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]); + } + // top, bottom, left, right + cv::copyMakeBorder(*im, *im, 0, pad_h, 0, pad_w, cv::BORDER_CONSTANT, value); + mat->SetHeight(origin_h + pad_h); + mat->SetWidth(origin_w + pad_w); + return true; +} + +#ifdef ENABLE_OPENCV_CUDA +bool StridePad::GpuRun(Mat* mat) { + if (mat->layout != Layout::HWC) { + FDERROR << "StridePad: The input data must be Layout::HWC format!" + << std::endl; + return false; + } + if (mat->Channels() > 4) { + FDERROR << "StridePad: Only support channels <= 4." << std::endl; + return false; + } + if (mat->Channels() != value_.size()) { + FDERROR + << "StridePad: Require input channels equals to size of padding value, " + "but now channels = " + << mat->Channels() << ", the size of padding values = " << value_.size() + << "." << std::endl; + return false; + } + + int origin_w = mat->Width(); + int origin_h = mat->Height(); + int pad_h = (mat->Height() / stride_) * stride_ + + (mat->Height() % stride_ != 0) * stride_; + int pad_w = (mat->Width() / stride_) * stride_ + + (mat->Width() % stride_ != 0) * stride_; + if (pad_h == 0 && pad_w == 0) { + return true; + } + + cv::cuda::GpuMat* im = mat->GetGpuMat(); + cv::Scalar value; + if (value_.size() == 1) { + value = cv::Scalar(value_[0]); + } else if (value_.size() == 2) { + value = cv::Scalar(value_[0], value_[1]); + } else if (value_.size() == 3) { + value = cv::Scalar(value_[0], value_[1], value_[2]); + } else { + value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]); + } + + // top, bottom, left, right + cv::cuda::copyMakeBorder(*im, *im, 0, pad_h, 0, pad_w, cv::BORDER_CONSTANT, + value); + mat->SetHeight(origin_h + pad_h); + mat->SetWidth(origin_w + pad_w); + return true; +} +#endif + +bool StridePad::Run(Mat* mat, int stride, const std::vector& value, + ProcLib lib) { + auto p = StridePad(stride, value); + return p(mat, lib); +} + +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/common/processors/stride_pad.h b/csrcs/fastdeploy/vision/common/processors/stride_pad.h new file mode 100644 index 00000000000..c002ca697bb --- /dev/null +++ b/csrcs/fastdeploy/vision/common/processors/stride_pad.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/vision/common/processors/base.h" + +namespace fastdeploy { +namespace vision { + +class StridePad : public Processor { + public: + // only support pad with left-top padding mode + StridePad(int stride, const std::vector& value) { + stride_ = stride; + value_ = value; + } + bool CpuRun(Mat* mat); +#ifdef ENABLE_OPENCV_CUDA + bool GpuRun(Mat* mat); +#endif + std::string Name() { return "StridePad"; } + + static bool Run(Mat* mat, int stride, + const std::vector& value = std::vector(), + ProcLib lib = ProcLib::OPENCV_CPU); + + private: + int stride_ = 32; + std::vector value_; +}; +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/common/processors/transform.h b/csrcs/fastdeploy/vision/common/processors/transform.h index 08073b4e423..fed3d0c9a25 100644 --- a/csrcs/fastdeploy/vision/common/processors/transform.h +++ b/csrcs/fastdeploy/vision/common/processors/transform.h @@ -21,5 +21,7 @@ #include "fastdeploy/vision/common/processors/hwc2chw.h" #include "fastdeploy/vision/common/processors/normalize.h" #include "fastdeploy/vision/common/processors/pad.h" +#include "fastdeploy/vision/common/processors/pad_to_size.h" #include "fastdeploy/vision/common/processors/resize.h" #include "fastdeploy/vision/common/processors/resize_by_short.h" +#include "fastdeploy/vision/common/processors/stride_pad.h" diff --git a/csrcs/fastdeploy/vision/ppdet/build_preprocess.cc b/csrcs/fastdeploy/vision/ppdet/build_preprocess.cc new file mode 100644 index 00000000000..ee3b6a16a69 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/build_preprocess.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ppdet/ppyoloe.h" + +bool BuildPreprocessPipelineFromConfig( + std::vector>* processors, + const std::string& config_file) { + processors->clear(); + YAML::Node cfg; + try { + cfg = YAML::LoadFile(config_file); + } catch (YAML::BadFile& e) { + FDERROR << "Failed to load yaml file " << config_file_ + << ", maybe you should check this file." << std::endl; + return false; + } + + processors->push_back(std::make_shared()); + + for (const auto& op : cfg["Preprocess"]) { + std::string op_name = op["type"].as(); + if (op_name == "NormalizeImage") { + auto mean = op["mean"].as>(); + auto std = op["std"].as>(); + bool is_scale = op["is_scale"].as(); + processors->push_back(std::make_shared(mean, std, is_scale)); + } else if (op_name == "Resize") { + bool keep_ratio = op["keep_ratio"].as(); + auto target_size = op["target_size"].as>(); + int interp = op["interp"].as(); + FDASSERT(target_size.size(), + "Require size of target_size be 2, but now it's " + + std::to_string(target_size.size()) + "."); + if (!keep_ratio) { + int width = target_size[1]; + int height = target_size[0]; + processors->push_back( + std::make_shared(width, height, -1.0, -1.0, interp, false)); + } else { + int min_target_size = std::min(target_size[0], target_size[1]); + int max_target_size = std::max(target_size[0], target_size[1]); + processors->push_back(std::make_shared( + min_target_size, interp, true, max_target_size)); + } + } else if (op_name == "Permute") { + // Do nothing, do permute as the last operation + continue; + } else if (op_name == "Pad") { + auto size = op["size"].as>(); + auto value = op["fill_value"].as>(); + processors->push_back(std::make_shared("float")); + processors->push_back( + std::make_shared(size[1], size[0], value)); + } else if (op_name == "PadStride") { + auto stride = op["stride"].as(); + processors->push_back( + std::make_shared(stride, std::vector(3, 0))); + } else { + FDERROR << "Unexcepted preprocess operator: " << op_name << "." + << std::endl; + return false; + } + } + processors->push_back(std::make_shared()); + return true; +} diff --git a/csrcs/fastdeploy/vision/ppdet/centernet.cc b/csrcs/fastdeploy/vision/ppdet/centernet.cc deleted file mode 100644 index 259ecf620d2..00000000000 --- a/csrcs/fastdeploy/vision/ppdet/centernet.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "fastdeploy/vision/ppdet/centernet.h" - -namespace fastdeploy { -namespace vision { -namespace ppdet { - -CenterNet::CenterNet(const std::string& model_file, - const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const Frontend& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::PDINFER}; - valid_gpu_backends = {Backend::PDINFER}; - has_nms_ = true; - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - initialized = Initialize(); -} - -} // namespace ppdet -} // namespace vision -} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/centernet.h b/csrcs/fastdeploy/vision/ppdet/centernet.h deleted file mode 100644 index a6ae756cd10..00000000000 --- a/csrcs/fastdeploy/vision/ppdet/centernet.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once -#include "fastdeploy/vision/ppdet/ppyolo.h" - -namespace fastdeploy { -namespace vision { -namespace ppdet { - -class FASTDEPLOY_DECL CenterNet : public PPYOLO { - public: - CenterNet(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const Frontend& model_format = Frontend::PADDLE); - - virtual std::string ModelName() const { return "PaddleDetection/CenterNet"; } -}; -} // namespace ppdet -} // namespace vision -} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/model.h b/csrcs/fastdeploy/vision/ppdet/model.h index 89c59fd1a9f..81bad81c4c6 100644 --- a/csrcs/fastdeploy/vision/ppdet/model.h +++ b/csrcs/fastdeploy/vision/ppdet/model.h @@ -1,6 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include "fastdeploy/vision/ppdet/centernet.h" #include "fastdeploy/vision/ppdet/picodet.h" #include "fastdeploy/vision/ppdet/ppyolo.h" #include "fastdeploy/vision/ppdet/ppyoloe.h" +#include "fastdeploy/vision/ppdet/rcnn.h" #include "fastdeploy/vision/ppdet/yolov3.h" +#include "fastdeploy/vision/ppdet/yolox.h" diff --git a/csrcs/fastdeploy/vision/ppdet/picodet.cc b/csrcs/fastdeploy/vision/ppdet/picodet.cc index f8070a3d240..5f912b8cf45 100644 --- a/csrcs/fastdeploy/vision/ppdet/picodet.cc +++ b/csrcs/fastdeploy/vision/ppdet/picodet.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "fastdeploy/vision/ppdet/picodet.h" #include "yaml-cpp/yaml.h" diff --git a/csrcs/fastdeploy/vision/ppdet/picodet.h b/csrcs/fastdeploy/vision/ppdet/picodet.h index 90242692610..7b45b9baf17 100644 --- a/csrcs/fastdeploy/vision/ppdet/picodet.h +++ b/csrcs/fastdeploy/vision/ppdet/picodet.h @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include "fastdeploy/vision/ppdet/ppyoloe.h" diff --git a/csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc b/csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc index bd1fc4621fc..9134b59cd11 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc @@ -27,5 +27,21 @@ void BindPPDet(pybind11::module& m) { self.Predict(&mat, &res); return res; }); + pybind11::class_(ppdet_module, + "PPYOLO") + .def(pybind11::init()); + pybind11::class_(ppdet_module, + "PicoDet") + .def(pybind11::init()); + pybind11::class_(ppdet_module, + "YOLOX") + .def(pybind11::init()); + pybind11::class_( + ppdet_module, "FasterRCNN") + .def(pybind11::init()); } } // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/ppyolo.cc b/csrcs/fastdeploy/vision/ppdet/ppyolo.cc index 2926b54a092..307f8b4f07a 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyolo.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppyolo.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "fastdeploy/vision/ppdet/ppyolo.h" namespace fastdeploy { @@ -48,20 +62,16 @@ bool PPYOLO::Preprocess(Mat* mat, std::vector* outputs) { outputs->resize(3); (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape"); (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor"); - std::cout << "111111111" << std::endl; float* ptr0 = static_cast((*outputs)[0].MutableData()); ptr0[0] = mat->Height(); ptr0[1] = mat->Width(); - std::cout << "090909" << std::endl; float* ptr2 = static_cast((*outputs)[2].MutableData()); ptr2[0] = mat->Height() * 1.0 / origin_h; ptr2[1] = mat->Width() * 1.0 / origin_w; - std::cout << "88888" << std::endl; (*outputs)[1].name = "image"; mat->ShareWithTensor(&((*outputs)[1])); // reshape to [1, c, h, w] (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); - std::cout << "??????" << std::endl; return true; } diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc index 99dae0fc6b0..3aea781b4d8 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc @@ -101,21 +101,38 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() { FDASSERT(target_size.size(), "Require size of target_size be 2, but now it's " + std::to_string(target_size.size()) + "."); - FDASSERT(!keep_ratio, - "Only support keep_ratio is false while deploy " - "PaddleDetection model."); - int width = target_size[1]; - int height = target_size[0]; - processors_.push_back( - std::make_shared(width, height, -1.0, -1.0, interp, false)); + if (!keep_ratio) { + int width = target_size[1]; + int height = target_size[0]; + processors_.push_back( + std::make_shared(width, height, -1.0, -1.0, interp, false)); + } else { + int min_target_size = std::min(target_size[0], target_size[1]); + int max_target_size = std::max(target_size[0], target_size[1]); + processors_.push_back(std::make_shared( + min_target_size, interp, true, max_target_size)); + } } else if (op_name == "Permute") { - processors_.push_back(std::make_shared()); + // Do nothing, do permute as the last operation + continue; + // processors_.push_back(std::make_shared()); + } else if (op_name == "Pad") { + auto size = op["size"].as>(); + auto value = op["fill_value"].as>(); + processors_.push_back(std::make_shared("float")); + processors_.push_back( + std::make_shared(size[1], size[0], value)); + } else if (op_name == "PadStride") { + auto stride = op["stride"].as(); + processors_.push_back( + std::make_shared(stride, std::vector(3, 0))); } else { FDERROR << "Unexcepted preprocess operator: " << op_name << "." << std::endl; return false; } } + processors_.push_back(std::make_shared()); return true; } @@ -224,7 +241,6 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) { processed_data[0].PrintInfo("Before infer"); float* tmp = static_cast(processed_data[1].Data()); - std::cout << "==== " << tmp[0] << " " << tmp[1] << std::endl; std::vector infer_result; if (!Infer(processed_data, &infer_result)) { FDERROR << "Failed to inference while using model:" << ModelName() << "." diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h index 84fe0781ba1..6c79755d140 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include "fastdeploy/fastdeploy_model.h" #include "fastdeploy/vision/common/processors/transform.h" @@ -47,6 +61,11 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { bool normalized = true; bool has_nms_ = false; }; + +// Read configuration and build pipeline to process input image +bool BuildPreprocessPipelineFromConfig( + std::vector>* processors, + const std::string& config_file); } // namespace ppdet } // namespace vision } // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/rcnn.cc b/csrcs/fastdeploy/vision/ppdet/rcnn.cc new file mode 100644 index 00000000000..25b822e2468 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/rcnn.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ppdet/rcnn.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +FasterRCNN::FasterRCNN(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::PDINFER}; + valid_gpu_backends = {Backend::PDINFER}; + has_nms_ = true; + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool FasterRCNN::Initialize() { + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." + << std::endl; + return false; + } + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool FasterRCNN::Preprocess(Mat* mat, std::vector* outputs) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); + mat->PrintInfo("Origin"); + float scale[2] = {1.0, 1.0}; + for (size_t i = 0; i < processors_.size(); ++i) { + if (!(*(processors_[i].get()))(mat)) { + FDERROR << "Failed to process image data in " << processors_[i]->Name() + << "." << std::endl; + return false; + } + if (processors_[i]->Name().find("Resize") != std::string::npos) { + scale[0] = mat->Height() * 1.0 / origin_h; + scale[1] = mat->Width() * 1.0 / origin_w; + } + mat->PrintInfo(processors_[i]->Name()); + } + + outputs->resize(3); + (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape"); + (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor"); + float* ptr0 = static_cast((*outputs)[0].MutableData()); + ptr0[0] = mat->Height(); + ptr0[1] = mat->Width(); + float* ptr2 = static_cast((*outputs)[2].MutableData()); + ptr2[0] = scale[0]; + ptr2[1] = scale[1]; + (*outputs)[1].name = "image"; + mat->ShareWithTensor(&((*outputs)[1])); + // reshape to [1, c, h, w] + (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); + (*outputs)[0].PrintInfo("im_shape"); + (*outputs)[1].PrintInfo("image"); + (*outputs)[2].PrintInfo("scale_factor"); + return true; +} + +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/rcnn.h b/csrcs/fastdeploy/vision/ppdet/rcnn.h new file mode 100644 index 00000000000..2a9255a5492 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/rcnn.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/ppdet/ppyoloe.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +class FASTDEPLOY_DECL FasterRCNN : public PPYOLOE { + public: + FasterRCNN(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE); + + virtual std::string ModelName() const { return "PaddleDetection/FasterRCNN"; } + + virtual bool Preprocess(Mat* mat, std::vector* outputs); + virtual bool Initialize(); + + protected: + FasterRCNN() {} +}; +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/yolov3.cc b/csrcs/fastdeploy/vision/ppdet/yolov3.cc index 9d683819473..608105b040e 100644 --- a/csrcs/fastdeploy/vision/ppdet/yolov3.cc +++ b/csrcs/fastdeploy/vision/ppdet/yolov3.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "fastdeploy/vision/ppdet/yolov3.h" namespace fastdeploy { @@ -34,20 +48,16 @@ bool YOLOv3::Preprocess(Mat* mat, std::vector* outputs) { outputs->resize(3); (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape"); (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor"); - std::cout << "111111111" << std::endl; float* ptr0 = static_cast((*outputs)[0].MutableData()); ptr0[0] = mat->Height(); ptr0[1] = mat->Width(); - std::cout << "090909" << std::endl; float* ptr2 = static_cast((*outputs)[2].MutableData()); ptr2[0] = mat->Height() * 1.0 / origin_h; ptr2[1] = mat->Width() * 1.0 / origin_w; - std::cout << "88888" << std::endl; (*outputs)[1].name = "image"; mat->ShareWithTensor(&((*outputs)[1])); // reshape to [1, c, h, w] (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); - std::cout << "??????" << std::endl; return true; } diff --git a/csrcs/fastdeploy/vision/ppdet/yolov3.h b/csrcs/fastdeploy/vision/ppdet/yolov3.h index 9919c023358..27b1352c9c2 100644 --- a/csrcs/fastdeploy/vision/ppdet/yolov3.h +++ b/csrcs/fastdeploy/vision/ppdet/yolov3.h @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include "fastdeploy/vision/ppdet/ppyoloe.h" diff --git a/csrcs/fastdeploy/vision/ppdet/yolox.cc b/csrcs/fastdeploy/vision/ppdet/yolox.cc new file mode 100644 index 00000000000..90168d29891 --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/yolox.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ppdet/yolox.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +YOLOX::YOLOX(const std::string& model_file, const std::string& params_file, + const std::string& config_file, const RuntimeOption& custom_option, + const Frontend& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + background_label = -1; + keep_top_k = 1000; + nms_eta = 1; + nms_threshold = 0.65; + nms_top_k = 10000; + normalized = true; + score_threshold = 0.001; + initialized = Initialize(); +} + +bool YOLOX::Preprocess(Mat* mat, std::vector* outputs) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); + float scale[2] = {1.0, 1.0}; + mat->PrintInfo("Origin"); + for (size_t i = 0; i < processors_.size(); ++i) { + if (!(*(processors_[i].get()))(mat)) { + FDERROR << "Failed to process image data in " << processors_[i]->Name() + << "." << std::endl; + return false; + } + mat->PrintInfo(processors_[i]->Name()); + if (processors_[i]->Name().find("Resize") != std::string::npos) { + scale[0] = mat->Height() * 1.0 / origin_h; + scale[1] = mat->Width() * 1.0 / origin_w; + } + } + + outputs->resize(2); + (*outputs)[0].name = InputInfoOfRuntime(0).name; + mat->ShareWithTensor(&((*outputs)[0])); + + // reshape to [1, c, h, w] + (*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1); + + (*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name); + float* ptr = static_cast((*outputs)[1].MutableData()); + ptr[0] = scale[0]; + ptr[1] = scale[1]; + return true; +} +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/yolox.h b/csrcs/fastdeploy/vision/ppdet/yolox.h new file mode 100644 index 00000000000..e689674a4ec --- /dev/null +++ b/csrcs/fastdeploy/vision/ppdet/yolox.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/ppdet/ppyoloe.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +class FASTDEPLOY_DECL YOLOX : public PPYOLOE { + public: + YOLOX(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE); + + virtual bool Preprocess(Mat* mat, std::vector* outputs); + + virtual std::string ModelName() const { return "PaddleDetection/YOLOX"; } +}; +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppogg/yolov5lite.cc b/csrcs/fastdeploy/vision/ppogg/yolov5lite.cc index 320867f581a..c58f5f2cc97 100644 --- a/csrcs/fastdeploy/vision/ppogg/yolov5lite.cc +++ b/csrcs/fastdeploy/vision/ppogg/yolov5lite.cc @@ -328,7 +328,6 @@ bool YOLOv5Lite::Predict(cv::Mat* im, DetectionResult* result, #ifdef FASTDEPLOY_DEBUG TIMERECORD_START(0) #endif - std::cout << nms_iou_threshold << nms_iou_threshold << std::endl; Mat mat(*im); std::vector input_tensors(1); diff --git a/fastdeploy/vision/ppdet/__init__.py b/fastdeploy/vision/ppdet/__init__.py index 661ef0e1fcd..bc175785301 100644 --- a/fastdeploy/vision/ppdet/__init__.py +++ b/fastdeploy/vision/ppdet/__init__.py @@ -27,7 +27,7 @@ def __init__(self, model_format=Frontend.PADDLE): super(PPYOLOE, self).__init__(runtime_option) - assert model_format == Frontend.PADDLE, "PPYOLOE only support model format of Frontend.Paddle now." + assert model_format == Frontend.PADDLE, "PPYOLOE model only support model format of Frontend.Paddle now." self._model = C.vision.ppdet.PPYOLOE(model_file, params_file, config_file, self._runtime_option, model_format) @@ -36,3 +36,67 @@ def __init__(self, def predict(self, input_image): assert input_image is not None, "The input image data is None." return self._model.predict(input_image) + + +class PPYOLO(PPYOLOE): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=Frontend.PADDLE): + super(PPYOLO, self).__init__(runtime_option) + + assert model_format == Frontend.PADDLE, "PPYOLO model only support model format of Frontend.Paddle now." + self._model = C.vision.ppdet.PPYOLO(model_file, params_file, + config_file, self._runtime_option, + model_format) + assert self.initialized, "PPYOLO model initialize failed." + + +class YOLOX(PPYOLOE): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=Frontend.PADDLE): + super(YOLOX, self).__init__(runtime_option) + + assert model_format == Frontend.PADDLE, "YOLOX model only support model format of Frontend.Paddle now." + self._model = C.vision.ppdet.YOLOX(model_file, params_file, + config_file, self._runtime_option, + model_format) + assert self.initialized, "YOLOX model initialize failed." + + +class PicoDet(PPYOLOE): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=Frontend.PADDLE): + super(PicoDet, self).__init__(runtime_option) + + assert model_format == Frontend.PADDLE, "PicoDet model only support model format of Frontend.Paddle now." + self._model = C.vision.ppdet.PicoDet(model_file, params_file, + config_file, self._runtime_option, + model_format) + assert self.initialized, "PicoDet model initialize failed." + + +class FasterRCNN(PPYOLOE): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=Frontend.PADDLE): + super(FasterRCNN, self).__init__(runtime_option) + + assert model_format == Frontend.PADDLE, "FasterRCNN model only support model format of Frontend.Paddle now." + self._model = C.vision.ppdet.FasterRCNN( + model_file, params_file, config_file, self._runtime_option, + model_format) + assert self.initialized, "FasterRCNN model initialize failed." From aeaad8ef837798f7df8749ce6701aef186c239a4 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Wed, 3 Aug 2022 08:25:24 +0000 Subject: [PATCH 4/5] add model --- csrcs/fastdeploy/fastdeploy_runtime.cc | 2 +- csrcs/fastdeploy/vision/ppdet/ppyoloe.h | 4 ++++ examples/resources/.gitignore | 6 +++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/csrcs/fastdeploy/fastdeploy_runtime.cc b/csrcs/fastdeploy/fastdeploy_runtime.cc index 1f782a54ae9..9207d5cae63 100644 --- a/csrcs/fastdeploy/fastdeploy_runtime.cc +++ b/csrcs/fastdeploy/fastdeploy_runtime.cc @@ -212,7 +212,6 @@ void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) { trt_serialize_file = cache_file_path; } - bool Runtime::Init(const RuntimeOption& _option) { option = _option; if (option.model_format == Frontend::AUTOREC) { @@ -274,6 +273,7 @@ void Runtime::CreatePaddleBackend() { pd_option.mkldnn_cache_size = option.pd_mkldnn_cache_size; pd_option.use_gpu = (option.device == Device::GPU) ? true : false; pd_option.gpu_id = option.device_id; + pd_option.cpu_thread_num = option.cpu_thread_num; FDASSERT(option.model_format == Frontend::PADDLE, "PaddleBackend only support model format of Frontend::PADDLE."); backend_ = new PaddleBackend(); diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h index 6c79755d140..6a452af5639 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h @@ -60,6 +60,10 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { int64_t nms_top_k = 10000; bool normalized = true; bool has_nms_ = false; + + // This function will used to check if this model contains multiclass_nms + // and get parameters from the operator + void GetNmsInfo(); }; // Read configuration and build pipeline to process input image diff --git a/examples/resources/.gitignore b/examples/resources/.gitignore index f8c24f7a602..aadf7025293 100644 --- a/examples/resources/.gitignore +++ b/examples/resources/.gitignore @@ -6,6 +6,10 @@ models/*.pd* models/*.engine models/*.trt models/*.nb +models/*param* +models/*model* outputs/*.jpg outputs/*.jpeg -outputs/*.png \ No newline at end of file +outputs/*.png +outputs/*.txt +outputs/*.json \ No newline at end of file From 658253f67016564c4d466d76dfeb795bac2c371f Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Thu, 4 Aug 2022 09:29:14 +0000 Subject: [PATCH 5/5] fix some usage bugs for detection models --- .../backends/tensorrt/trt_backend.cc | 36 +++++----- .../backends/tensorrt/trt_backend.h | 3 + csrcs/fastdeploy/fastdeploy_model.cc | 6 +- csrcs/fastdeploy/fastdeploy_model.h | 22 +++--- .../vision/ppdet/build_preprocess.cc | 10 ++- csrcs/fastdeploy/vision/ppdet/model.h | 1 - csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc | 63 +++++++++++++---- csrcs/fastdeploy/vision/ppdet/ppyolo.cc | 2 - csrcs/fastdeploy/vision/ppdet/ppyoloe.cc | 5 -- csrcs/fastdeploy/vision/ppdet/ppyoloe.h | 3 - csrcs/fastdeploy/vision/ppdet/rcnn.cc | 5 -- csrcs/fastdeploy/vision/ppdet/yolov3.cc | 2 - csrcs/fastdeploy/vision/ppdet/yolox.cc | 2 - fastdeploy/__init__.py | 5 ++ fastdeploy/vision/ppdet/__init__.py | 24 +++++-- model_zoo/vision/ppyoloe/README.md | 52 -------------- model_zoo/vision/ppyoloe/api.md | 68 ------------------- model_zoo/vision/ppyoloe/cpp/CMakeLists.txt | 17 ----- model_zoo/vision/ppyoloe/cpp/README.md | 39 ----------- model_zoo/vision/ppyoloe/cpp/ppyoloe.cc | 51 -------------- model_zoo/vision/ppyoloe/ppyoloe.py | 24 ------- setup.py | 13 +++- 22 files changed, 130 insertions(+), 323 deletions(-) delete mode 100644 model_zoo/vision/ppyoloe/README.md delete mode 100644 model_zoo/vision/ppyoloe/api.md delete mode 100644 model_zoo/vision/ppyoloe/cpp/CMakeLists.txt delete mode 100644 model_zoo/vision/ppyoloe/cpp/README.md delete mode 100644 model_zoo/vision/ppyoloe/cpp/ppyoloe.cc delete mode 100644 model_zoo/vision/ppyoloe/ppyoloe.py diff --git a/csrcs/fastdeploy/backends/tensorrt/trt_backend.cc b/csrcs/fastdeploy/backends/tensorrt/trt_backend.cc index 6a9d21d370d..dd3f837d972 100644 --- a/csrcs/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/csrcs/fastdeploy/backends/tensorrt/trt_backend.cc @@ -54,8 +54,8 @@ std::vector toVec(const nvinfer1::Dims& dim) { bool CheckDynamicShapeConfig(const paddle2onnx::OnnxReader& reader, const TrtBackendOption& option) { - //paddle2onnx::ModelTensorInfo inputs[reader.NumInputs()]; - //std::string input_shapes[reader.NumInputs()]; + // paddle2onnx::ModelTensorInfo inputs[reader.NumInputs()]; + // std::string input_shapes[reader.NumInputs()]; std::vector inputs(reader.NumInputs()); std::vector input_shapes(reader.NumInputs()); for (int i = 0; i < reader.NumInputs(); ++i) { @@ -374,27 +374,27 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model, 1U << static_cast( nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto builder = SampleUniquePtr( + builder_ = SampleUniquePtr( nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger())); - if (!builder) { + if (!builder_) { FDERROR << "Failed to call createInferBuilder()." << std::endl; return false; } - auto network = SampleUniquePtr( - builder->createNetworkV2(explicitBatch)); - if (!network) { + network_ = SampleUniquePtr( + builder_->createNetworkV2(explicitBatch)); + if (!network_) { FDERROR << "Failed to call createNetworkV2()." << std::endl; return false; } - auto config = - SampleUniquePtr(builder->createBuilderConfig()); + auto config = SampleUniquePtr( + builder_->createBuilderConfig()); if (!config) { FDERROR << "Failed to call createBuilderConfig()." << std::endl; return false; } if (option.enable_fp16) { - if (!builder->platformHasFastFp16()) { + if (!builder_->platformHasFastFp16()) { FDWARNING << "Detected FP16 is not supported in the current GPU, " "will use FP32 instead." << std::endl; @@ -403,25 +403,25 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model, } } - auto parser = SampleUniquePtr( - nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger())); - if (!parser) { + parser_ = SampleUniquePtr( + nvonnxparser::createParser(*network_, sample::gLogger.getTRTLogger())); + if (!parser_) { FDERROR << "Failed to call createParser()." << std::endl; return false; } - if (!parser->parse(onnx_model.data(), onnx_model.size())) { + if (!parser_->parse(onnx_model.data(), onnx_model.size())) { FDERROR << "Failed to parse ONNX model by TensorRT." << std::endl; return false; } FDINFO << "Start to building TensorRT Engine..." << std::endl; - bool fp16 = builder->platformHasFastFp16(); - builder->setMaxBatchSize(option.max_batch_size); + bool fp16 = builder_->platformHasFastFp16(); + builder_->setMaxBatchSize(option.max_batch_size); config->setMaxWorkspaceSize(option.max_workspace_size); if (option.max_shape.size() > 0) { - auto profile = builder->createOptimizationProfile(); + auto profile = builder_->createOptimizationProfile(); FDASSERT(option.max_shape.size() == option.min_shape.size() && option.min_shape.size() == option.opt_shape.size(), "[TrtBackend] Size of max_shape/opt_shape/min_shape in " @@ -459,7 +459,7 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model, } SampleUniquePtr plan{ - builder->buildSerializedNetwork(*network, *config)}; + builder_->buildSerializedNetwork(*network_, *config)}; if (!plan) { FDERROR << "Failed to call buildSerializedNetwork()." << std::endl; return false; diff --git a/csrcs/fastdeploy/backends/tensorrt/trt_backend.h b/csrcs/fastdeploy/backends/tensorrt/trt_backend.h index b2555c57668..a6bc3b05309 100644 --- a/csrcs/fastdeploy/backends/tensorrt/trt_backend.h +++ b/csrcs/fastdeploy/backends/tensorrt/trt_backend.h @@ -85,6 +85,9 @@ class TrtBackend : public BaseBackend { private: std::shared_ptr engine_; std::shared_ptr context_; + SampleUniquePtr parser_; + SampleUniquePtr builder_; + SampleUniquePtr network_; cudaStream_t stream_{}; std::vector bindings_; std::vector inputs_desc_; diff --git a/csrcs/fastdeploy/fastdeploy_model.cc b/csrcs/fastdeploy/fastdeploy_model.cc index e434e19fa5b..0558dd625ae 100644 --- a/csrcs/fastdeploy/fastdeploy_model.cc +++ b/csrcs/fastdeploy/fastdeploy_model.cc @@ -53,7 +53,7 @@ bool FastDeployModel::InitRuntime() { << std::endl; return false; } - runtime_ = new Runtime(); + runtime_ = std::unique_ptr(new Runtime()); if (!runtime_->Init(runtime_option)) { return false; } @@ -88,7 +88,7 @@ bool FastDeployModel::CreateCpuBackend() { continue; } runtime_option.backend = valid_cpu_backends[i]; - runtime_ = new Runtime(); + runtime_ = std::unique_ptr(new Runtime()); if (!runtime_->Init(runtime_option)) { return false; } @@ -111,7 +111,7 @@ bool FastDeployModel::CreateGpuBackend() { continue; } runtime_option.backend = valid_gpu_backends[i]; - runtime_ = new Runtime(); + runtime_ = std::unique_ptr(new Runtime()); if (!runtime_->Init(runtime_option)) { return false; } diff --git a/csrcs/fastdeploy/fastdeploy_model.h b/csrcs/fastdeploy/fastdeploy_model.h index 070a905f411..df83ac52588 100644 --- a/csrcs/fastdeploy/fastdeploy_model.h +++ b/csrcs/fastdeploy/fastdeploy_model.h @@ -18,7 +18,7 @@ namespace fastdeploy { class FASTDEPLOY_DECL FastDeployModel { public: - virtual std::string ModelName() const { return "NameUndefined"; }; + virtual std::string ModelName() const { return "NameUndefined"; } virtual bool InitRuntime(); virtual bool CreateCpuBackend(); @@ -47,21 +47,21 @@ class FASTDEPLOY_DECL FastDeployModel { virtual bool DebugEnabled(); private: - Runtime* runtime_ = nullptr; + std::unique_ptr runtime_; bool runtime_initialized_ = false; bool debug_ = false; }; -#define TIMERECORD_START(id) \ - TimeCounter tc_##id; \ +#define TIMERECORD_START(id) \ + TimeCounter tc_##id; \ tc_##id.Start(); -#define TIMERECORD_END(id, prefix) \ - if (DebugEnabled()) { \ - tc_##id.End(); \ - FDLogger() << __FILE__ << "(" << __LINE__ << "):" << __FUNCTION__ << " " \ - << prefix << " duration = " << tc_##id.Duration() << "s." \ - << std::endl; \ +#define TIMERECORD_END(id, prefix) \ + if (DebugEnabled()) { \ + tc_##id.End(); \ + FDLogger() << __FILE__ << "(" << __LINE__ << "):" << __FUNCTION__ << " " \ + << prefix << " duration = " << tc_##id.Duration() << "s." \ + << std::endl; \ } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/build_preprocess.cc b/csrcs/fastdeploy/vision/ppdet/build_preprocess.cc index ee3b6a16a69..20348214e91 100644 --- a/csrcs/fastdeploy/vision/ppdet/build_preprocess.cc +++ b/csrcs/fastdeploy/vision/ppdet/build_preprocess.cc @@ -12,7 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/ppdet/ppyoloe.h" +#include "yaml-cpp/yaml.h" + +namespace fastdeploy { +namespace vision { bool BuildPreprocessPipelineFromConfig( std::vector>* processors, @@ -22,7 +27,7 @@ bool BuildPreprocessPipelineFromConfig( try { cfg = YAML::LoadFile(config_file); } catch (YAML::BadFile& e) { - FDERROR << "Failed to load yaml file " << config_file_ + FDERROR << "Failed to load yaml file " << config_file << ", maybe you should check this file." << std::endl; return false; } @@ -76,3 +81,6 @@ bool BuildPreprocessPipelineFromConfig( processors->push_back(std::make_shared()); return true; } + +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/model.h b/csrcs/fastdeploy/vision/ppdet/model.h index 81bad81c4c6..17541d7fef3 100644 --- a/csrcs/fastdeploy/vision/ppdet/model.h +++ b/csrcs/fastdeploy/vision/ppdet/model.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include "fastdeploy/vision/ppdet/centernet.h" #include "fastdeploy/vision/ppdet/picodet.h" #include "fastdeploy/vision/ppdet/ppyolo.h" #include "fastdeploy/vision/ppdet/ppyoloe.h" diff --git a/csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc b/csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc index 9134b59cd11..bcc1a047815 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppdet_pybind.cc @@ -27,21 +27,60 @@ void BindPPDet(pybind11::module& m) { self.Predict(&mat, &res); return res; }); - pybind11::class_(ppdet_module, - "PPYOLO") + + pybind11::class_(ppdet_module, + "PPYOLO") .def(pybind11::init()); - pybind11::class_(ppdet_module, - "PicoDet") + Frontend>()) + .def("predict", [](vision::ppdet::PPYOLO& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res); + return res; + }); + + pybind11::class_(ppdet_module, + "PicoDet") .def(pybind11::init()); - pybind11::class_(ppdet_module, - "YOLOX") + Frontend>()) + .def("predict", [](vision::ppdet::PicoDet& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res); + return res; + }); + + pybind11::class_(ppdet_module, "YOLOX") .def(pybind11::init()); - pybind11::class_( - ppdet_module, "FasterRCNN") + Frontend>()) + .def("predict", [](vision::ppdet::YOLOX& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res); + return res; + }); + + pybind11::class_(ppdet_module, + "FasterRCNN") .def(pybind11::init()); + Frontend>()) + .def("predict", + [](vision::ppdet::FasterRCNN& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res); + return res; + }); + + pybind11::class_(ppdet_module, + "YOLOv3") + .def(pybind11::init()) + .def("predict", [](vision::ppdet::YOLOv3& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res); + return res; + }); } } // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/ppyolo.cc b/csrcs/fastdeploy/vision/ppdet/ppyolo.cc index 307f8b4f07a..194ad4f69ef 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyolo.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppyolo.cc @@ -49,14 +49,12 @@ bool PPYOLO::Initialize() { bool PPYOLO::Preprocess(Mat* mat, std::vector* outputs) { int origin_w = mat->Width(); int origin_h = mat->Height(); - mat->PrintInfo("Origin"); for (size_t i = 0; i < processors_.size(); ++i) { if (!(*(processors_[i].get()))(mat)) { FDERROR << "Failed to process image data in " << processors_[i]->Name() << "." << std::endl; return false; } - mat->PrintInfo(processors_[i]->Name()); } outputs->resize(3); diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc index 3aea781b4d8..0e7d00c64b7 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc @@ -139,14 +139,12 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() { bool PPYOLOE::Preprocess(Mat* mat, std::vector* outputs) { int origin_w = mat->Width(); int origin_h = mat->Height(); - mat->PrintInfo("Origin"); for (size_t i = 0; i < processors_.size(); ++i) { if (!(*(processors_[i].get()))(mat)) { FDERROR << "Failed to process image data in " << processors_[i]->Name() << "." << std::endl; return false; } - mat->PrintInfo(processors_[i]->Name()); } outputs->resize(2); @@ -239,7 +237,6 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) { return false; } - processed_data[0].PrintInfo("Before infer"); float* tmp = static_cast(processed_data[1].Data()); std::vector infer_result; if (!Infer(processed_data, &infer_result)) { @@ -248,8 +245,6 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) { return false; } - infer_result[0].PrintInfo("Boxes"); - infer_result[1].PrintInfo("Num"); if (!Postprocess(infer_result, result)) { FDERROR << "Failed to postprocess while using model:" << ModelName() << "." << std::endl; diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h index 6a452af5639..3b7e24479c9 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h @@ -45,9 +45,6 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { protected: PPYOLOE() {} - // This function will used to check if this model contains multiclass_nms - // and get parameters from the operator - void GetNmsInfo(); std::vector> processors_; std::string config_file_; diff --git a/csrcs/fastdeploy/vision/ppdet/rcnn.cc b/csrcs/fastdeploy/vision/ppdet/rcnn.cc index 25b822e2468..c976293a80e 100644 --- a/csrcs/fastdeploy/vision/ppdet/rcnn.cc +++ b/csrcs/fastdeploy/vision/ppdet/rcnn.cc @@ -50,7 +50,6 @@ bool FasterRCNN::Initialize() { bool FasterRCNN::Preprocess(Mat* mat, std::vector* outputs) { int origin_w = mat->Width(); int origin_h = mat->Height(); - mat->PrintInfo("Origin"); float scale[2] = {1.0, 1.0}; for (size_t i = 0; i < processors_.size(); ++i) { if (!(*(processors_[i].get()))(mat)) { @@ -62,7 +61,6 @@ bool FasterRCNN::Preprocess(Mat* mat, std::vector* outputs) { scale[0] = mat->Height() * 1.0 / origin_h; scale[1] = mat->Width() * 1.0 / origin_w; } - mat->PrintInfo(processors_[i]->Name()); } outputs->resize(3); @@ -78,9 +76,6 @@ bool FasterRCNN::Preprocess(Mat* mat, std::vector* outputs) { mat->ShareWithTensor(&((*outputs)[1])); // reshape to [1, c, h, w] (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); - (*outputs)[0].PrintInfo("im_shape"); - (*outputs)[1].PrintInfo("image"); - (*outputs)[2].PrintInfo("scale_factor"); return true; } diff --git a/csrcs/fastdeploy/vision/ppdet/yolov3.cc b/csrcs/fastdeploy/vision/ppdet/yolov3.cc index 608105b040e..a02853dbbb8 100644 --- a/csrcs/fastdeploy/vision/ppdet/yolov3.cc +++ b/csrcs/fastdeploy/vision/ppdet/yolov3.cc @@ -35,14 +35,12 @@ YOLOv3::YOLOv3(const std::string& model_file, const std::string& params_file, bool YOLOv3::Preprocess(Mat* mat, std::vector* outputs) { int origin_w = mat->Width(); int origin_h = mat->Height(); - mat->PrintInfo("Origin"); for (size_t i = 0; i < processors_.size(); ++i) { if (!(*(processors_[i].get()))(mat)) { FDERROR << "Failed to process image data in " << processors_[i]->Name() << "." << std::endl; return false; } - mat->PrintInfo(processors_[i]->Name()); } outputs->resize(3); diff --git a/csrcs/fastdeploy/vision/ppdet/yolox.cc b/csrcs/fastdeploy/vision/ppdet/yolox.cc index 90168d29891..44f4ec0552f 100644 --- a/csrcs/fastdeploy/vision/ppdet/yolox.cc +++ b/csrcs/fastdeploy/vision/ppdet/yolox.cc @@ -42,14 +42,12 @@ bool YOLOX::Preprocess(Mat* mat, std::vector* outputs) { int origin_w = mat->Width(); int origin_h = mat->Height(); float scale[2] = {1.0, 1.0}; - mat->PrintInfo("Origin"); for (size_t i = 0; i < processors_.size(); ++i) { if (!(*(processors_[i].get()))(mat)) { FDERROR << "Failed to process image data in " << processors_[i]->Name() << "." << std::endl; return false; } - mat->PrintInfo(processors_[i]->Name()); if (processors_[i]->Name().find("Resize") != std::string::npos) { scale[0] = mat->Height() * 1.0 / origin_h; scale[1] = mat->Width() * 1.0 / origin_w; diff --git a/fastdeploy/__init__.py b/fastdeploy/__init__.py index 6a23cd3d2c4..b389669a371 100644 --- a/fastdeploy/__init__.py +++ b/fastdeploy/__init__.py @@ -16,6 +16,11 @@ import os import sys +try: + import paddle +except: + pass + def add_dll_search_dir(dir_path): os.environ["path"] = dir_path + ";" + os.environ["path"] diff --git a/fastdeploy/vision/ppdet/__init__.py b/fastdeploy/vision/ppdet/__init__.py index bc175785301..08d39a36b88 100644 --- a/fastdeploy/vision/ppdet/__init__.py +++ b/fastdeploy/vision/ppdet/__init__.py @@ -45,7 +45,7 @@ def __init__(self, config_file, runtime_option=None, model_format=Frontend.PADDLE): - super(PPYOLO, self).__init__(runtime_option) + super(PPYOLOE, self).__init__(runtime_option) assert model_format == Frontend.PADDLE, "PPYOLO model only support model format of Frontend.Paddle now." self._model = C.vision.ppdet.PPYOLO(model_file, params_file, @@ -61,7 +61,7 @@ def __init__(self, config_file, runtime_option=None, model_format=Frontend.PADDLE): - super(YOLOX, self).__init__(runtime_option) + super(PPYOLOE, self).__init__(runtime_option) assert model_format == Frontend.PADDLE, "YOLOX model only support model format of Frontend.Paddle now." self._model = C.vision.ppdet.YOLOX(model_file, params_file, @@ -77,7 +77,7 @@ def __init__(self, config_file, runtime_option=None, model_format=Frontend.PADDLE): - super(PicoDet, self).__init__(runtime_option) + super(PPYOLOE, self).__init__(runtime_option) assert model_format == Frontend.PADDLE, "PicoDet model only support model format of Frontend.Paddle now." self._model = C.vision.ppdet.PicoDet(model_file, params_file, @@ -93,10 +93,26 @@ def __init__(self, config_file, runtime_option=None, model_format=Frontend.PADDLE): - super(FasterRCNN, self).__init__(runtime_option) + super(PPYOLOE, self).__init__(runtime_option) assert model_format == Frontend.PADDLE, "FasterRCNN model only support model format of Frontend.Paddle now." self._model = C.vision.ppdet.FasterRCNN( model_file, params_file, config_file, self._runtime_option, model_format) assert self.initialized, "FasterRCNN model initialize failed." + + +class YOLOv3(PPYOLOE): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=Frontend.PADDLE): + super(PPYOLOE, self).__init__(runtime_option) + + assert model_format == Frontend.PADDLE, "YOLOv3 model only support model format of Frontend.Paddle now." + self._model = C.vision.ppdet.YOLOv3(model_file, params_file, + config_file, self._runtime_option, + model_format) + assert self.initialized, "YOLOv3 model initialize failed." diff --git a/model_zoo/vision/ppyoloe/README.md b/model_zoo/vision/ppyoloe/README.md deleted file mode 100644 index 42d18104ad8..00000000000 --- a/model_zoo/vision/ppyoloe/README.md +++ /dev/null @@ -1,52 +0,0 @@ -# PaddleDetection/PPYOLOE部署示例 - -- 当前支持PaddleDetection版本为[release/2.4](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4) - -本文档说明如何进行[PPYOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)的快速部署推理。本目录结构如下 -``` -. -├── cpp # C++ 代码目录 -│   ├── CMakeLists.txt # C++ 代码编译CMakeLists文件 -│   ├── README.md # C++ 代码编译部署文档 -│   └── ppyoloe.cc # C++ 示例代码 -├── README.md # PPYOLOE 部署文档 -└── ppyoloe.py # Python示例代码 -``` - -## 安装FastDeploy - -使用如下命令安装FastDeploy,注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu` -``` -# 安装fastdeploy-python工具 -pip install fastdeploy-python -``` - -## Python部署 - -执行如下代码即会自动下载PPYOLOE模型和测试图片 -``` -python ppyoloe.py -``` - -执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下 -``` -DetectionResult: [xmin, ymin, xmax, ymax, score, label_id] -162.380249,132.057449, 463.178345, 413.167114, 0.962918, 33 -414.914642,141.148666, 91.275269, 308.688293, 0.951003, 0 -163.449234,129.669067, 35.253891, 135.111786, 0.900734, 0 -267.232239,142.290436, 31.578918, 126.329773, 0.848709, 0 -581.790833,179.027115, 30.893127, 135.484940, 0.837986, 0 -104.407021,72.602615, 22.900627, 75.469055, 0.796468, 0 -348.795380,70.122147, 18.806061, 85.829330, 0.785557, 0 -364.118683,92.457428, 17.437622, 89.212891, 0.774282, 0 -75.180283,192.470490, 41.898407, 55.552414, 0.712569, 56 -328.133759,61.894299, 19.100616, 65.633575, 0.710519, 0 -504.797760,181.732574, 107.740814, 248.115082, 0.708902, 0 -379.063080,64.762360, 15.956146, 68.312546, 0.680725, 0 -25.858747,186.564178, 34.958130, 56.007080, 0.580415, 0 -``` - -## 其它文档 - -- [C++部署](./cpp/README.md) -- [PPYOLOE API文档](./api.md) diff --git a/model_zoo/vision/ppyoloe/api.md b/model_zoo/vision/ppyoloe/api.md deleted file mode 100644 index 99565e56365..00000000000 --- a/model_zoo/vision/ppyoloe/api.md +++ /dev/null @@ -1,68 +0,0 @@ -# PPYOLOE API说明 - -## Python API - -### PPYOLOE类 -``` -fastdeploy.vision.ppdet.PPYOLOE(model_file, params_file, config_file, runtime_option=None, model_format=fd.Frontend.PADDLE) -``` -PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式 - -**参数** - -> * **model_file**(str): 模型文件路径 -> * **params_file**(str): 参数文件路径 -> * **config_file**(str): 模型推理配置文件 -> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 -> * **model_format**(Frontend): 模型格式 - -#### predict函数 -> ``` -> PPYOLOE.predict(image_data) -> ``` -> 模型预测结口,输入图像直接输出检测结果。 -> -> **参数** -> -> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 - -示例代码参考[ppyoloe.py](./ppyoloe.py) - - -## C++ API - -### PPYOLOE类 -``` -fastdeploy::vision::ppdet::PPYOLOE( - const string& model_file, - const string& params_file, - const string& config_file, - const RuntimeOption& runtime_option = RuntimeOption(), - const Frontend& model_format = Frontend::PADDLE) -``` -PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式 - -**参数** - -> * **model_file**(str): 模型文件路径 -> * **params_file**(str): 参数文件路径 -> * **config_file**(str): 模型推理配置文件 -> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 -> * **model_format**(Frontend): 模型格式 - -#### Predict函数 -> ``` -> PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) -> ``` -> 模型预测接口,输入图像直接输出检测结果。 -> -> **参数** -> -> > * **im**: 输入图像,注意需为HWC,BGR格式 -> > * **result**: 检测结果,包括检测框,各个框的置信度 - -示例代码参考[cpp/ppyoloe.cc](cpp/ppyoloe.cc) - -## 其它API使用 - -- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md) diff --git a/model_zoo/vision/ppyoloe/cpp/CMakeLists.txt b/model_zoo/vision/ppyoloe/cpp/CMakeLists.txt deleted file mode 100644 index 6222a00da39..00000000000 --- a/model_zoo/vision/ppyoloe/cpp/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -PROJECT(ppyoloe_demo C CXX) -CMAKE_MINIMUM_REQUIRED (VERSION 3.16) - -# 在低版本ABI环境中,通过如下代码进行兼容性编译 -# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) - -# 指定下载解压后的fastdeploy库路径 -set(FASTDEPLOY_INSTALL_DIR /fastdeploy/CustomOp/FastDeploy/build1/fastdeploy-linux-x64-gpu-0.3.0) - -include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) - -# 添加FastDeploy依赖头文件 -include_directories(${FASTDEPLOY_INCS}) - -add_executable(ppyoloe_demo ${PROJECT_SOURCE_DIR}/ppyoloe.cc) -# 添加FastDeploy库依赖 -target_link_libraries(ppyoloe_demo ${FASTDEPLOY_LIBS}) diff --git a/model_zoo/vision/ppyoloe/cpp/README.md b/model_zoo/vision/ppyoloe/cpp/README.md deleted file mode 100644 index 1027c2eeb24..00000000000 --- a/model_zoo/vision/ppyoloe/cpp/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# 编译PPYOLOE示例 - - -``` -# 下载和解压预测库 -wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz -tar xvf fastdeploy-linux-x64-0.0.3.tgz - -# 编译示例代码 -mkdir build & cd build -cmake .. -make -j - -# 下载模型和图片 -wget https://bj.bcebos.com/paddle2onnx/fastdeploy/models/ppdet/ppyoloe_crn_l_300e_coco.tgz -tar xvf ppyoloe_crn_l_300e_coco.tgz -wget https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/release/2.4/demo/000000014439_640x640.jpg - -# 执行 -./ppyoloe_demo -``` - -执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示 -``` -DetectionResult: [xmin, ymin, xmax, ymax, score, label_id] -162.380249,132.057449, 463.178345, 413.167114, 0.962918, 33 -414.914642,141.148666, 91.275269, 308.688293, 0.951003, 0 -163.449234,129.669067, 35.253891, 135.111786, 0.900734, 0 -267.232239,142.290436, 31.578918, 126.329773, 0.848709, 0 -581.790833,179.027115, 30.893127, 135.484940, 0.837986, 0 -104.407021,72.602615, 22.900627, 75.469055, 0.796468, 0 -348.795380,70.122147, 18.806061, 85.829330, 0.785557, 0 -364.118683,92.457428, 17.437622, 89.212891, 0.774282, 0 -75.180283,192.470490, 41.898407, 55.552414, 0.712569, 56 -328.133759,61.894299, 19.100616, 65.633575, 0.710519, 0 -504.797760,181.732574, 107.740814, 248.115082, 0.708902, 0 -379.063080,64.762360, 15.956146, 68.312546, 0.680725, 0 -25.858747,186.564178, 34.958130, 56.007080, 0.580415, 0 -``` diff --git a/model_zoo/vision/ppyoloe/cpp/ppyoloe.cc b/model_zoo/vision/ppyoloe/cpp/ppyoloe.cc deleted file mode 100644 index e63f29e62a5..00000000000 --- a/model_zoo/vision/ppyoloe/cpp/ppyoloe.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision.h" - -int main() { - namespace vis = fastdeploy::vision; - - std::string model_file = "ppyoloe_crn_l_300e_coco/model.pdmodel"; - std::string params_file = "ppyoloe_crn_l_300e_coco/model.pdiparams"; - std::string config_file = "ppyoloe_crn_l_300e_coco/infer_cfg.yml"; - std::string img_path = "000000014439_640x640.jpg"; - std::string vis_path = "vis.jpeg"; - - auto model = vis::ppdet::PPYOLOE(model_file, params_file, config_file); - if (!model.Initialized()) { - std::cerr << "Init Failed." << std::endl; - return -1; - } - - cv::Mat im = cv::imread(img_path); - cv::Mat vis_im = im.clone(); - - vis::DetectionResult res; - if (!model.Predict(&im, &res)) { - std::cerr << "Prediction Failed." << std::endl; - return -1; - } else { - std::cout << "Prediction Done!" << std::endl; - } - - // 输出预测框结果 - std::cout << res.Str() << std::endl; - - // 可视化预测结果 - vis::Visualize::VisDetection(&vis_im, res); - cv::imwrite(vis_path, vis_im); - std::cout << "Detect Done! Saved: " << vis_path << std::endl; - return 0; -} diff --git a/model_zoo/vision/ppyoloe/ppyoloe.py b/model_zoo/vision/ppyoloe/ppyoloe.py deleted file mode 100644 index a3b12c1dc6b..00000000000 --- a/model_zoo/vision/ppyoloe/ppyoloe.py +++ /dev/null @@ -1,24 +0,0 @@ -import fastdeploy as fd -import cv2 - -# 下载模型和测试图片 -model_url = "https://bj.bcebos.com/paddle2onnx/fastdeploy/models/ppdet/ppyoloe_crn_l_300e_coco.tgz" -test_jpg_url = "https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/release/2.4/demo/000000014439_640x640.jpg" -fd.download_and_decompress(model_url, ".") -fd.download(test_jpg_url, ".", show_progress=True) - -# 加载模型 -model = fd.vision.ppdet.PPYOLOE("ppyoloe_crn_l_300e_coco/model.pdmodel", - "ppyoloe_crn_l_300e_coco/model.pdiparams", - "ppyoloe_crn_l_300e_coco/infer_cfg.yml") - -# 预测图片 -im = cv2.imread("000000014439_640x640.jpg") -result = model.predict(im) - -# 可视化结果 -fd.vision.visualize.vis_detection(im, result) -cv2.imwrite("vis_result.jpg", im) - -# 输出预测结果 -print(result) diff --git a/setup.py b/setup.py index e57dcd49338..15f5fc29d09 100644 --- a/setup.py +++ b/setup.py @@ -371,9 +371,13 @@ def run(self): for f1 in os.listdir(lib_dir_name): release_dir = os.path.join(lib_dir_name, f1) if f1 == "Release" and not os.path.isfile(release_dir): - if os.path.exists(os.path.join("fastdeploy/libs/third_libs", f)): - shutil.rmtree(os.path.join("fastdeploy/libs/third_libs", f)) - shutil.copytree(release_dir, os.path.join("fastdeploy/libs/third_libs", f, "lib")) + if os.path.exists( + os.path.join("fastdeploy/libs/third_libs", f)): + shutil.rmtree( + os.path.join("fastdeploy/libs/third_libs", f)) + shutil.copytree(release_dir, + os.path.join("fastdeploy/libs/third_libs", + f, "lib")) if platform.system().lower() == "windows": release_dir = os.path.join(".setuptools-cmake-build", "Release") @@ -398,6 +402,9 @@ def run(self): path)) rpaths = ":".join(rpaths) command = "patchelf --set-rpath '{}' ".format(rpaths) + pybind_so_file + print( + "=========================Set rpath for library===================") + print(command) # The sw_64 not suppot patchelf, so we just disable that. if platform.machine() != 'sw_64' and platform.machine() != 'mips64': assert os.system(