Skip to content

Commit

Permalink
[dev] inference support bs > 1 (#3003)
Browse files Browse the repository at this point in the history
* bs>1 for YOLO
  • Loading branch information
cnn committed May 20, 2021
1 parent fd49465 commit 5e19955
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 190 deletions.
2 changes: 2 additions & 0 deletions deploy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ python tools/export_model.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml
* C++部署 支持`CPU``GPU``XPU`环境,支持,windows、linux系统,支持NV Jetson嵌入式设备上部署。参考文档[C++部署](cpp/README.md)
* PaddleDetection支持TensorRT加速,相关文档请参考[TensorRT预测部署教程](TENSOR_RT.md)

**注意:** Paddle预测库版本需要>=2.1,batch_size>1仅支持YOLOv3和PP-YOLO。

## 2.PaddleServing部署
### 2.1 导出模型

Expand Down
9 changes: 6 additions & 3 deletions deploy/cpp/include/object_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::vector<int> GenerateColorMap(int num_class);
// Visualiztion Detection Result
cv::Mat VisualizeResult(const cv::Mat& img,
const std::vector<ObjectResult>& results,
const std::vector<std::string>& lable_list,
const std::vector<std::string>& lables,
const std::vector<int>& colormap,
const bool is_rbox);

Expand Down Expand Up @@ -93,11 +93,12 @@ class ObjectDetector {
const std::string& run_mode = "fluid");

// Run predictor
void Predict(const cv::Mat& im,
void Predict(const std::vector<cv::Mat> imgs,
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
std::vector<ObjectResult>* result = nullptr,
std::vector<int>* bbox_num = nullptr,
std::vector<double>* times = nullptr);

// Get Model Label list
Expand All @@ -120,14 +121,16 @@ class ObjectDetector {
void Preprocess(const cv::Mat& image_mat);
// Postprocess result
void Postprocess(
const cv::Mat& raw_mat,
const std::vector<cv::Mat> mats,
std::vector<ObjectResult>* result,
std::vector<int> bbox_num,
bool is_rbox);

std::shared_ptr<Predictor> predictor_;
Preprocessor preprocessor_;
ImageBlob inputs_;
std::vector<float> output_data_;
std::vector<int> out_bbox_num_data_;
float threshold_;
ConfigPaser config_;
std::vector<int> image_shape_;
Expand Down
143 changes: 94 additions & 49 deletions deploy/cpp/src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <numeric>
#include <sys/types.h>
#include <sys/stat.h>
#include <math.h>

#ifdef _WIN32
#include <direct.h>
Expand All @@ -37,6 +38,7 @@
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_file, "", "Path of input image");
DEFINE_string(image_dir, "", "Dir of input image, `image_file` has a higher priority.");
DEFINE_int32(batch_size, 1, "batch_size");
DEFINE_string(video_file, "", "Path of input video, `video_file` or `camera_id` has a highest priority.");
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
Expand Down Expand Up @@ -189,6 +191,7 @@ void PredictVideo(const std::string& video_path,
}

std::vector<PaddleDetection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times;
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
Expand All @@ -200,8 +203,9 @@ void PredictVideo(const std::string& video_path,
if (frame.empty()) {
break;
}

det->Predict(frame, 0.5, 0, 1, &result, &det_times);
std::vector<cv::Mat> imgs;
imgs.push_back(frame);
det->Predict(imgs, 0.5, 0, 1, &result, &bbox_num, &det_times);
for (const auto& item : result) {
if (item.rect.size() > 6){
is_rbox = true;
Expand Down Expand Up @@ -238,70 +242,107 @@ void PredictVideo(const std::string& video_path,
video_out.release();
}

void PredictImage(const std::vector<std::string> all_img_list,
void PredictImage(const std::vector<std::string> all_img_paths,
const int batch_size,
const double threshold,
const bool run_benchmark,
PaddleDetection::ObjectDetector* det,
const std::string& output_dir = "output") {
std::vector<double> det_t = {0, 0, 0};
for (auto image_file : all_img_list) {
// Open input image as an opencv cv::Mat object
cv::Mat im = cv::imread(image_file, 1);
int steps = ceil(float(all_img_paths.size()) / batch_size);
printf("total images = %d, batch_size = %d, total steps = %d\n",
all_img_paths.size(), batch_size, steps);
for (int idx = 0; idx < steps; idx++) {
std::vector<cv::Mat> batch_imgs;
int left_image_cnt = all_img_paths.size() - idx * batch_size;
if (left_image_cnt > batch_size) {
left_image_cnt = batch_size;
}
for (int bs = 0; bs < left_image_cnt; bs++) {
std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
cv::Mat im = cv::imread(image_file_path, 1);
batch_imgs.insert(batch_imgs.end(), im);
}

// Store all detected result
std::vector<PaddleDetection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times;
bool is_rbox = false;
if (run_benchmark) {
det->Predict(im, threshold, 10, 10, &result, &det_times);
det->Predict(batch_imgs, threshold, 10, 10, &result, &bbox_num, &det_times);
} else {
det->Predict(im, 0.5, 0, 1, &result, &det_times);
for (const auto& item : result) {
if (item.rect.size() > 6){
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
det->Predict(batch_imgs, 0.5, 0, 1, &result, &bbox_num, &det_times);
// get labels and colormap
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());

int item_start_idx = 0;
for (int i = 0; i < left_image_cnt; i++) {
std::cout << all_img_paths.at(idx * batch_size + i) << "result" << std::endl;
if (bbox_num[i] <= 1) {
continue;
}
else{
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
for (int j = 0; j < bbox_num[i]; j++) {
PaddleDetection::ObjectResult item = result[item_start_idx + j];
if (item.rect.size() > 6){
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
}
else{
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
}
item_start_idx = item_start_idx + bbox_num[i];
}
// Visualization result
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, result, labels, colormap, is_rbox);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
std::string output_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
output_path += OS_PATH_SEP;
int bbox_idx = 0;
for (int bs = 0; bs < batch_imgs.size(); bs++) {
if (bbox_num[bs] <= 1) {
continue;
}
cv::Mat im = batch_imgs[bs];
std::vector<PaddleDetection::ObjectResult> im_result;
for (int k = 0; k < bbox_num[bs]; k++) {
im_result.push_back(result[bbox_idx+k]);
}
bbox_idx += bbox_num[bs];
cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, im_result, labels, colormap, is_rbox);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
std::string output_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
output_path += OS_PATH_SEP;
}
std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
output_path += image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::imwrite(output_path, vis_img, compression_params);
printf("Visualized output saved as %s\n", output_path.c_str());
}
;
output_path += image_file.substr(image_file.find_last_of('/') + 1);
cv::imwrite(output_path, vis_img, compression_params);
printf("Visualized output saved as %s\n", output_path.c_str());
}
det_t[0] += det_times[0];
det_t[1] += det_times[1];
det_t[2] += det_times[2];
}
PrintBenchmarkLog(det_t, all_img_list.size());
PrintBenchmarkLog(det_t, all_img_paths.size());
}

int main(int argc, char** argv) {
Expand Down Expand Up @@ -329,13 +370,17 @@ int main(int argc, char** argv) {
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
std::vector<std::string> all_img_list;
std::vector<std::string> all_imgs;
if (!FLAGS_image_file.empty()) {
all_img_list.push_back(FLAGS_image_file);
all_imgs.push_back(FLAGS_image_file);
if (FLAGS_batch_size > 1) {
std::cout << "batch_size should be 1, when image_file is not None" << std::endl;
FLAGS_batch_size = 1;
}
} else {
GetAllFiles((char *)FLAGS_image_dir.c_str(), all_img_list);
GetAllFiles((char *)FLAGS_image_dir.c_str(), all_imgs);
}
PredictImage(all_img_list, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir);
PredictImage(all_imgs, FLAGS_batch_size, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir);
}
return 0;
}
Loading

0 comments on commit 5e19955

Please sign in to comment.