Skip to content

Commit

Permalink
fix(demo): fix num_class value in TRT inference (#251)
Browse files Browse the repository at this point in the history
fix(demo): fix num_class value in TRT inference
  • Loading branch information
F0xZz committed Aug 6, 2021
1 parent 5bc3f23 commit 0f8513d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions demo/TensorRT/cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ Check the 'model_trt.engine' file generated from Step 1, which will be automatic

Please follow the [TensorRT Installation Guide](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) to install TensorRT.

And you should set the TensorRT path and CUDA path in CMakeLists.txt.

If you train your custom dataset, you may need to modify the value of `num_class`.

```c++
const int num_class = 80;
```

Install opencv with ```sudo apt-get install libopencv-dev``` (we don't need a higher version of opencv like v3.3+).

build the demo:
Expand All @@ -37,3 +45,4 @@ or
```shell
./yolox <path/to/your/engine_file> -i <path/to/image>
```

2 changes: 1 addition & 1 deletion demo/TensorRT/cpp/yolox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ static void generate_yolox_proposals(std::vector<GridAndStride> grid_strides, fl
const int grid1 = grid_strides[anchor_idx].grid1;
const int stride = grid_strides[anchor_idx].stride;

const int basic_pos = anchor_idx * 85;
const int basic_pos = anchor_idx * (num_class + 5);

// yolox/models/yolo_head.py decode logic
float x_center = (feat_blob[basic_pos+0] + grid0) * stride;
Expand Down

0 comments on commit 0f8513d

Please sign in to comment.