From 3e4252a1e1574610124cc5613c40e9dbbe3cd997 Mon Sep 17 00:00:00 2001 From: Levi Pereira Date: Thu, 9 May 2024 20:12:07 -0300 Subject: [PATCH 1/6] Clone EfficientNMSPlugin Signed-off-by: Levi Pereira --- plugin/yoloNMSPlugin/CMakeLists.txt | 22 + plugin/yoloNMSPlugin/README.md | 170 ++++ .../YoloNMSPlugin_PluginConfig.yaml | 111 +++ .../YoloNMSPlugin_PluginGoldenIO.json | 80 ++ plugin/yoloNMSPlugin/yoloNMSInference.cu | 725 ++++++++++++++++++ plugin/yoloNMSPlugin/yoloNMSInference.cuh | 261 +++++++ plugin/yoloNMSPlugin/yoloNMSInference.h | 32 + plugin/yoloNMSPlugin/yoloNMSParameters.h | 63 ++ plugin/yoloNMSPlugin/yoloNMSPlugin.cpp | 621 +++++++++++++++ plugin/yoloNMSPlugin/yoloNMSPlugin.h | 122 +++ 10 files changed, 2207 insertions(+) create mode 100644 plugin/yoloNMSPlugin/CMakeLists.txt create mode 100644 plugin/yoloNMSPlugin/README.md create mode 100644 plugin/yoloNMSPlugin/YoloNMSPlugin_PluginConfig.yaml create mode 100644 plugin/yoloNMSPlugin/YoloNMSPlugin_PluginGoldenIO.json create mode 100644 plugin/yoloNMSPlugin/yoloNMSInference.cu create mode 100644 plugin/yoloNMSPlugin/yoloNMSInference.cuh create mode 100644 plugin/yoloNMSPlugin/yoloNMSInference.h create mode 100644 plugin/yoloNMSPlugin/yoloNMSParameters.h create mode 100644 plugin/yoloNMSPlugin/yoloNMSPlugin.cpp create mode 100644 plugin/yoloNMSPlugin/yoloNMSPlugin.h diff --git a/plugin/yoloNMSPlugin/CMakeLists.txt b/plugin/yoloNMSPlugin/CMakeLists.txt new file mode 100644 index 00000000..1f1d4169 --- /dev/null +++ b/plugin/yoloNMSPlugin/CMakeLists.txt @@ -0,0 +1,22 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +file(GLOB SRCS *.cpp) +set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS}) +set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE) +file(GLOB CU_SRCS *.cu) +set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS}) +set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE) diff --git a/plugin/yoloNMSPlugin/README.md b/plugin/yoloNMSPlugin/README.md new file mode 100644 index 00000000..b2af6231 --- /dev/null +++ b/plugin/yoloNMSPlugin/README.md @@ -0,0 +1,170 @@ +# Efficient NMS Plugin + +#### Table of Contents +- [Description](#description) +- [Structure](#structure) + * [Inputs](#inputs) + * [Dynamic Shape Support](#dynamic-shape-support) + * [Box Coding Type](#box-coding-type) + * [Outputs](#outputs) + * [Parameters](#parameters) +- [Algorithm](#algorithm) + * [Process Description](#process-description) + * [Performance Tuning](#performance-tuning) + * [Additional Resources](#additional-resources) +- [License](#license) + +## Description + +This TensorRT plugin implements an efficient algorithm to perform Non Maximum Suppression for object detection networks. + +This plugin is primarily intended for using with EfficientDet on TensorRT, as this network is particularly sensitive to the latencies introduced by slower NMS implementations. However, the plugin is generic enough that it will work correctly for other detections architectures, such as SSD or FasterRCNN. + +## Structure + +### Inputs + +The plugin has two modes of operation, depending on the given input data. The plugin will automatically detect which mode to operate as, depending on the number of inputs it receives, as follows: + +1. **Standard NMS Mode:** Only two input tensors are given, (i) the bounding box coordinates and (ii) the corresponding classification scores for each box. + +2. **Fused Box Decoder Mode:** Three input tensors are given, (i) the raw localization predictions for each box originating directly from the localization head of the network, (ii) the corresponding classification scores originating from the classification head of the network, and (iii) the default anchor box coordinates usually hardcoded as constant tensors in the network. + +Most object detection networks work by generating raw predictions from a "localization head" which adjust the coordinates of standard non-learned anchor coordinates to produce a tighter fitting bounding box. This process is called "box decoding", and it usually involves a large number of element-wise operations to transform the anchors to final box coordinates. As this can involve exponential operations on a large number of anchors, it can be computationally expensive, so this plugin gives the option of fusing the box decoder within the NMS operation which can be done in a far more efficient manner, resulting in lower latency for the network. + +#### Boxes Input +> **Input Shape:** `[batch_size, number_boxes, 4]` or `[batch_size, number_boxes, number_classes, 4]` +> +> **Data Type:** `float32` or `float16` + +The boxes input can have 3 dimensions in case a single box prediction is produced for all classes (such as in EfficientDet or SSD), or 4 dimensions when separate box predictions are generated for each class (such as in FasterRCNN), in which case `number_classes` >= 1 and must match the number of classes in the scores input. The final dimension represents the four coordinates that define the bounding box prediction. + +For *Standard NMS* mode, this tensor should contain the final box coordinates for each predicted detection. For *Fused Box Decoder* mode, this tensor should have the raw localization predictions. In either case, this data is given as `4` coordinates which makes up the final shape dimension. + +#### Scores Input +> **Input Shape:** `[batch_size, number_boxes, number_classes]` +> +> **Data Type:** `float32` or `float16` + +The scores input has `number_classes` elements with the predicted scores for each candidate class for each of the `number_boxes` anchor boxes. + +Usually, the score values will have passed through a sigmoid activation function before reaching the NMS operation. However, as an optimization, the pre-sigmoid raw scores can also be provided to the NMS plugin to reduce overall network latency. If raw scores are given, enable the `score_activation` parameter so they are processed accordingly. + +#### Anchors Input (Optional) +> **Input Shape:** `[1, number_boxes, 4]` or `[batch_size, number_boxes, 4]` +> +> **Data Type:** `float32` or `float16` + +Only used in *Fused Box Decoder* mode. It is much more efficient to perform the box decoding within this plugin. In this case, the boxes input will be treated as the raw localization head box corrections, and this third input should contain the default anchor/prior box coordinates. + +When used, the input must have 3 dimensions, where the first one may be either `1` in case anchors are constant for all images in a batch, or `batch_size` in case each image has different anchors -- such as in the box refinement NMS of FasterRCNN's second stage. + +### Dynamic Shape Support + +Most input shape dimensions, namely `batch_size`, `number_boxes`, and `number_classes`, for all inputs can be defined dynamically at runtime if the TensorRT engine is built with dynamic input shapes. However, once defined, these dimensions must match across all tensors that use them (e.g. the same `number_boxes` dimension must be given for both boxes and scores, etc.) + +### Box Coding Type +Different object detection networks represent their box coordinate system differently. The two types supported by this plugin are: + +1. **BoxCorners:** The four coordinates represent `[x1, y1, x2, y2]` values, where each x,y pair defines the top-left and bottom-right corners of a bounding box. +2. **BoxCenterSize:** The four coordinates represent `[x, y, w, h]` values, where the x,y pair define the box center location, and the w,h pair define its width and height. + +Note that for NMS purposes, horizontal and vertical coordinates are fully interchangeable. TensorFlow-trained networks, for example, often uses vertical-first coordinates such as `[y1, x1, y2, x2]`, but this coordinate system will work equally well under the BoxCorner coding. Similarly, `[y, x, h, w]` will be properly covered by the BoxCornerSize coding. + +In *Fused Box Decoder* mode, the boxes and anchor tensors should both use the same coding. + +### Outputs + +The following four output tensors are generated: + +- **num_detections:** + This is a `[batch_size, 1]` tensor of data type `int32`. The last dimension is a scalar indicating the number of valid detections per batch image. It can be less than `max_output_boxes`. Only the top `num_detections[i]` entries in `nms_boxes[i]`, `nms_scores[i]` and `nms_classes[i]` are valid. + +- **detection_boxes:** + This is a `[batch_size, max_output_boxes, 4]` tensor of data type `float32` or `float16`, containing the coordinates of non-max suppressed boxes. The output coordinates will always be in BoxCorner format, regardless of the input code type. + +- **detection_scores:** + This is a `[batch_size, max_output_boxes]` tensor of data type `float32` or `float16`, containing the scores for the boxes. + +- **detection_classes:** + This is a `[batch_size, max_output_boxes]` tensor of data type `int32`, containing the classes for the boxes. + +### Parameters + +| Type | Parameter | Description +|----------|--------------------------|-------------------------------------------------------- +|`float` |`score_threshold` * |The scalar threshold for score (low scoring boxes are removed). +|`float` |`iou_threshold` |The scalar threshold for IOU (additional boxes that have high IOU overlap with previously selected boxes are removed). +|`int` |`max_output_boxes` |The maximum number of detections to output per image. +|`int` |`background_class` |The label ID for the background class. If there is no background class, set it to `-1`. +|`bool` |`score_activation` * |Set to true to apply sigmoid activation to the confidence scores during NMS operation. +|`bool` |`class_agnostic` |Set to true to do class-independent NMS; otherwise, boxes of different classes would be considered separately during NMS. +|`int` |`box_coding` |Coding type used for boxes (and anchors if applicable), 0 = BoxCorner, 1 = BoxCenterSize. + +Parameters marked with a `*` have a non-negligible effect on runtime latency. See the [Performance Tuning](#performance-tuning) section below for more details on how to set them optimally. + +## Limitations + +The `EfficientNMS_ONNX_TRT` plugin's output may not always be sufficiently sized to capture all NMS-ed boxes. This is because it ignores the number of classes in the calculation of the output size (it produces an output of size `(batch_size * max_output_boxes_per_class, 3)` when in general, a tensor of size `(batch_size * max_output_boxes_per_class * num_classes, 3)`) would be required. This was a compromise made to keep the output size from growing uncontrollably since it lacks an attribute similar to `max_output_boxes` to control the number of output boxes globally. + +Due to this reason, please use TensorRT's inbuilt `INMSLayer` instead of the `EfficientNMS_ONNX_TRT` plugin wherever possible. + +## Algorithm + +### Process Description + +The NMS algorithm in this plugin first filters the scores below the given `scoreThreshold`. This subset of scores is then sorted, and their corresponding boxes are then further filtered out by removing boxes that overlap each other with an IOU above the given `iouThreshold`. + +The algorithm launcher and its relevant CUDA kernels are all defined in the `efficientNMSInference.cu` file. + +Specifically, the NMS algorithm does the following: + +- The scores are filtered with the `score_threshold` parameter to reject any scores below the score threshold, while maintaining indexing to cross-reference these scores to their corresponding box coordinates. This is done with the `EfficientNMSFilter` CUDA kernel. + +- If too many elements are kept, due to a very low (or zero) score threshold, the filter operation can become a bottleneck due to the atomic operations involved. To mitigate this, a fallback kernel `EfficientNMSDenseIndex` is used instead which passes all the score elements densely packed and indexed. This method is heuristically selected only if the score threshold is less than 0.007. + +- The selected scores that remain after filtering are sorted in descending order. The indexing is carefully handled to still maintain score to box relationships after sorting. + +- After sorting, the highest 4096 scores are processed by the `EfficientNMS` CUDA kernel. This algorithm uses the index data maintained throughout the previous steps to find the boxes corresponding to the remaining scores. If the fused box decoder is being used, decoding will happen until this stage, where only the top scoring boxes need to be decoded. + +- The NMS kernel uses an efficient filtering algorithm that largely reduces the number of IOU overlap cross-checks between box pairs. The boxes that survive the IOU filtering finally pass through to the output results. At this stage, the sigmoid activation is applied to only the final remaining scores, if `score_activation` is enabled, thereby greatly reducing the amount of sigmoid calculations required otherwise. + +### Performance Tuning + +The plugin implements a very efficient NMS algorithm which largely reduces the latency of this operation in comparison to other NMS plugins. However, there are certain considerations that can help to better fine tune its performance: + +#### Choosing the Score Threshold + +The algorithm is highly sensitive to the selected `score_threshold` parameter. With a higher threshold, fewer elements need to be processed and so the algorithm runs much faster. Therefore, it's beneficial to always select the highest possible score threshold that fulfills the application requirements. Threshold values lower than approximately 0.01 may cause substantially higher latency. + +#### Using Sigmoid Activation + +Depending on network configuration, it is usually more efficient to provide raw scores (pre-sigmoid) to the NMS plugin scores input, and enable the `score_activation` parameter. Doing so applies a sigmoid activation only to the last `max_output_boxes` selected scores, instead of all the predicted scores, largely reducing the computational cost. + +#### Class Independent NMS + +Some object detection networks/architectures like YOLO series need to use class-independent NMS operations. If `class_agnostic` is enabled, class-independent NMS is performed; otherwise, different classes would do NMS separately. + +#### Using the Fused Box Decoder + +When using networks with many anchors, such as EfficientDet or SSD, it may be more efficient to do box decoding within the NMS plugin. For this, pass the raw box predictions as the boxes input, and the default anchor coordinates as the optional third input to the plugin. + +### Additional Resources + +The following resources provide a deeper understanding of the NMS algorithm: + +#### Networks +- [EfficientDet](https://arxiv.org/abs/1911.09070) +- [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325) +- [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497) +- [Mask R-CNN](https://arxiv.org/abs/1703.06870) + + +#### Documentation +- [NMS algorithm](https://www.coursera.org/lecture/convolutional-neural-networks/non-max-suppression-dvrjH) +- [NonMaxSuppression ONNX Op](https://github.com/onnx/onnx/blob/master/docs/Operators.md#NonMaxSuppression) + +## License + +For terms and conditions for use, reproduction, and distribution, see the [TensorRT Software License Agreement](https://docs.nvidia.com/deeplearning/sdk/tensorrt-sla/index.html) +documentation. diff --git a/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginConfig.yaml b/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginConfig.yaml new file mode 100644 index 00000000..0a7ce414 --- /dev/null +++ b/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginConfig.yaml @@ -0,0 +1,111 @@ +--- +name: EfficientNMS_TRT +interface: "IPluginV2DynamicExt" +versions: + "1": + inputs: + - boxes + - scores + - anchors + outputs: + - num_detections + - detection_boxes + - detection_scores + - detection_classes + attributes: + - score_threshold + - iou_threshold + - max_output_boxes + - background_class + - score_activation + - class_agnostic + - box_coding + attribute_types: + score_threshold: float32 + iou_threshold: float32 + max_output_boxes: int32 + background_class: int32 + score_activation: int32 + class_agnostic: int32 + box_coding: int32 + attribute_length: + score_threshold: 1 + iou_threshold: 1 + max_output_boxes: 1 + background_class: 1 + score_activation: 1 + class_agnostic: 1 + box_coding: 1 + attribute_options: + score_threshold: + min: "=0" + max: "=pinf" + iou_threshold: + min: "0" + max: "=pinf" + max_output_boxes: + min: "0" + max: "=pinf" + background_class: + min: "=ninf" + max: "=pinf" + score_activation: + - 0 + - 1 + class_agnostic: + - 0 + - 1 + box_coding: + - 0 + - 1 + attributes_required: + - score_threshold + - iou_threshold + - max_output_boxes + - background_class + - score_activation + - box_coding + golden_io_path: "plugin/efficientNMSPlugin/EfficientNMSPlugin_PluginGoldenIO.json" + abs_tol: 1e-5 + rel_tol: 1e-5 + configs: + config1: + input_types: + boxes: float32 + scores: float32 + attribute_options: + "background_class": + value: -1 + shape: "1" + "score_activation": + value: 0 + shape: "1" + "class_agnostic": + value: 0 + shape: "1" + "box_coding": + value: 0 + shape: "1" + output_types: + num_detections: int32 + detection_boxes: float32 + class_agnostic: + input_types: + boxes: float32 + scores: float32 + attribute_options: + "background_class": + value: -1 + shape: "1" + "score_activation": + value: 0 + shape: "1" + "class_agnostic": + value: 1 + shape: "1" + "box_coding": + value: 0 + shape: "1" + output_types: + num_detections: int32 + detection_boxes: float32 diff --git a/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginGoldenIO.json b/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginGoldenIO.json new file mode 100644 index 00000000..ae6599c9 --- /dev/null +++ b/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginGoldenIO.json @@ -0,0 +1,80 @@ +{ + "config1": [ + { + "inputs": { + "boxes": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsIDYsIDQpLCB9ICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoAAAAAAAAAAAAAgD8AAIA/AAAAAM3MzD0AAIA/zcyMPwAAAADNzMy9AACAP2ZmZj8AAAAAAAAgQQAAgD8AADBBAAAAAJqZIUEAAIA/mpkxQQAAAAAAAMhCAACAPwAAykI=", + "polygraphy_class": "ndarray" + }, + "scores": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBUcnVlLCAnc2hhcGUnOiAoMSwgNiwgMiksIH0gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAo9Clc+CtejPpqZmT6kcD0/zczMPbgeRT+uR2E+uB4FPgrXIz6PwnU9exRuPwrXIz8=", + "polygraphy_class": "ndarray" + } + }, + "attributes": { + "score_threshold": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsKSwgfSAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoAAAAA", + "polygraphy_class": "ndarray" + }, + "iou_threshold": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsKSwgfSAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoAAAA/", + "polygraphy_class": "ndarray" + }, + "max_output_boxes": 6, + "background_class": -1, + "score_activation": false, + "class_agnostic": false, + "box_coding": 0 + }, + "outputs": { + "num_detections": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGk0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsIDEpLCB9ICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoGAAAA", + "polygraphy_class": "ndarray" + }, + "detection_boxes": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsIDYsIDQpLCB9ICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoAAAAAmpkhQQAAgD+amTFBAAAAAAAAyEIAAIA/AADKQgAAAAAAACBBAACAPwAAMEEAAAAAAADIQgAAgD8AAMpCAAAAAM3MzD0AAIA/zcyMPwAAAAAAAAAAAACAPwAAgD8=", + "polygraphy_class": "ndarray" + } + } + } + ], + "class_agnostic": [ + { + "inputs": { + "boxes": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsIDYsIDQpLCB9ICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoAAAAAAAAAAAAAgD8AAIA/AAAAAM3MzD0AAIA/zcyMPwAAAADNzMy9AACAP2ZmZj8AAAAAAAAgQQAAgD8AADBBAAAAAJqZIUEAAIA/mpkxQQAAAAAAAMhCAACAPwAAykI=", + "polygraphy_class": "ndarray" + }, + "scores": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBUcnVlLCAnc2hhcGUnOiAoMSwgNiwgMiksIH0gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAo9Clc+CtejPpqZmT6kcD0/zczMPbgeRT+uR2E+uB4FPgrXIz6PwnU9exRuPwrXIz8=", + "polygraphy_class": "ndarray" + } + }, + "attributes": { + "score_threshold": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsKSwgfSAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoAAAAA", + "polygraphy_class": "ndarray" + }, + "iou_threshold": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsKSwgfSAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoAAAA/", + "polygraphy_class": "ndarray" + }, + "max_output_boxes": 6, + "background_class": -1, + "score_activation": false, + "class_agnostic": true, + "box_coding": 0 + }, + "outputs": { + "num_detections": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGk0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsIDEpLCB9ICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoDAAAA", + "polygraphy_class": "ndarray" + }, + "detection_boxes": { + "array": "k05VTVBZAQB2AHsnZGVzY3InOiAnPGY0JywgJ2ZvcnRyYW5fb3JkZXInOiBGYWxzZSwgJ3NoYXBlJzogKDEsIDYsIDQpLCB9ICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAoAAAAAmpkhQQAAgD+amTFBAAAAAAAAyEIAAIA/AADKQgAAAADNzMw9AACAP83MjD8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "polygraphy_class": "ndarray" + } + } + } + ] +} \ No newline at end of file diff --git a/plugin/yoloNMSPlugin/yoloNMSInference.cu b/plugin/yoloNMSPlugin/yoloNMSInference.cu new file mode 100644 index 00000000..ba99cb56 --- /dev/null +++ b/plugin/yoloNMSPlugin/yoloNMSInference.cu @@ -0,0 +1,725 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/bboxUtils.h" +#include "cub/cub.cuh" +#include "cuda_runtime_api.h" + +#include "efficientNMSInference.cuh" +#include "efficientNMSInference.h" + +#define NMS_TILES 5 + +using namespace nvinfer1; +using namespace nvinfer1::plugin; + +template +__device__ float IOU(EfficientNMSParameters param, BoxCorner box1, BoxCorner box2) +{ + // Regardless of the selected box coding, IOU is always performed in BoxCorner coding. + // The boxes are copied so that they can be reordered without affecting the originals. + BoxCorner b1 = box1; + BoxCorner b2 = box2; + b1.reorder(); + b2.reorder(); + float intersectArea = BoxCorner::intersect(b1, b2).area(); + if (intersectArea <= 0.f) + { + return 0.f; + } + float unionArea = b1.area() + b2.area() - intersectArea; + if (unionArea <= 0.f) + { + return 0.f; + } + return intersectArea / unionArea; +} + +template +__device__ BoxCorner DecodeBoxes(EfficientNMSParameters param, int boxIdx, int anchorIdx, + const Tb* __restrict__ boxesInput, const Tb* __restrict__ anchorsInput) +{ + // The inputs will be in the selected coding format, as well as the decoding function. But the decoded box + // will always be returned as BoxCorner. + Tb box = boxesInput[boxIdx]; + if (!param.boxDecoder) + { + return BoxCorner(box); + } + Tb anchor = anchorsInput[anchorIdx]; + box.reorder(); + anchor.reorder(); + return BoxCorner(box.decode(anchor)); +} + +template +__device__ void MapNMSData(EfficientNMSParameters param, int idx, int imageIdx, const Tb* __restrict__ boxesInput, + const Tb* __restrict__ anchorsInput, const int* __restrict__ topClassData, const int* __restrict__ topAnchorsData, + const int* __restrict__ topNumData, const T* __restrict__ sortedScoresData, const int* __restrict__ sortedIndexData, + T& scoreMap, int& classMap, BoxCorner& boxMap, int& boxIdxMap) +{ + // idx: Holds the NMS box index, within the current batch. + // idxSort: Holds the batched NMS box index, which indexes the (filtered, but sorted) score buffer. + // scoreMap: Holds the score that corresponds to the indexed box being processed by NMS. + if (idx >= topNumData[imageIdx]) + { + return; + } + int idxSort = imageIdx * param.numScoreElements + idx; + scoreMap = sortedScoresData[idxSort]; + + // idxMap: Holds the re-mapped index, which indexes the (filtered, but unsorted) buffers. + // classMap: Holds the class that corresponds to the idx'th sorted score being processed by NMS. + // anchorMap: Holds the anchor that corresponds to the idx'th sorted score being processed by NMS. + int idxMap = imageIdx * param.numScoreElements + sortedIndexData[idxSort]; + classMap = topClassData[idxMap]; + int anchorMap = topAnchorsData[idxMap]; + + // boxIdxMap: Holds the re-re-mapped index, which indexes the (unfiltered, and unsorted) boxes input buffer. + boxIdxMap = -1; + if (param.shareLocation) // Shape of boxesInput: [batchSize, numAnchors, 1, 4] + { + boxIdxMap = imageIdx * param.numAnchors + anchorMap; + } + else // Shape of boxesInput: [batchSize, numAnchors, numClasses, 4] + { + int batchOffset = imageIdx * param.numAnchors * param.numClasses; + int anchorOffset = anchorMap * param.numClasses; + boxIdxMap = batchOffset + anchorOffset + classMap; + } + // anchorIdxMap: Holds the re-re-mapped index, which indexes the (unfiltered, and unsorted) anchors input buffer. + int anchorIdxMap = -1; + if (param.shareAnchors) // Shape of anchorsInput: [1, numAnchors, 4] + { + anchorIdxMap = anchorMap; + } + else // Shape of anchorsInput: [batchSize, numAnchors, 4] + { + anchorIdxMap = imageIdx * param.numAnchors + anchorMap; + } + // boxMap: Holds the box that corresponds to the idx'th sorted score being processed by NMS. + boxMap = DecodeBoxes(param, boxIdxMap, anchorIdxMap, boxesInput, anchorsInput); +} + +template +__device__ void WriteNMSResult(EfficientNMSParameters param, int* __restrict__ numDetectionsOutput, + T* __restrict__ nmsScoresOutput, int* __restrict__ nmsClassesOutput, BoxCorner* __restrict__ nmsBoxesOutput, + T threadScore, int threadClass, BoxCorner threadBox, int imageIdx, unsigned int resultsCounter) +{ + int outputIdx = imageIdx * param.numOutputBoxes + resultsCounter - 1; + if (param.scoreSigmoid) + { + nmsScoresOutput[outputIdx] = sigmoid_mp(threadScore); + } + else if (param.scoreBits > 0) + { + nmsScoresOutput[outputIdx] = add_mp(threadScore, (T) -1); + } + else + { + nmsScoresOutput[outputIdx] = threadScore; + } + nmsClassesOutput[outputIdx] = threadClass; + if (param.clipBoxes) + { + nmsBoxesOutput[outputIdx] = threadBox.clip((T) 0, (T) 1); + } + else + { + nmsBoxesOutput[outputIdx] = threadBox; + } + numDetectionsOutput[imageIdx] = resultsCounter; +} + +__device__ void WriteONNXResult(EfficientNMSParameters param, int* outputIndexData, int* __restrict__ nmsIndicesOutput, + int imageIdx, int threadClass, int boxIdxMap) +{ + int index = boxIdxMap % param.numAnchors; + int idx = atomicAdd((unsigned int*) &outputIndexData[0], 1); + nmsIndicesOutput[idx * 3 + 0] = imageIdx; + nmsIndicesOutput[idx * 3 + 1] = threadClass; + nmsIndicesOutput[idx * 3 + 2] = index; +} + +__global__ void PadONNXResult(EfficientNMSParameters param, int* outputIndexData, int* __restrict__ nmsIndicesOutput) +{ + if (threadIdx.x > 0) + { + return; + } + int pidx = outputIndexData[0] - 1; + if (pidx < 0) + { + return; + } + for (int idx = pidx + 1; idx < param.batchSize * param.numOutputBoxes; idx++) + { + nmsIndicesOutput[idx * 3 + 0] = nmsIndicesOutput[pidx * 3 + 0]; + nmsIndicesOutput[idx * 3 + 1] = nmsIndicesOutput[pidx * 3 + 1]; + nmsIndicesOutput[idx * 3 + 2] = nmsIndicesOutput[pidx * 3 + 2]; + } +} + +template +__global__ void EfficientNMS(EfficientNMSParameters param, const int* topNumData, int* outputIndexData, + int* outputClassData, const int* sortedIndexData, const T* __restrict__ sortedScoresData, + const int* __restrict__ topClassData, const int* __restrict__ topAnchorsData, const Tb* __restrict__ boxesInput, + const Tb* __restrict__ anchorsInput, int* __restrict__ numDetectionsOutput, T* __restrict__ nmsScoresOutput, + int* __restrict__ nmsClassesOutput, int* __restrict__ nmsIndicesOutput, BoxCorner* __restrict__ nmsBoxesOutput) +{ + unsigned int thread = threadIdx.x; + unsigned int imageIdx = blockIdx.y; + unsigned int tileSize = blockDim.x; + if (imageIdx >= param.batchSize) + { + return; + } + + int numSelectedBoxes = min(topNumData[imageIdx], param.numSelectedBoxes); + int numTiles = (numSelectedBoxes + tileSize - 1) / tileSize; + if (thread >= numSelectedBoxes) + { + return; + } + + __shared__ int blockState; + __shared__ unsigned int resultsCounter; + if (thread == 0) + { + blockState = 0; + resultsCounter = 0; + } + + int threadState[NMS_TILES]; + unsigned int boxIdx[NMS_TILES]; + T threadScore[NMS_TILES]; + int threadClass[NMS_TILES]; + BoxCorner threadBox[NMS_TILES]; + int boxIdxMap[NMS_TILES]; + for (int tile = 0; tile < numTiles; tile++) + { + threadState[tile] = 0; + boxIdx[tile] = thread + tile * blockDim.x; + MapNMSData(param, boxIdx[tile], imageIdx, boxesInput, anchorsInput, topClassData, topAnchorsData, + topNumData, sortedScoresData, sortedIndexData, threadScore[tile], threadClass[tile], threadBox[tile], + boxIdxMap[tile]); + } + + // Iterate through all boxes to NMS against. + for (int i = 0; i < numSelectedBoxes; i++) + { + int tile = i / tileSize; + + if (boxIdx[tile] == i) + { + // Iteration lead thread, figure out what the other threads should do, + // this will be signaled via the blockState shared variable. + if (threadState[tile] == -1) + { + // Thread already dead, this box was already dropped in a previous iteration, + // because it had a large IOU overlap with another lead thread previously, so + // it would never be kept anyway, therefore it can safely be skip all IOU operations + // in this iteration. + blockState = -1; // -1 => Signal all threads to skip iteration + } + else if (threadState[tile] == 0) + { + // As this box will be kept, this is a good place to find what index in the results buffer it + // should have, as this allows to perform an early loop exit if there are enough results. + if (resultsCounter >= param.numOutputBoxes) + { + blockState = -2; // -2 => Signal all threads to do an early loop exit. + } + else + { + // Thread is still alive, because it has not had a large enough IOU overlap with + // any other kept box previously. Therefore, this box will be kept for sure. However, + // we need to check against all other subsequent boxes from this position onward, + // to see how those other boxes will behave in future iterations. + blockState = 1; // +1 => Signal all (higher index) threads to calculate IOU against this box + threadState[tile] = 1; // +1 => Mark this box's thread to be kept and written out to results + + // If the numOutputBoxesPerClass check is enabled, write the result only if the limit for this + // class on this image has not been reached yet. Other than (possibly) skipping the write, this + // won't affect anything else in the NMS threading. + bool write = true; + if (param.numOutputBoxesPerClass >= 0) + { + int classCounterIdx = imageIdx * param.numClasses + threadClass[tile]; + write = (outputClassData[classCounterIdx] < param.numOutputBoxesPerClass); + outputClassData[classCounterIdx]++; + } + if (write) + { + // This branch is visited by one thread per iteration, so it's safe to do non-atomic increments. + resultsCounter++; + if (param.outputONNXIndices) + { + WriteONNXResult( + param, outputIndexData, nmsIndicesOutput, imageIdx, threadClass[tile], boxIdxMap[tile]); + } + else + { + WriteNMSResult(param, numDetectionsOutput, nmsScoresOutput, nmsClassesOutput, + nmsBoxesOutput, threadScore[tile], threadClass[tile], threadBox[tile], imageIdx, + resultsCounter); + } + } + } + } + else + { + // This state should never be reached, but just in case... + blockState = 0; // 0 => Signal all threads to not do any updates, nothing happens. + } + } + + __syncthreads(); + + if (blockState == -2) + { + // This is the signal to exit from the loop. + return; + } + + if (blockState == -1) + { + // This is the signal for all threads to just skip this iteration, as no IOU's need to be checked. + continue; + } + + // Grab a box and class to test the current box against. The test box corresponds to iteration i, + // therefore it will have a lower index than the current thread box, and will therefore have a higher score + // than the current box because it's located "before" in the sorted score list. + T testScore; + int testClass; + BoxCorner testBox; + int testBoxIdxMap; + MapNMSData(param, i, imageIdx, boxesInput, anchorsInput, topClassData, topAnchorsData, topNumData, + sortedScoresData, sortedIndexData, testScore, testClass, testBox, testBoxIdxMap); + + for (int tile = 0; tile < numTiles; tile++) + { + bool ignoreClass = true; + if (!param.classAgnostic) + { + ignoreClass = threadClass[tile] == testClass; + } + + // IOU + if (boxIdx[tile] > i && // Make sure two different boxes are being tested, and that it's a higher index; + boxIdx[tile] < numSelectedBoxes && // Make sure the box is within numSelectedBoxes; + blockState == 1 && // Signal that allows IOU checks to be performed; + threadState[tile] == 0 && // Make sure this box hasn't been either dropped or kept already; + ignoreClass && // Compare only boxes of matching classes when classAgnostic is false; + lte_mp(threadScore[tile], testScore) && // Make sure the sorting order of scores is as expected; + IOU(param, threadBox[tile], testBox) >= param.iouThreshold) // And... IOU overlap. + { + // Current box overlaps with the box tested in this iteration, this box will be skipped. + threadState[tile] = -1; // -1 => Mark this box's thread to be dropped. + } + } + } +} + +template +cudaError_t EfficientNMSLauncher(EfficientNMSParameters& param, int* topNumData, int* outputIndexData, + int* outputClassData, int* sortedIndexData, T* sortedScoresData, int* topClassData, int* topAnchorsData, + const void* boxesInput, const void* anchorsInput, int* numDetectionsOutput, T* nmsScoresOutput, + int* nmsClassesOutput, int* nmsIndicesOutput, void* nmsBoxesOutput, cudaStream_t stream) +{ + unsigned int tileSize = param.numSelectedBoxes / NMS_TILES; + if (param.numSelectedBoxes <= 512) + { + tileSize = 512; + } + if (param.numSelectedBoxes <= 256) + { + tileSize = 256; + } + + const dim3 blockSize = {tileSize, 1, 1}; + const dim3 gridSize = {1, (unsigned int) param.batchSize, 1}; + + if (param.boxCoding == 0) + { + EfficientNMS><<>>(param, topNumData, outputIndexData, + outputClassData, sortedIndexData, sortedScoresData, topClassData, topAnchorsData, + (BoxCorner*) boxesInput, (BoxCorner*) anchorsInput, numDetectionsOutput, nmsScoresOutput, + nmsClassesOutput, nmsIndicesOutput, (BoxCorner*) nmsBoxesOutput); + } + else if (param.boxCoding == 1) + { + // Note that nmsBoxesOutput is always coded as BoxCorner, regardless of the input coding type. + EfficientNMS><<>>(param, topNumData, outputIndexData, + outputClassData, sortedIndexData, sortedScoresData, topClassData, topAnchorsData, + (BoxCenterSize*) boxesInput, (BoxCenterSize*) anchorsInput, numDetectionsOutput, nmsScoresOutput, + nmsClassesOutput, nmsIndicesOutput, (BoxCorner*) nmsBoxesOutput); + } + + if (param.outputONNXIndices) + { + PadONNXResult<<<1, 1, 0, stream>>>(param, outputIndexData, nmsIndicesOutput); + } + + return cudaGetLastError(); +} + +__global__ void EfficientNMSFilterSegments(EfficientNMSParameters param, const int* __restrict__ topNumData, + int* __restrict__ topOffsetsStartData, int* __restrict__ topOffsetsEndData) +{ + int imageIdx = threadIdx.x; + if (imageIdx > param.batchSize) + { + return; + } + topOffsetsStartData[imageIdx] = imageIdx * param.numScoreElements; + topOffsetsEndData[imageIdx] = imageIdx * param.numScoreElements + topNumData[imageIdx]; +} + +template +__global__ void EfficientNMSFilter(EfficientNMSParameters param, const T* __restrict__ scoresInput, + int* __restrict__ topNumData, int* __restrict__ topIndexData, int* __restrict__ topAnchorsData, + T* __restrict__ topScoresData, int* __restrict__ topClassData) +{ + int elementIdx = blockDim.x * blockIdx.x + threadIdx.x; + int imageIdx = blockDim.y * blockIdx.y + threadIdx.y; + + // Boundary Conditions + if (elementIdx >= param.numScoreElements || imageIdx >= param.batchSize) + { + return; + } + + // Shape of scoresInput: [batchSize, numAnchors, numClasses] + int scoresInputIdx = imageIdx * param.numScoreElements + elementIdx; + + // For each class, check its corresponding score if it crosses the threshold, and if so select this anchor, + // and keep track of the maximum score and the corresponding (argmax) class id + T score = scoresInput[scoresInputIdx]; + if (gte_mp(score, (T) param.scoreThreshold)) + { + // Unpack the class and anchor index from the element index + int classIdx = elementIdx % param.numClasses; + int anchorIdx = elementIdx / param.numClasses; + + // If this is a background class, ignore it. + if (classIdx == param.backgroundClass) + { + return; + } + + // Use an atomic to find an open slot where to write the selected anchor data. + if (topNumData[imageIdx] >= param.numScoreElements) + { + return; + } + int selectedIdx = atomicAdd((unsigned int*) &topNumData[imageIdx], 1); + if (selectedIdx >= param.numScoreElements) + { + topNumData[imageIdx] = param.numScoreElements; + return; + } + + // Shape of topScoresData / topClassData: [batchSize, numScoreElements] + int topIdx = imageIdx * param.numScoreElements + selectedIdx; + + if (param.scoreBits > 0) + { + score = add_mp(score, (T) 1); + if (gt_mp(score, (T) (2.f - 1.f / 1024.f))) + { + // Ensure the incremented score fits in the mantissa without changing the exponent + score = (2.f - 1.f / 1024.f); + } + } + + topIndexData[topIdx] = selectedIdx; + topAnchorsData[topIdx] = anchorIdx; + topScoresData[topIdx] = score; + topClassData[topIdx] = classIdx; + } +} + +template +__global__ void EfficientNMSDenseIndex(EfficientNMSParameters param, int* __restrict__ topNumData, + int* __restrict__ topIndexData, int* __restrict__ topAnchorsData, int* __restrict__ topOffsetsStartData, + int* __restrict__ topOffsetsEndData, T* __restrict__ topScoresData, int* __restrict__ topClassData) +{ + int elementIdx = blockDim.x * blockIdx.x + threadIdx.x; + int imageIdx = blockDim.y * blockIdx.y + threadIdx.y; + + if (elementIdx >= param.numScoreElements || imageIdx >= param.batchSize) + { + return; + } + + int dataIdx = imageIdx * param.numScoreElements + elementIdx; + int anchorIdx = elementIdx / param.numClasses; + int classIdx = elementIdx % param.numClasses; + if (param.scoreBits > 0) + { + T score = topScoresData[dataIdx]; + if (lt_mp(score, (T) param.scoreThreshold)) + { + score = (T) 1; + } + else if (classIdx == param.backgroundClass) + { + score = (T) 1; + } + else + { + score = add_mp(score, (T) 1); + if (gt_mp(score, (T) (2.f - 1.f / 1024.f))) + { + // Ensure the incremented score fits in the mantissa without changing the exponent + score = (2.f - 1.f / 1024.f); + } + } + topScoresData[dataIdx] = score; + } + else + { + T score = topScoresData[dataIdx]; + if (lt_mp(score, (T) param.scoreThreshold)) + { + topScoresData[dataIdx] = -(1 << 15); + } + else if (classIdx == param.backgroundClass) + { + topScoresData[dataIdx] = -(1 << 15); + } + } + + topIndexData[dataIdx] = elementIdx; + topAnchorsData[dataIdx] = anchorIdx; + topClassData[dataIdx] = classIdx; + + if (elementIdx == 0) + { + // Saturate counters + topNumData[imageIdx] = param.numScoreElements; + topOffsetsStartData[imageIdx] = imageIdx * param.numScoreElements; + topOffsetsEndData[imageIdx] = (imageIdx + 1) * param.numScoreElements; + } +} + +template +cudaError_t EfficientNMSFilterLauncher(EfficientNMSParameters& param, const T* scoresInput, int* topNumData, + int* topIndexData, int* topAnchorsData, int* topOffsetsStartData, int* topOffsetsEndData, T* topScoresData, + int* topClassData, cudaStream_t stream) +{ + const unsigned int elementsPerBlock = 512; + const unsigned int imagesPerBlock = 1; + const unsigned int elementBlocks = (param.numScoreElements + elementsPerBlock - 1) / elementsPerBlock; + const unsigned int imageBlocks = (param.batchSize + imagesPerBlock - 1) / imagesPerBlock; + const dim3 blockSize = {elementsPerBlock, imagesPerBlock, 1}; + const dim3 gridSize = {elementBlocks, imageBlocks, 1}; + + float kernelSelectThreshold = 0.007f; + if (param.scoreSigmoid) + { + // Inverse Sigmoid + if (param.scoreThreshold <= 0.f) + { + param.scoreThreshold = -(1 << 15); + } + else + { + param.scoreThreshold = logf(param.scoreThreshold / (1.f - param.scoreThreshold)); + } + kernelSelectThreshold = logf(kernelSelectThreshold / (1.f - kernelSelectThreshold)); + // Disable Score Bits Optimization + param.scoreBits = -1; + } + + if (param.scoreThreshold < kernelSelectThreshold) + { + // A full copy of the buffer is necessary because sorting will scramble the input data otherwise. + PLUGIN_CHECK_CUDA(cudaMemcpyAsync(topScoresData, scoresInput, + param.batchSize * param.numScoreElements * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + + EfficientNMSDenseIndex<<>>(param, topNumData, topIndexData, topAnchorsData, + topOffsetsStartData, topOffsetsEndData, topScoresData, topClassData); + } + else + { + EfficientNMSFilter<<>>( + param, scoresInput, topNumData, topIndexData, topAnchorsData, topScoresData, topClassData); + + EfficientNMSFilterSegments<<<1, param.batchSize, 0, stream>>>( + param, topNumData, topOffsetsStartData, topOffsetsEndData); + } + + return cudaGetLastError(); +} + +template +size_t EfficientNMSSortWorkspaceSize(int batchSize, int numScoreElements) +{ + size_t sortedWorkspaceSize = 0; + cub::DoubleBuffer keysDB(nullptr, nullptr); + cub::DoubleBuffer valuesDB(nullptr, nullptr); + cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, sortedWorkspaceSize, keysDB, valuesDB, + numScoreElements, batchSize, (const int*) nullptr, (const int*) nullptr); + return sortedWorkspaceSize; +} + +size_t EfficientNMSWorkspaceSize(int batchSize, int numScoreElements, int numClasses, DataType datatype) +{ + size_t total = 0; + const size_t align = 256; + // Counters + // 3 for Filtering + // 1 for Output Indexing + // C for Max per Class Limiting + size_t size = (3 + 1 + numClasses) * batchSize * sizeof(int); + total += size + (size % align ? align - (size % align) : 0); + // Int Buffers + for (int i = 0; i < 4; i++) + { + size = batchSize * numScoreElements * sizeof(int); + total += size + (size % align ? align - (size % align) : 0); + } + // Float Buffers + for (int i = 0; i < 2; i++) + { + size = batchSize * numScoreElements * dataTypeSize(datatype); + total += size + (size % align ? align - (size % align) : 0); + } + // Sort Workspace + if (datatype == DataType::kHALF) + { + size = EfficientNMSSortWorkspaceSize<__half>(batchSize, numScoreElements); + total += size + (size % align ? align - (size % align) : 0); + } + else if (datatype == DataType::kFLOAT) + { + size = EfficientNMSSortWorkspaceSize(batchSize, numScoreElements); + total += size + (size % align ? align - (size % align) : 0); + } + + return total; +} + +template +T* EfficientNMSWorkspace(void* workspace, size_t& offset, size_t elements) +{ + T* buffer = (T*) ((size_t) workspace + offset); + size_t align = 256; + size_t size = elements * sizeof(T); + size_t sizeAligned = size + (size % align ? align - (size % align) : 0); + offset += sizeAligned; + return buffer; +} + +template +pluginStatus_t EfficientNMSDispatch(EfficientNMSParameters param, const void* boxesInput, const void* scoresInput, + const void* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput, + void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream) +{ + // Clear Outputs (not all elements will get overwritten by the kernels, so safer to clear everything out) + if (param.outputONNXIndices) + { + CSC(cudaMemsetAsync(nmsIndicesOutput, 0xFF, param.batchSize * param.numOutputBoxes * 3 * sizeof(int), stream), STATUS_FAILURE); + } + else + { + CSC(cudaMemsetAsync(numDetectionsOutput, 0x00, param.batchSize * sizeof(int), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); + } + + // Empty Inputs + if (param.numScoreElements < 1) + { + return STATUS_SUCCESS; + } + + // Counters Workspace + size_t workspaceOffset = 0; + int countersTotalSize = (3 + 1 + param.numClasses) * param.batchSize; + int* topNumData = EfficientNMSWorkspace(workspace, workspaceOffset, countersTotalSize); + int* topOffsetsStartData = topNumData + param.batchSize; + int* topOffsetsEndData = topNumData + 2 * param.batchSize; + int* outputIndexData = topNumData + 3 * param.batchSize; + int* outputClassData = topNumData + 4 * param.batchSize; + CSC(cudaMemsetAsync(topNumData, 0x00, countersTotalSize * sizeof(int), stream), STATUS_FAILURE); + cudaError_t status = cudaGetLastError(); + CSC(status, STATUS_FAILURE); + + // Other Buffers Workspace + int* topIndexData + = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* topClassData + = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* topAnchorsData + = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* sortedIndexData + = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + T* topScoresData = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + T* sortedScoresData + = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + size_t sortedWorkspaceSize = EfficientNMSSortWorkspaceSize(param.batchSize, param.numScoreElements); + char* sortedWorkspaceData = EfficientNMSWorkspace(workspace, workspaceOffset, sortedWorkspaceSize); + cub::DoubleBuffer scoresDB(topScoresData, sortedScoresData); + cub::DoubleBuffer indexDB(topIndexData, sortedIndexData); + + // Kernels + status = EfficientNMSFilterLauncher(param, (T*) scoresInput, topNumData, topIndexData, topAnchorsData, + topOffsetsStartData, topOffsetsEndData, topScoresData, topClassData, stream); + CSC(status, STATUS_FAILURE); + + status = cub::DeviceSegmentedRadixSort::SortPairsDescending(sortedWorkspaceData, sortedWorkspaceSize, scoresDB, + indexDB, param.batchSize * param.numScoreElements, param.batchSize, topOffsetsStartData, topOffsetsEndData, + param.scoreBits > 0 ? (10 - param.scoreBits) : 0, param.scoreBits > 0 ? 10 : sizeof(T) * 8, stream); + CSC(status, STATUS_FAILURE); + + status = EfficientNMSLauncher(param, topNumData, outputIndexData, outputClassData, indexDB.Current(), + scoresDB.Current(), topClassData, topAnchorsData, boxesInput, anchorsInput, (int*) numDetectionsOutput, + (T*) nmsScoresOutput, (int*) nmsClassesOutput, (int*) nmsIndicesOutput, nmsBoxesOutput, stream); + CSC(status, STATUS_FAILURE); + + return STATUS_SUCCESS; +} + +pluginStatus_t EfficientNMSInference(EfficientNMSParameters param, const void* boxesInput, const void* scoresInput, + const void* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput, + void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream) +{ + if (param.datatype == DataType::kFLOAT) + { + param.scoreBits = -1; + return EfficientNMSDispatch(param, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, + nmsBoxesOutput, nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, workspace, stream); + } + else if (param.datatype == DataType::kHALF) + { + if (param.scoreBits <= 0 || param.scoreBits > 10) + { + param.scoreBits = -1; + } + return EfficientNMSDispatch<__half>(param, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, + nmsBoxesOutput, nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, workspace, stream); + } + else + { + return STATUS_NOT_SUPPORTED; + } +} diff --git a/plugin/yoloNMSPlugin/yoloNMSInference.cuh b/plugin/yoloNMSPlugin/yoloNMSInference.cuh new file mode 100644 index 00000000..bf12c359 --- /dev/null +++ b/plugin/yoloNMSPlugin/yoloNMSInference.cuh @@ -0,0 +1,261 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_EFFICIENT_NMS_INFERENCE_CUH +#define TRT_EFFICIENT_NMS_INFERENCE_CUH + +#include + +// FP32 Intrinsics + +float __device__ __inline__ exp_mp(const float a) +{ + return __expf(a); +} +float __device__ __inline__ sigmoid_mp(const float a) +{ + return __frcp_rn(__fadd_rn(1.f, __expf(-a))); +} +float __device__ __inline__ add_mp(const float a, const float b) +{ + return __fadd_rn(a, b); +} +float __device__ __inline__ sub_mp(const float a, const float b) +{ + return __fsub_rn(a, b); +} +float __device__ __inline__ mul_mp(const float a, const float b) +{ + return __fmul_rn(a, b); +} +bool __device__ __inline__ gt_mp(const float a, const float b) +{ + return a > b; +} +bool __device__ __inline__ lt_mp(const float a, const float b) +{ + return a < b; +} +bool __device__ __inline__ lte_mp(const float a, const float b) +{ + return a <= b; +} +bool __device__ __inline__ gte_mp(const float a, const float b) +{ + return a >= b; +} + +#if __CUDA_ARCH__ >= 530 + +// FP16 Intrinsics + +__half __device__ __inline__ exp_mp(const __half a) +{ + return hexp(a); +} +__half __device__ __inline__ sigmoid_mp(const __half a) +{ + return hrcp(__hadd((__half) 1, hexp(__hneg(a)))); +} +__half __device__ __inline__ add_mp(const __half a, const __half b) +{ + return __hadd(a, b); +} +__half __device__ __inline__ sub_mp(const __half a, const __half b) +{ + return __hsub(a, b); +} +__half __device__ __inline__ mul_mp(const __half a, const __half b) +{ + return __hmul(a, b); +} +bool __device__ __inline__ gt_mp(const __half a, const __half b) +{ + return __hgt(a, b); +} +bool __device__ __inline__ lt_mp(const __half a, const __half b) +{ + return __hlt(a, b); +} +bool __device__ __inline__ lte_mp(const __half a, const __half b) +{ + return __hle(a, b); +} +bool __device__ __inline__ gte_mp(const __half a, const __half b) +{ + return __hge(a, b); +} + +#else + +// FP16 Fallbacks on older architectures that lack support + +__half __device__ __inline__ exp_mp(const __half a) +{ + return __float2half(exp_mp(__half2float(a))); +} +__half __device__ __inline__ sigmoid_mp(const __half a) +{ + return __float2half(sigmoid_mp(__half2float(a))); +} +__half __device__ __inline__ add_mp(const __half a, const __half b) +{ + return __float2half(add_mp(__half2float(a), __half2float(b))); +} +__half __device__ __inline__ sub_mp(const __half a, const __half b) +{ + return __float2half(sub_mp(__half2float(a), __half2float(b))); +} +__half __device__ __inline__ mul_mp(const __half a, const __half b) +{ + return __float2half(mul_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ gt_mp(const __half a, const __half b) +{ + return __float2half(gt_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ lt_mp(const __half a, const __half b) +{ + return __float2half(lt_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ lte_mp(const __half a, const __half b) +{ + return __float2half(lte_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ gte_mp(const __half a, const __half b) +{ + return __float2half(gte_mp(__half2float(a), __half2float(b))); +} + +#endif + +template +struct __align__(4 * sizeof(T)) BoxCorner; + +template +struct __align__(4 * sizeof(T)) BoxCenterSize; + +template +struct __align__(4 * sizeof(T)) BoxCorner +{ + // For NMS/IOU purposes, YXYX coding is identical to XYXY + T y1, x1, y2, x2; + + __device__ void reorder() + { + if (gt_mp(y1, y2)) + { + // Swap values, so y1 < y2 + y1 = sub_mp(y1, y2); + y2 = add_mp(y1, y2); + y1 = sub_mp(y2, y1); + } + if (gt_mp(x1, x2)) + { + // Swap values, so x1 < x2 + x1 = sub_mp(x1, x2); + x2 = add_mp(x1, x2); + x1 = sub_mp(x2, x1); + } + } + + __device__ BoxCorner clip(T low, T high) const + { + return {lt_mp(y1, low) ? low : (gt_mp(y1, high) ? high : y1), + lt_mp(x1, low) ? low : (gt_mp(x1, high) ? high : x1), lt_mp(y2, low) ? low : (gt_mp(y2, high) ? high : y2), + lt_mp(x2, low) ? low : (gt_mp(x2, high) ? high : x2)}; + } + + __device__ BoxCorner decode(BoxCorner anchor) const + { + return {add_mp(y1, anchor.y1), add_mp(x1, anchor.x1), add_mp(y2, anchor.y2), add_mp(x2, anchor.x2)}; + } + + __device__ float area() const + { + T w = sub_mp(x2, x1); + T h = sub_mp(y2, y1); + if (lte_mp(h, (T) 0)) + { + return 0; + } + if (lte_mp(w, (T) 0)) + { + return 0; + } + return (float) h * (float) w; + } + + __device__ operator BoxCenterSize() const + { + T w = sub_mp(x2, x1); + T h = sub_mp(y2, y1); + return BoxCenterSize{add_mp(y1, mul_mp((T) 0.5, h)), add_mp(x1, mul_mp((T) 0.5, w)), h, w}; + } + + __device__ static BoxCorner intersect(BoxCorner a, BoxCorner b) + { + return {gt_mp(a.y1, b.y1) ? a.y1 : b.y1, gt_mp(a.x1, b.x1) ? a.x1 : b.x1, lt_mp(a.y2, b.y2) ? a.y2 : b.y2, + lt_mp(a.x2, b.x2) ? a.x2 : b.x2}; + } +}; + +template +struct __align__(4 * sizeof(T)) BoxCenterSize +{ + // For NMS/IOU purposes, YXHW coding is identical to XYWH + T y, x, h, w; + + __device__ void reorder() {} + + __device__ BoxCenterSize clip(T low, T high) const + { + return BoxCenterSize(BoxCorner(*this).clip(low, high)); + } + + __device__ BoxCenterSize decode(BoxCenterSize anchor) const + { + return {add_mp(mul_mp(y, anchor.h), anchor.y), add_mp(mul_mp(x, anchor.w), anchor.x), + mul_mp(anchor.h, exp_mp(h)), mul_mp(anchor.w, exp_mp(w))}; + } + + __device__ float area() const + { + if (h <= (T) 0) + { + return 0; + } + if (w <= (T) 0) + { + return 0; + } + return (float) h * (float) w; + } + + __device__ operator BoxCorner() const + { + T h2 = mul_mp(h, (T) 0.5); + T w2 = mul_mp(w, (T) 0.5); + return BoxCorner{sub_mp(y, h2), sub_mp(x, w2), add_mp(y, h2), add_mp(x, w2)}; + } + __device__ static BoxCenterSize intersect(BoxCenterSize a, BoxCenterSize b) + { + return BoxCenterSize(BoxCorner::intersect(BoxCorner(a), BoxCorner(b))); + } +}; + +#endif diff --git a/plugin/yoloNMSPlugin/yoloNMSInference.h b/plugin/yoloNMSPlugin/yoloNMSInference.h new file mode 100644 index 00000000..d9ec3192 --- /dev/null +++ b/plugin/yoloNMSPlugin/yoloNMSInference.h @@ -0,0 +1,32 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_EFFICIENT_NMS_INFERENCE_H +#define TRT_EFFICIENT_NMS_INFERENCE_H + +#include "common/plugin.h" + +#include "efficientNMSParameters.h" + +size_t EfficientNMSWorkspaceSize( + int32_t batchSize, int32_t numScoreElements, int32_t numClasses, nvinfer1::DataType datatype); + +pluginStatus_t EfficientNMSInference(nvinfer1::plugin::EfficientNMSParameters param, void const* boxesInput, + void const* scoresInput, void const* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, + void* nmsScoresOutput, void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream); + +#endif diff --git a/plugin/yoloNMSPlugin/yoloNMSParameters.h b/plugin/yoloNMSPlugin/yoloNMSParameters.h new file mode 100644 index 00000000..89829089 --- /dev/null +++ b/plugin/yoloNMSPlugin/yoloNMSParameters.h @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_EFFICIENT_NMS_PARAMETERS_H +#define TRT_EFFICIENT_NMS_PARAMETERS_H + +#include "common/plugin.h" + +namespace nvinfer1 +{ +namespace plugin +{ + +struct EfficientNMSParameters +{ + // Related to NMS Options + float iouThreshold = 0.5F; + float scoreThreshold = 0.5F; + int32_t numOutputBoxes = 100; + int32_t numOutputBoxesPerClass = -1; + bool padOutputBoxesPerClass = false; + int32_t backgroundClass = -1; + bool scoreSigmoid = false; + bool clipBoxes = false; + int32_t boxCoding = 0; + bool classAgnostic = false; + + // Related to NMS Internals + int32_t numSelectedBoxes = 4096; + int32_t scoreBits = -1; + bool outputONNXIndices = false; + + // Related to Tensor Configuration + // (These are set by the various plugin configuration methods, no need to define them during plugin creation.) + int32_t batchSize = -1; + int32_t numClasses = 1; + int32_t numBoxElements = -1; + int32_t numScoreElements = -1; + int32_t numAnchors = -1; + bool shareLocation = true; + bool shareAnchors = true; + bool boxDecoder = false; + nvinfer1::DataType datatype = nvinfer1::DataType::kFLOAT; +}; + +} // namespace plugin +} // namespace nvinfer1 + +#endif diff --git a/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp b/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp new file mode 100644 index 00000000..2f5d428b --- /dev/null +++ b/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp @@ -0,0 +1,621 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "efficientNMSPlugin.h" +#include "efficientNMSInference.h" + +using namespace nvinfer1; +using nvinfer1::plugin::EfficientNMSPlugin; +using nvinfer1::plugin::EfficientNMSParameters; +using nvinfer1::plugin::EfficientNMSPluginCreator; +using nvinfer1::plugin::EfficientNMSONNXPluginCreator; + +namespace +{ +char const* const kEFFICIENT_NMS_PLUGIN_VERSION{"1"}; +char const* const kEFFICIENT_NMS_PLUGIN_NAME{"EfficientNMS_TRT"}; +char const* const kEFFICIENT_NMS_ONNX_PLUGIN_VERSION{"1"}; +char const* const kEFFICIENT_NMS_ONNX_PLUGIN_NAME{"EfficientNMS_ONNX_TRT"}; +} // namespace + +EfficientNMSPlugin::EfficientNMSPlugin(EfficientNMSParameters param) + : mParam(std::move(param)) +{ +} + +EfficientNMSPlugin::EfficientNMSPlugin(void const* data, size_t length) +{ + deserialize(static_cast(data), length); +} + +void EfficientNMSPlugin::deserialize(int8_t const* data, size_t length) +{ + auto const* d{data}; + mParam = read(d); + PLUGIN_VALIDATE(d == data + length); +} + +char const* EfficientNMSPlugin::getPluginType() const noexcept +{ + return kEFFICIENT_NMS_PLUGIN_NAME; +} + +char const* EfficientNMSPlugin::getPluginVersion() const noexcept +{ + return kEFFICIENT_NMS_PLUGIN_VERSION; +} + +int32_t EfficientNMSPlugin::getNbOutputs() const noexcept +{ + if (mParam.outputONNXIndices) + { + // ONNX NonMaxSuppression Compatibility + return 1; + } + + // Standard Plugin Implementation + return 4; +} + +int32_t EfficientNMSPlugin::initialize() noexcept +{ + if (!initialized) + { + int32_t device; + CSC(cudaGetDevice(&device), STATUS_FAILURE); + struct cudaDeviceProp properties; + CSC(cudaGetDeviceProperties(&properties, device), STATUS_FAILURE); + if (properties.regsPerBlock >= 65536) + { + // Most Devices + mParam.numSelectedBoxes = 5000; + } + else + { + // Jetson TX1/TX2 + mParam.numSelectedBoxes = 2000; + } + initialized = true; + } + return STATUS_SUCCESS; +} + +void EfficientNMSPlugin::terminate() noexcept {} + +size_t EfficientNMSPlugin::getSerializationSize() const noexcept +{ + return sizeof(EfficientNMSParameters); +} + +void EfficientNMSPlugin::serialize(void* buffer) const noexcept +{ + char *d = reinterpret_cast(buffer), *a = d; + write(d, mParam); + PLUGIN_ASSERT(d == a + getSerializationSize()); +} + +void EfficientNMSPlugin::destroy() noexcept +{ + delete this; +} + +void EfficientNMSPlugin::setPluginNamespace(char const* pluginNamespace) noexcept +{ + try + { + mNamespace = pluginNamespace; + } + catch (std::exception const& e) + { + caughtError(e); + } +} + +char const* EfficientNMSPlugin::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +nvinfer1::DataType EfficientNMSPlugin::getOutputDataType( + int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept +{ + if (mParam.outputONNXIndices) + { + // ONNX NMS uses an integer output + return nvinfer1::DataType::kINT32; + } + + // On standard NMS, num_detections and detection_classes use integer outputs + if (index == 0 || index == 3) + { + return nvinfer1::DataType::kINT32; + } + // All others should use the same datatype as the input + return inputTypes[0]; +} + +IPluginV2DynamicExt* EfficientNMSPlugin::clone() const noexcept +{ + try + { + auto* plugin = new EfficientNMSPlugin(mParam); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +DimsExprs EfficientNMSPlugin::getOutputDimensions( + int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept +{ + try + { + DimsExprs out_dim; + + // When pad per class is set, the output size may need to be reduced: + // i.e.: outputBoxes = min(outputBoxes, outputBoxesPerClass * numClasses) + // As the number of classes may not be static, numOutputBoxes must be a dynamic + // expression. The corresponding parameter can not be set at this time, so the + // value will be calculated again in configurePlugin() and the param overwritten. + IDimensionExpr const* numOutputBoxes = exprBuilder.constant(mParam.numOutputBoxes); + if (mParam.padOutputBoxesPerClass && mParam.numOutputBoxesPerClass > 0) + { + IDimensionExpr const* numOutputBoxesPerClass = exprBuilder.constant(mParam.numOutputBoxesPerClass); + IDimensionExpr const* numClasses = inputs[1].d[2]; + numOutputBoxes = exprBuilder.operation(DimensionOperation::kMIN, *numOutputBoxes, + *exprBuilder.operation(DimensionOperation::kPROD, *numOutputBoxesPerClass, *numClasses)); + } + + if (mParam.outputONNXIndices) + { + // ONNX NMS + PLUGIN_ASSERT(outputIndex == 0); + + // detection_indices + out_dim.nbDims = 2; + out_dim.d[0] = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[0], *numOutputBoxes); + out_dim.d[1] = exprBuilder.constant(3); + } + else + { + // Standard NMS + PLUGIN_ASSERT(outputIndex >= 0 && outputIndex <= 3); + + // num_detections + if (outputIndex == 0) + { + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = exprBuilder.constant(1); + } + // detection_boxes + else if (outputIndex == 1) + { + out_dim.nbDims = 3; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = numOutputBoxes; + out_dim.d[2] = exprBuilder.constant(4); + } + // detection_scores: outputIndex == 2 + // detection_classes: outputIndex == 3 + else if (outputIndex == 2 || outputIndex == 3) + { + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = numOutputBoxes; + } + } + + return out_dim; + } + catch (std::exception const& e) + { + caughtError(e); + } + return DimsExprs{}; +} + +bool EfficientNMSPlugin::supportsFormatCombination( + int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept +{ + if (inOut[pos].format != PluginFormat::kLINEAR) + { + return false; + } + + if (mParam.outputONNXIndices) + { + PLUGIN_ASSERT(nbInputs == 2); + PLUGIN_ASSERT(nbOutputs == 1); + + // detection_indices output: int32_t + if (pos == 2) + { + return inOut[pos].type == DataType::kINT32; + } + + // boxes and scores input: fp32 or fp16 + return (inOut[pos].type == DataType::kHALF || inOut[pos].type == DataType::kFLOAT) + && (inOut[0].type == inOut[pos].type); + } + + PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); + PLUGIN_ASSERT(nbOutputs == 4); + if (nbInputs == 2) + { + PLUGIN_ASSERT(0 <= pos && pos <= 5); + } + if (nbInputs == 3) + { + PLUGIN_ASSERT(0 <= pos && pos <= 6); + } + + // num_detections and detection_classes output: int32_t + int32_t const posOut = pos - nbInputs; + if (posOut == 0 || posOut == 3) + { + return inOut[pos].type == DataType::kINT32 && inOut[pos].format == PluginFormat::kLINEAR; + } + + // all other inputs/outputs: fp32 or fp16 + return (inOut[pos].type == DataType::kHALF || inOut[pos].type == DataType::kFLOAT) + && (inOut[0].type == inOut[pos].type); +} + +void EfficientNMSPlugin::configurePlugin( + DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept +{ + try + { + if (mParam.outputONNXIndices) + { + // Accepts two inputs + // [0] boxes, [1] scores + PLUGIN_ASSERT(nbInputs == 2); + PLUGIN_ASSERT(nbOutputs == 1); + } + else + { + // Accepts two or three inputs + // If two inputs: [0] boxes, [1] scores + // If three inputs: [0] boxes, [1] scores, [2] anchors + PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); + PLUGIN_ASSERT(nbOutputs == 4); + } + mParam.datatype = in[0].desc.type; + + // Shape of scores input should be + // [batch_size, num_boxes, num_classes] or [batch_size, num_boxes, num_classes, 1] + PLUGIN_ASSERT(in[1].desc.dims.nbDims == 3 || (in[1].desc.dims.nbDims == 4 && in[1].desc.dims.d[3] == 1)); + mParam.numScoreElements = in[1].desc.dims.d[1] * in[1].desc.dims.d[2]; + mParam.numClasses = in[1].desc.dims.d[2]; + + // When pad per class is set, the total output boxes size may need to be reduced. + // This operation is also done in getOutputDimension(), but for dynamic shapes, the + // numOutputBoxes param can't be set until the number of classes is fully known here. + if (mParam.padOutputBoxesPerClass && mParam.numOutputBoxesPerClass > 0) + { + if (mParam.numOutputBoxesPerClass * mParam.numClasses < mParam.numOutputBoxes) + { + mParam.numOutputBoxes = mParam.numOutputBoxesPerClass * mParam.numClasses; + } + } + + // Shape of boxes input should be + // [batch_size, num_boxes, 4] or [batch_size, num_boxes, 1, 4] or [batch_size, num_boxes, num_classes, 4] + PLUGIN_ASSERT(in[0].desc.dims.nbDims == 3 || in[0].desc.dims.nbDims == 4); + if (in[0].desc.dims.nbDims == 3) + { + PLUGIN_ASSERT(in[0].desc.dims.d[2] == 4); + mParam.shareLocation = true; + mParam.numBoxElements = in[0].desc.dims.d[1] * in[0].desc.dims.d[2]; + } + else + { + mParam.shareLocation = (in[0].desc.dims.d[2] == 1); + PLUGIN_ASSERT(in[0].desc.dims.d[2] == mParam.numClasses || mParam.shareLocation); + PLUGIN_ASSERT(in[0].desc.dims.d[3] == 4); + mParam.numBoxElements = in[0].desc.dims.d[1] * in[0].desc.dims.d[2] * in[0].desc.dims.d[3]; + } + mParam.numAnchors = in[0].desc.dims.d[1]; + + if (nbInputs == 2) + { + // Only two inputs are used, disable the fused box decoder + mParam.boxDecoder = false; + } + if (nbInputs == 3) + { + // All three inputs are used, enable the box decoder + // Shape of anchors input should be + // Constant shape: [1, numAnchors, 4] or [batch_size, numAnchors, 4] + PLUGIN_ASSERT(in[2].desc.dims.nbDims == 3); + mParam.boxDecoder = true; + mParam.shareAnchors = (in[2].desc.dims.d[0] == 1); + } + } + catch (std::exception const& e) + { + caughtError(e); + } +} + +size_t EfficientNMSPlugin::getWorkspaceSize( + PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept +{ + int32_t batchSize = inputs[1].dims.d[0]; + int32_t numScoreElements = inputs[1].dims.d[1] * inputs[1].dims.d[2]; + int32_t numClasses = inputs[1].dims.d[2]; + return EfficientNMSWorkspaceSize(batchSize, numScoreElements, numClasses, mParam.datatype); +} + +int32_t EfficientNMSPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +{ + try + { + mParam.batchSize = inputDesc[0].dims.d[0]; + + if (mParam.outputONNXIndices) + { + // ONNX NonMaxSuppression Op Support + void const* const boxesInput = inputs[0]; + void const* const scoresInput = inputs[1]; + + void* nmsIndicesOutput = outputs[0]; + + return EfficientNMSInference(mParam, boxesInput, scoresInput, nullptr, nullptr, nullptr, nullptr, nullptr, + nmsIndicesOutput, workspace, stream); + } + + // Standard NMS Operation + void const* const boxesInput = inputs[0]; + void const* const scoresInput = inputs[1]; + void const* const anchorsInput = mParam.boxDecoder ? inputs[2] : nullptr; + + void* numDetectionsOutput = outputs[0]; + void* nmsBoxesOutput = outputs[1]; + void* nmsScoresOutput = outputs[2]; + void* nmsClassesOutput = outputs[3]; + + return EfficientNMSInference(mParam, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, nmsBoxesOutput, + nmsScoresOutput, nmsClassesOutput, nullptr, workspace, stream); + } + catch (std::exception const& e) + { + caughtError(e); + } + return -1; +} + +// Standard NMS Plugin Operation + +EfficientNMSPluginCreator::EfficientNMSPluginCreator() + : mParam{} +{ + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("max_output_boxes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("background_class", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("score_activation", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("class_agnostic", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("box_coding", nullptr, PluginFieldType::kINT32, 1)); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* EfficientNMSPluginCreator::getPluginName() const noexcept +{ + return kEFFICIENT_NMS_PLUGIN_NAME; +} + +char const* EfficientNMSPluginCreator::getPluginVersion() const noexcept +{ + return kEFFICIENT_NMS_PLUGIN_VERSION; +} + +PluginFieldCollection const* EfficientNMSPluginCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2DynamicExt* EfficientNMSPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept +{ + try + { + PLUGIN_VALIDATE(fc != nullptr); + PluginField const* fields = fc->fields; + PLUGIN_VALIDATE(fields != nullptr); + plugin::validateRequiredAttributesExist({"score_threshold", "iou_threshold", "max_output_boxes", + "background_class", "score_activation", "box_coding"}, + fc); + for (int32_t i{0}; i < fc->nbFields; ++i) + { + char const* attrName = fields[i].name; + if (!strcmp(attrName, "score_threshold")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32); + auto const scoreThreshold = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(scoreThreshold >= 0.0F); + mParam.scoreThreshold = scoreThreshold; + } + if (!strcmp(attrName, "iou_threshold")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32); + auto const iouThreshold = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(iouThreshold > 0.0F); + mParam.iouThreshold = iouThreshold; + } + if (!strcmp(attrName, "max_output_boxes")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); + auto const numOutputBoxes = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(numOutputBoxes > 0); + mParam.numOutputBoxes = numOutputBoxes; + } + if (!strcmp(attrName, "background_class")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); + mParam.backgroundClass = *(static_cast(fields[i].data)); + } + if (!strcmp(attrName, "score_activation")) + { + auto const scoreSigmoid = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(scoreSigmoid == 0 || scoreSigmoid == 1); + mParam.scoreSigmoid = static_cast(scoreSigmoid); + } + if (!strcmp(attrName, "class_agnostic")) + { + auto const classAgnostic = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(classAgnostic == 0 || classAgnostic == 1); + mParam.classAgnostic = static_cast(classAgnostic); + } + if (!strcmp(attrName, "box_coding")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); + auto const boxCoding = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(boxCoding == 0 || boxCoding == 1); + mParam.boxCoding = boxCoding; + } + } + + auto* plugin = new EfficientNMSPlugin(mParam); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +IPluginV2DynamicExt* EfficientNMSPluginCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept +{ + try + { + // This object will be deleted when the network is destroyed, which will + // call EfficientNMSPlugin::destroy() + auto* plugin = new EfficientNMSPlugin(serialData, serialLength); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +// ONNX NonMaxSuppression Op Compatibility + +EfficientNMSONNXPluginCreator::EfficientNMSONNXPluginCreator() + : mParam{} +{ + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("max_output_boxes_per_class", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("center_point_box", nullptr, PluginFieldType::kINT32, 1)); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* EfficientNMSONNXPluginCreator::getPluginName() const noexcept +{ + return kEFFICIENT_NMS_ONNX_PLUGIN_NAME; +} + +char const* EfficientNMSONNXPluginCreator::getPluginVersion() const noexcept +{ + return kEFFICIENT_NMS_ONNX_PLUGIN_VERSION; +} + +PluginFieldCollection const* EfficientNMSONNXPluginCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2DynamicExt* EfficientNMSONNXPluginCreator::createPlugin( + char const* name, PluginFieldCollection const* fc) noexcept +{ + try + { + PluginField const* fields = fc->fields; + for (int32_t i = 0; i < fc->nbFields; ++i) + { + char const* attrName = fields[i].name; + if (!strcmp(attrName, "score_threshold")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32); + mParam.scoreThreshold = *(static_cast(fields[i].data)); + } + if (!strcmp(attrName, "iou_threshold")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32); + mParam.iouThreshold = *(static_cast(fields[i].data)); + } + if (!strcmp(attrName, "max_output_boxes_per_class")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); + mParam.numOutputBoxesPerClass = *(static_cast(fields[i].data)); + } + if (!strcmp(attrName, "center_point_box")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); + mParam.boxCoding = *(static_cast(fields[i].data)); + } + } + + // This enables ONNX compatibility mode + mParam.outputONNXIndices = true; + mParam.numOutputBoxes = mParam.numOutputBoxesPerClass; + + auto* plugin = new EfficientNMSPlugin(mParam); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +IPluginV2DynamicExt* EfficientNMSONNXPluginCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept +{ + try + { + // This object will be deleted when the network is destroyed, which will + // call EfficientNMSPlugin::destroy() + auto* plugin = new EfficientNMSPlugin(serialData, serialLength); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} diff --git a/plugin/yoloNMSPlugin/yoloNMSPlugin.h b/plugin/yoloNMSPlugin/yoloNMSPlugin.h new file mode 100644 index 00000000..afceec01 --- /dev/null +++ b/plugin/yoloNMSPlugin/yoloNMSPlugin.h @@ -0,0 +1,122 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRT_EFFICIENT_NMS_PLUGIN_H +#define TRT_EFFICIENT_NMS_PLUGIN_H + +#include + +#include "common/plugin.h" +#include "efficientNMSPlugin/efficientNMSParameters.h" + +namespace nvinfer1 +{ +namespace plugin +{ + +class EfficientNMSPlugin : public IPluginV2DynamicExt +{ +public: + explicit EfficientNMSPlugin(EfficientNMSParameters param); + EfficientNMSPlugin(void const* data, size_t length); + ~EfficientNMSPlugin() override = default; + + // IPluginV2 methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* libNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext methods + nvinfer1::DataType getOutputDataType( + int32_t index, nvinfer1::DataType const* inputType, int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt methods + IPluginV2DynamicExt* clone() const noexcept override; + DimsExprs getOutputDimensions( + int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination( + int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + void configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept override; + int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + +protected: + EfficientNMSParameters mParam{}; + bool initialized{false}; + std::string mNamespace; + +private: + void deserialize(int8_t const* data, size_t length); +}; + +// Standard NMS Plugin Operation +class EfficientNMSPluginCreator : public nvinfer1::pluginInternal::BaseCreator +{ +public: + EfficientNMSPluginCreator(); + ~EfficientNMSPluginCreator() override = default; + + char const* getPluginName() const noexcept override; + char const* getPluginVersion() const noexcept override; + PluginFieldCollection const* getFieldNames() noexcept override; + + IPluginV2DynamicExt* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override; + IPluginV2DynamicExt* deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept override; + +protected: + PluginFieldCollection mFC; + EfficientNMSParameters mParam; + std::vector mPluginAttributes; + std::string mPluginName; +}; + +// ONNX NonMaxSuppression Op Compatibility +class EfficientNMSONNXPluginCreator : public nvinfer1::pluginInternal::BaseCreator +{ +public: + EfficientNMSONNXPluginCreator(); + ~EfficientNMSONNXPluginCreator() override = default; + + char const* getPluginName() const noexcept override; + char const* getPluginVersion() const noexcept override; + PluginFieldCollection const* getFieldNames() noexcept override; + + IPluginV2DynamicExt* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override; + IPluginV2DynamicExt* deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept override; + +protected: + PluginFieldCollection mFC; + EfficientNMSParameters mParam; + std::vector mPluginAttributes; + std::string mPluginName; +}; + +} // namespace plugin +} // namespace nvinfer1 + +#endif // TRT_EFFICIENT_NMS_PLUGIN_H From f0e85ba522b7a7d17dc19e7ec991b5dc0a96660e Mon Sep 17 00:00:00 2001 From: Levi Pereira Date: Thu, 9 May 2024 20:16:23 -0300 Subject: [PATCH 2/6] Added Plugin yoloNMSPlugin Signed-off-by: Levi Pereira --- plugin/CMakeLists.txt | 1 + plugin/api/inferPlugin.cpp | 3 + .../YoloNMSPlugin_PluginConfig.yaml | 5 +- plugin/yoloNMSPlugin/yoloNMSInference.cu | 90 +++++------ plugin/yoloNMSPlugin/yoloNMSInference.cuh | 4 +- plugin/yoloNMSPlugin/yoloNMSInference.h | 10 +- plugin/yoloNMSPlugin/yoloNMSParameters.h | 6 +- plugin/yoloNMSPlugin/yoloNMSPlugin.cpp | 142 +++++++++--------- plugin/yoloNMSPlugin/yoloNMSPlugin.h | 36 ++--- 9 files changed, 156 insertions(+), 141 deletions(-) diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index 393d4891..15f649b0 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -69,6 +69,7 @@ set(PLUGIN_LISTS specialSlicePlugin splitPlugin voxelGeneratorPlugin + yoloNMSPlugin ) # Add BERT sources if ${BERT_GENCODES} was populated diff --git a/plugin/api/inferPlugin.cpp b/plugin/api/inferPlugin.cpp index b55f9388..da40766c 100644 --- a/plugin/api/inferPlugin.cpp +++ b/plugin/api/inferPlugin.cpp @@ -53,6 +53,7 @@ #include "specialSlicePlugin/specialSlicePlugin.h" #include "splitPlugin/split.h" #include "voxelGeneratorPlugin/voxelGenerator.h" +#include "yoloNMSPlugin/yoloNMSPlugin.h" #include #include @@ -218,6 +219,8 @@ extern "C" initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); return true; } } // extern "C" diff --git a/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginConfig.yaml b/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginConfig.yaml index 0a7ce414..7a601cad 100644 --- a/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginConfig.yaml +++ b/plugin/yoloNMSPlugin/YoloNMSPlugin_PluginConfig.yaml @@ -1,5 +1,5 @@ --- -name: EfficientNMS_TRT +name: YoloNMS_TRT interface: "IPluginV2DynamicExt" versions: "1": @@ -12,6 +12,7 @@ versions: - detection_boxes - detection_scores - detection_classes + - detection_index attributes: - score_threshold - iou_threshold @@ -65,7 +66,7 @@ versions: - background_class - score_activation - box_coding - golden_io_path: "plugin/efficientNMSPlugin/EfficientNMSPlugin_PluginGoldenIO.json" + golden_io_path: "plugin/yoloNMSPlugin/YoloNMSPlugin_PluginGoldenIO.json" abs_tol: 1e-5 rel_tol: 1e-5 configs: diff --git a/plugin/yoloNMSPlugin/yoloNMSInference.cu b/plugin/yoloNMSPlugin/yoloNMSInference.cu index ba99cb56..61822815 100644 --- a/plugin/yoloNMSPlugin/yoloNMSInference.cu +++ b/plugin/yoloNMSPlugin/yoloNMSInference.cu @@ -19,8 +19,8 @@ #include "cub/cub.cuh" #include "cuda_runtime_api.h" -#include "efficientNMSInference.cuh" -#include "efficientNMSInference.h" +#include "yoloNMSInference.cuh" +#include "yoloNMSInference.h" #define NMS_TILES 5 @@ -28,7 +28,7 @@ using namespace nvinfer1; using namespace nvinfer1::plugin; template -__device__ float IOU(EfficientNMSParameters param, BoxCorner box1, BoxCorner box2) +__device__ float IOU(YoloNMSParameters param, BoxCorner box1, BoxCorner box2) { // Regardless of the selected box coding, IOU is always performed in BoxCorner coding. // The boxes are copied so that they can be reordered without affecting the originals. @@ -50,7 +50,7 @@ __device__ float IOU(EfficientNMSParameters param, BoxCorner box1, BoxCorner< } template -__device__ BoxCorner DecodeBoxes(EfficientNMSParameters param, int boxIdx, int anchorIdx, +__device__ BoxCorner DecodeBoxes(YoloNMSParameters param, int boxIdx, int anchorIdx, const Tb* __restrict__ boxesInput, const Tb* __restrict__ anchorsInput) { // The inputs will be in the selected coding format, as well as the decoding function. But the decoded box @@ -67,7 +67,7 @@ __device__ BoxCorner DecodeBoxes(EfficientNMSParameters param, int boxIdx, in } template -__device__ void MapNMSData(EfficientNMSParameters param, int idx, int imageIdx, const Tb* __restrict__ boxesInput, +__device__ void MapNMSData(YoloNMSParameters param, int idx, int imageIdx, const Tb* __restrict__ boxesInput, const Tb* __restrict__ anchorsInput, const int* __restrict__ topClassData, const int* __restrict__ topAnchorsData, const int* __restrict__ topNumData, const T* __restrict__ sortedScoresData, const int* __restrict__ sortedIndexData, T& scoreMap, int& classMap, BoxCorner& boxMap, int& boxIdxMap) @@ -116,9 +116,10 @@ __device__ void MapNMSData(EfficientNMSParameters param, int idx, int imageIdx, } template -__device__ void WriteNMSResult(EfficientNMSParameters param, int* __restrict__ numDetectionsOutput, +__device__ void WriteNMSResult(YoloNMSParameters param, int* __restrict__ numDetectionsOutput, T* __restrict__ nmsScoresOutput, int* __restrict__ nmsClassesOutput, BoxCorner* __restrict__ nmsBoxesOutput, - T threadScore, int threadClass, BoxCorner threadBox, int imageIdx, unsigned int resultsCounter) + int* __restrict__ nmsIndicesOutput, T threadScore, int threadClass, BoxCorner threadBox, int imageIdx, + unsigned int resultsCounter, int boxIdxMap) { int outputIdx = imageIdx * param.numOutputBoxes + resultsCounter - 1; if (param.scoreSigmoid) @@ -143,9 +144,13 @@ __device__ void WriteNMSResult(EfficientNMSParameters param, int* __restrict__ n nmsBoxesOutput[outputIdx] = threadBox; } numDetectionsOutput[imageIdx] = resultsCounter; + + int index = boxIdxMap % param.numAnchors; + + nmsIndicesOutput[outputIdx] = index; } -__device__ void WriteONNXResult(EfficientNMSParameters param, int* outputIndexData, int* __restrict__ nmsIndicesOutput, +__device__ void WriteONNXResult(YoloNMSParameters param, int* outputIndexData, int* __restrict__ nmsIndicesOutput, int imageIdx, int threadClass, int boxIdxMap) { int index = boxIdxMap % param.numAnchors; @@ -155,7 +160,7 @@ __device__ void WriteONNXResult(EfficientNMSParameters param, int* outputIndexDa nmsIndicesOutput[idx * 3 + 2] = index; } -__global__ void PadONNXResult(EfficientNMSParameters param, int* outputIndexData, int* __restrict__ nmsIndicesOutput) +__global__ void PadONNXResult(YoloNMSParameters param, int* outputIndexData, int* __restrict__ nmsIndicesOutput) { if (threadIdx.x > 0) { @@ -175,7 +180,7 @@ __global__ void PadONNXResult(EfficientNMSParameters param, int* outputIndexData } template -__global__ void EfficientNMS(EfficientNMSParameters param, const int* topNumData, int* outputIndexData, +__global__ void YoloNMS(YoloNMSParameters param, const int* topNumData, int* outputIndexData, int* outputClassData, const int* sortedIndexData, const T* __restrict__ sortedScoresData, const int* __restrict__ topClassData, const int* __restrict__ topAnchorsData, const Tb* __restrict__ boxesInput, const Tb* __restrict__ anchorsInput, int* __restrict__ numDetectionsOutput, T* __restrict__ nmsScoresOutput, @@ -275,8 +280,8 @@ __global__ void EfficientNMS(EfficientNMSParameters param, const int* topNumData else { WriteNMSResult(param, numDetectionsOutput, nmsScoresOutput, nmsClassesOutput, - nmsBoxesOutput, threadScore[tile], threadClass[tile], threadBox[tile], imageIdx, - resultsCounter); + nmsBoxesOutput, nmsIndicesOutput, threadScore[tile], threadClass[tile], threadBox[tile], imageIdx, + resultsCounter, boxIdxMap[tile]); } } } @@ -337,7 +342,7 @@ __global__ void EfficientNMS(EfficientNMSParameters param, const int* topNumData } template -cudaError_t EfficientNMSLauncher(EfficientNMSParameters& param, int* topNumData, int* outputIndexData, +cudaError_t YoloNMSLauncher(YoloNMSParameters& param, int* topNumData, int* outputIndexData, int* outputClassData, int* sortedIndexData, T* sortedScoresData, int* topClassData, int* topAnchorsData, const void* boxesInput, const void* anchorsInput, int* numDetectionsOutput, T* nmsScoresOutput, int* nmsClassesOutput, int* nmsIndicesOutput, void* nmsBoxesOutput, cudaStream_t stream) @@ -357,7 +362,7 @@ cudaError_t EfficientNMSLauncher(EfficientNMSParameters& param, int* topNumData, if (param.boxCoding == 0) { - EfficientNMS><<>>(param, topNumData, outputIndexData, + YoloNMS><<>>(param, topNumData, outputIndexData, outputClassData, sortedIndexData, sortedScoresData, topClassData, topAnchorsData, (BoxCorner*) boxesInput, (BoxCorner*) anchorsInput, numDetectionsOutput, nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, (BoxCorner*) nmsBoxesOutput); @@ -365,7 +370,7 @@ cudaError_t EfficientNMSLauncher(EfficientNMSParameters& param, int* topNumData, else if (param.boxCoding == 1) { // Note that nmsBoxesOutput is always coded as BoxCorner, regardless of the input coding type. - EfficientNMS><<>>(param, topNumData, outputIndexData, + YoloNMS><<>>(param, topNumData, outputIndexData, outputClassData, sortedIndexData, sortedScoresData, topClassData, topAnchorsData, (BoxCenterSize*) boxesInput, (BoxCenterSize*) anchorsInput, numDetectionsOutput, nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, (BoxCorner*) nmsBoxesOutput); @@ -379,7 +384,7 @@ cudaError_t EfficientNMSLauncher(EfficientNMSParameters& param, int* topNumData, return cudaGetLastError(); } -__global__ void EfficientNMSFilterSegments(EfficientNMSParameters param, const int* __restrict__ topNumData, +__global__ void YoloNMSFilterSegments(YoloNMSParameters param, const int* __restrict__ topNumData, int* __restrict__ topOffsetsStartData, int* __restrict__ topOffsetsEndData) { int imageIdx = threadIdx.x; @@ -392,7 +397,7 @@ __global__ void EfficientNMSFilterSegments(EfficientNMSParameters param, const i } template -__global__ void EfficientNMSFilter(EfficientNMSParameters param, const T* __restrict__ scoresInput, +__global__ void YoloNMSFilter(YoloNMSParameters param, const T* __restrict__ scoresInput, int* __restrict__ topNumData, int* __restrict__ topIndexData, int* __restrict__ topAnchorsData, T* __restrict__ topScoresData, int* __restrict__ topClassData) { @@ -456,7 +461,7 @@ __global__ void EfficientNMSFilter(EfficientNMSParameters param, const T* __rest } template -__global__ void EfficientNMSDenseIndex(EfficientNMSParameters param, int* __restrict__ topNumData, +__global__ void YoloNMSDenseIndex(YoloNMSParameters param, int* __restrict__ topNumData, int* __restrict__ topIndexData, int* __restrict__ topAnchorsData, int* __restrict__ topOffsetsStartData, int* __restrict__ topOffsetsEndData, T* __restrict__ topScoresData, int* __restrict__ topClassData) { @@ -520,7 +525,7 @@ __global__ void EfficientNMSDenseIndex(EfficientNMSParameters param, int* __rest } template -cudaError_t EfficientNMSFilterLauncher(EfficientNMSParameters& param, const T* scoresInput, int* topNumData, +cudaError_t YoloNMSFilterLauncher(YoloNMSParameters& param, const T* scoresInput, int* topNumData, int* topIndexData, int* topAnchorsData, int* topOffsetsStartData, int* topOffsetsEndData, T* topScoresData, int* topClassData, cudaStream_t stream) { @@ -554,15 +559,15 @@ cudaError_t EfficientNMSFilterLauncher(EfficientNMSParameters& param, const T* s PLUGIN_CHECK_CUDA(cudaMemcpyAsync(topScoresData, scoresInput, param.batchSize * param.numScoreElements * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - EfficientNMSDenseIndex<<>>(param, topNumData, topIndexData, topAnchorsData, + YoloNMSDenseIndex<<>>(param, topNumData, topIndexData, topAnchorsData, topOffsetsStartData, topOffsetsEndData, topScoresData, topClassData); } else { - EfficientNMSFilter<<>>( + YoloNMSFilter<<>>( param, scoresInput, topNumData, topIndexData, topAnchorsData, topScoresData, topClassData); - EfficientNMSFilterSegments<<<1, param.batchSize, 0, stream>>>( + YoloNMSFilterSegments<<<1, param.batchSize, 0, stream>>>( param, topNumData, topOffsetsStartData, topOffsetsEndData); } @@ -570,7 +575,7 @@ cudaError_t EfficientNMSFilterLauncher(EfficientNMSParameters& param, const T* s } template -size_t EfficientNMSSortWorkspaceSize(int batchSize, int numScoreElements) +size_t YoloNMSSortWorkspaceSize(int batchSize, int numScoreElements) { size_t sortedWorkspaceSize = 0; cub::DoubleBuffer keysDB(nullptr, nullptr); @@ -580,7 +585,7 @@ size_t EfficientNMSSortWorkspaceSize(int batchSize, int numScoreElements) return sortedWorkspaceSize; } -size_t EfficientNMSWorkspaceSize(int batchSize, int numScoreElements, int numClasses, DataType datatype) +size_t YoloNMSWorkspaceSize(int batchSize, int numScoreElements, int numClasses, DataType datatype) { size_t total = 0; const size_t align = 256; @@ -605,12 +610,12 @@ size_t EfficientNMSWorkspaceSize(int batchSize, int numScoreElements, int numCla // Sort Workspace if (datatype == DataType::kHALF) { - size = EfficientNMSSortWorkspaceSize<__half>(batchSize, numScoreElements); + size = YoloNMSSortWorkspaceSize<__half>(batchSize, numScoreElements); total += size + (size % align ? align - (size % align) : 0); } else if (datatype == DataType::kFLOAT) { - size = EfficientNMSSortWorkspaceSize(batchSize, numScoreElements); + size = YoloNMSSortWorkspaceSize(batchSize, numScoreElements); total += size + (size % align ? align - (size % align) : 0); } @@ -618,7 +623,7 @@ size_t EfficientNMSWorkspaceSize(int batchSize, int numScoreElements, int numCla } template -T* EfficientNMSWorkspace(void* workspace, size_t& offset, size_t elements) +T* YoloNMSWorkspace(void* workspace, size_t& offset, size_t elements) { T* buffer = (T*) ((size_t) workspace + offset); size_t align = 256; @@ -629,7 +634,7 @@ T* EfficientNMSWorkspace(void* workspace, size_t& offset, size_t elements) } template -pluginStatus_t EfficientNMSDispatch(EfficientNMSParameters param, const void* boxesInput, const void* scoresInput, +pluginStatus_t YoloNMSDispatch(YoloNMSParameters param, const void* boxesInput, const void* scoresInput, const void* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput, void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream) { @@ -644,6 +649,7 @@ pluginStatus_t EfficientNMSDispatch(EfficientNMSParameters param, const void* bo CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream), STATUS_FAILURE); CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream), STATUS_FAILURE); CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsIndicesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); } // Empty Inputs @@ -655,7 +661,7 @@ pluginStatus_t EfficientNMSDispatch(EfficientNMSParameters param, const void* bo // Counters Workspace size_t workspaceOffset = 0; int countersTotalSize = (3 + 1 + param.numClasses) * param.batchSize; - int* topNumData = EfficientNMSWorkspace(workspace, workspaceOffset, countersTotalSize); + int* topNumData = YoloNMSWorkspace(workspace, workspaceOffset, countersTotalSize); int* topOffsetsStartData = topNumData + param.batchSize; int* topOffsetsEndData = topNumData + 2 * param.batchSize; int* outputIndexData = topNumData + 3 * param.batchSize; @@ -666,23 +672,23 @@ pluginStatus_t EfficientNMSDispatch(EfficientNMSParameters param, const void* bo // Other Buffers Workspace int* topIndexData - = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + = YoloNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); int* topClassData - = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + = YoloNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); int* topAnchorsData - = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + = YoloNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); int* sortedIndexData - = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); - T* topScoresData = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + = YoloNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + T* topScoresData = YoloNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); T* sortedScoresData - = EfficientNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); - size_t sortedWorkspaceSize = EfficientNMSSortWorkspaceSize(param.batchSize, param.numScoreElements); - char* sortedWorkspaceData = EfficientNMSWorkspace(workspace, workspaceOffset, sortedWorkspaceSize); + = YoloNMSWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + size_t sortedWorkspaceSize = YoloNMSSortWorkspaceSize(param.batchSize, param.numScoreElements); + char* sortedWorkspaceData = YoloNMSWorkspace(workspace, workspaceOffset, sortedWorkspaceSize); cub::DoubleBuffer scoresDB(topScoresData, sortedScoresData); cub::DoubleBuffer indexDB(topIndexData, sortedIndexData); // Kernels - status = EfficientNMSFilterLauncher(param, (T*) scoresInput, topNumData, topIndexData, topAnchorsData, + status = YoloNMSFilterLauncher(param, (T*) scoresInput, topNumData, topIndexData, topAnchorsData, topOffsetsStartData, topOffsetsEndData, topScoresData, topClassData, stream); CSC(status, STATUS_FAILURE); @@ -691,7 +697,7 @@ pluginStatus_t EfficientNMSDispatch(EfficientNMSParameters param, const void* bo param.scoreBits > 0 ? (10 - param.scoreBits) : 0, param.scoreBits > 0 ? 10 : sizeof(T) * 8, stream); CSC(status, STATUS_FAILURE); - status = EfficientNMSLauncher(param, topNumData, outputIndexData, outputClassData, indexDB.Current(), + status = YoloNMSLauncher(param, topNumData, outputIndexData, outputClassData, indexDB.Current(), scoresDB.Current(), topClassData, topAnchorsData, boxesInput, anchorsInput, (int*) numDetectionsOutput, (T*) nmsScoresOutput, (int*) nmsClassesOutput, (int*) nmsIndicesOutput, nmsBoxesOutput, stream); CSC(status, STATUS_FAILURE); @@ -699,14 +705,14 @@ pluginStatus_t EfficientNMSDispatch(EfficientNMSParameters param, const void* bo return STATUS_SUCCESS; } -pluginStatus_t EfficientNMSInference(EfficientNMSParameters param, const void* boxesInput, const void* scoresInput, +pluginStatus_t YoloNMSInference(YoloNMSParameters param, const void* boxesInput, const void* scoresInput, const void* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput, void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream) { if (param.datatype == DataType::kFLOAT) { param.scoreBits = -1; - return EfficientNMSDispatch(param, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, + return YoloNMSDispatch(param, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, nmsBoxesOutput, nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, workspace, stream); } else if (param.datatype == DataType::kHALF) @@ -715,7 +721,7 @@ pluginStatus_t EfficientNMSInference(EfficientNMSParameters param, const void* b { param.scoreBits = -1; } - return EfficientNMSDispatch<__half>(param, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, + return YoloNMSDispatch<__half>(param, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, nmsBoxesOutput, nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, workspace, stream); } else diff --git a/plugin/yoloNMSPlugin/yoloNMSInference.cuh b/plugin/yoloNMSPlugin/yoloNMSInference.cuh index bf12c359..4081ca85 100644 --- a/plugin/yoloNMSPlugin/yoloNMSInference.cuh +++ b/plugin/yoloNMSPlugin/yoloNMSInference.cuh @@ -15,8 +15,8 @@ * limitations under the License. */ -#ifndef TRT_EFFICIENT_NMS_INFERENCE_CUH -#define TRT_EFFICIENT_NMS_INFERENCE_CUH +#ifndef TRT_YOLO_NMS_INFERENCE_CUH +#define TRT_YOLO_NMS_INFERENCE_CUH #include diff --git a/plugin/yoloNMSPlugin/yoloNMSInference.h b/plugin/yoloNMSPlugin/yoloNMSInference.h index d9ec3192..5c2b4bb9 100644 --- a/plugin/yoloNMSPlugin/yoloNMSInference.h +++ b/plugin/yoloNMSPlugin/yoloNMSInference.h @@ -15,17 +15,17 @@ * limitations under the License. */ -#ifndef TRT_EFFICIENT_NMS_INFERENCE_H -#define TRT_EFFICIENT_NMS_INFERENCE_H +#ifndef TRT_YOLO_NMS_INFERENCE_H +#define TRT_YOLO_NMS_INFERENCE_H #include "common/plugin.h" -#include "efficientNMSParameters.h" +#include "yoloNMSParameters.h" -size_t EfficientNMSWorkspaceSize( +size_t YoloNMSWorkspaceSize( int32_t batchSize, int32_t numScoreElements, int32_t numClasses, nvinfer1::DataType datatype); -pluginStatus_t EfficientNMSInference(nvinfer1::plugin::EfficientNMSParameters param, void const* boxesInput, +pluginStatus_t YoloNMSInference(nvinfer1::plugin::YoloNMSParameters param, void const* boxesInput, void const* scoresInput, void const* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput, void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream); diff --git a/plugin/yoloNMSPlugin/yoloNMSParameters.h b/plugin/yoloNMSPlugin/yoloNMSParameters.h index 89829089..58430206 100644 --- a/plugin/yoloNMSPlugin/yoloNMSParameters.h +++ b/plugin/yoloNMSPlugin/yoloNMSParameters.h @@ -15,8 +15,8 @@ * limitations under the License. */ -#ifndef TRT_EFFICIENT_NMS_PARAMETERS_H -#define TRT_EFFICIENT_NMS_PARAMETERS_H +#ifndef TRT_YOLO_NMS_PARAMETERS_H +#define TRT_YOLO_NMS_PARAMETERS_H #include "common/plugin.h" @@ -25,7 +25,7 @@ namespace nvinfer1 namespace plugin { -struct EfficientNMSParameters +struct YoloNMSParameters { // Related to NMS Options float iouThreshold = 0.5F; diff --git a/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp b/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp index 2f5d428b..aab6554d 100644 --- a/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp +++ b/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp @@ -15,51 +15,51 @@ * limitations under the License. */ -#include "efficientNMSPlugin.h" -#include "efficientNMSInference.h" +#include "yoloNMSPlugin.h" +#include "yoloNMSInference.h" using namespace nvinfer1; -using nvinfer1::plugin::EfficientNMSPlugin; -using nvinfer1::plugin::EfficientNMSParameters; -using nvinfer1::plugin::EfficientNMSPluginCreator; -using nvinfer1::plugin::EfficientNMSONNXPluginCreator; +using nvinfer1::plugin::YoloNMSPlugin; +using nvinfer1::plugin::YoloNMSParameters; +using nvinfer1::plugin::YoloNMSPluginCreator; +using nvinfer1::plugin::YoloNMSONNXPluginCreator; namespace { -char const* const kEFFICIENT_NMS_PLUGIN_VERSION{"1"}; -char const* const kEFFICIENT_NMS_PLUGIN_NAME{"EfficientNMS_TRT"}; -char const* const kEFFICIENT_NMS_ONNX_PLUGIN_VERSION{"1"}; -char const* const kEFFICIENT_NMS_ONNX_PLUGIN_NAME{"EfficientNMS_ONNX_TRT"}; +char const* const kYOLO_NMS_PLUGIN_VERSION{"1"}; +char const* const kYOLO_NMS_PLUGIN_NAME{"YOLO_NMS_TRT"}; +char const* const kYOLO_NMS_ONNX_PLUGIN_VERSION{"1"}; +char const* const kYOLO_NMS_ONNX_PLUGIN_NAME{"YOLO_NMS_ONNX_TRT"}; } // namespace -EfficientNMSPlugin::EfficientNMSPlugin(EfficientNMSParameters param) +YoloNMSPlugin::YoloNMSPlugin(YoloNMSParameters param) : mParam(std::move(param)) { } -EfficientNMSPlugin::EfficientNMSPlugin(void const* data, size_t length) +YoloNMSPlugin::YoloNMSPlugin(void const* data, size_t length) { deserialize(static_cast(data), length); } -void EfficientNMSPlugin::deserialize(int8_t const* data, size_t length) +void YoloNMSPlugin::deserialize(int8_t const* data, size_t length) { auto const* d{data}; - mParam = read(d); + mParam = read(d); PLUGIN_VALIDATE(d == data + length); } -char const* EfficientNMSPlugin::getPluginType() const noexcept +char const* YoloNMSPlugin::getPluginType() const noexcept { - return kEFFICIENT_NMS_PLUGIN_NAME; + return kYOLO_NMS_PLUGIN_NAME; } -char const* EfficientNMSPlugin::getPluginVersion() const noexcept +char const* YoloNMSPlugin::getPluginVersion() const noexcept { - return kEFFICIENT_NMS_PLUGIN_VERSION; + return kYOLO_NMS_PLUGIN_VERSION; } -int32_t EfficientNMSPlugin::getNbOutputs() const noexcept +int32_t YoloNMSPlugin::getNbOutputs() const noexcept { if (mParam.outputONNXIndices) { @@ -68,10 +68,10 @@ int32_t EfficientNMSPlugin::getNbOutputs() const noexcept } // Standard Plugin Implementation - return 4; + return 5; } -int32_t EfficientNMSPlugin::initialize() noexcept +int32_t YoloNMSPlugin::initialize() noexcept { if (!initialized) { @@ -94,26 +94,26 @@ int32_t EfficientNMSPlugin::initialize() noexcept return STATUS_SUCCESS; } -void EfficientNMSPlugin::terminate() noexcept {} +void YoloNMSPlugin::terminate() noexcept {} -size_t EfficientNMSPlugin::getSerializationSize() const noexcept +size_t YoloNMSPlugin::getSerializationSize() const noexcept { - return sizeof(EfficientNMSParameters); + return sizeof(YoloNMSParameters); } -void EfficientNMSPlugin::serialize(void* buffer) const noexcept +void YoloNMSPlugin::serialize(void* buffer) const noexcept { char *d = reinterpret_cast(buffer), *a = d; write(d, mParam); PLUGIN_ASSERT(d == a + getSerializationSize()); } -void EfficientNMSPlugin::destroy() noexcept +void YoloNMSPlugin::destroy() noexcept { delete this; } -void EfficientNMSPlugin::setPluginNamespace(char const* pluginNamespace) noexcept +void YoloNMSPlugin::setPluginNamespace(char const* pluginNamespace) noexcept { try { @@ -125,12 +125,12 @@ void EfficientNMSPlugin::setPluginNamespace(char const* pluginNamespace) noexcep } } -char const* EfficientNMSPlugin::getPluginNamespace() const noexcept +char const* YoloNMSPlugin::getPluginNamespace() const noexcept { return mNamespace.c_str(); } -nvinfer1::DataType EfficientNMSPlugin::getOutputDataType( +nvinfer1::DataType YoloNMSPlugin::getOutputDataType( int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept { if (mParam.outputONNXIndices) @@ -140,7 +140,7 @@ nvinfer1::DataType EfficientNMSPlugin::getOutputDataType( } // On standard NMS, num_detections and detection_classes use integer outputs - if (index == 0 || index == 3) + if (index == 0 || index == 3 || index == 4) { return nvinfer1::DataType::kINT32; } @@ -148,11 +148,11 @@ nvinfer1::DataType EfficientNMSPlugin::getOutputDataType( return inputTypes[0]; } -IPluginV2DynamicExt* EfficientNMSPlugin::clone() const noexcept +IPluginV2DynamicExt* YoloNMSPlugin::clone() const noexcept { try { - auto* plugin = new EfficientNMSPlugin(mParam); + auto* plugin = new YoloNMSPlugin(mParam); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } @@ -163,7 +163,7 @@ IPluginV2DynamicExt* EfficientNMSPlugin::clone() const noexcept return nullptr; } -DimsExprs EfficientNMSPlugin::getOutputDimensions( +DimsExprs YoloNMSPlugin::getOutputDimensions( int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept { try @@ -197,7 +197,7 @@ DimsExprs EfficientNMSPlugin::getOutputDimensions( else { // Standard NMS - PLUGIN_ASSERT(outputIndex >= 0 && outputIndex <= 3); + PLUGIN_ASSERT(outputIndex >= 0 && outputIndex <= 4); // num_detections if (outputIndex == 0) @@ -216,7 +216,8 @@ DimsExprs EfficientNMSPlugin::getOutputDimensions( } // detection_scores: outputIndex == 2 // detection_classes: outputIndex == 3 - else if (outputIndex == 2 || outputIndex == 3) + // detection_indices: outputIndex == 4 + else if (outputIndex == 2 || outputIndex == 3 || outputIndex == 4) { out_dim.nbDims = 2; out_dim.d[0] = inputs[0].d[0]; @@ -233,7 +234,7 @@ DimsExprs EfficientNMSPlugin::getOutputDimensions( return DimsExprs{}; } -bool EfficientNMSPlugin::supportsFormatCombination( +bool YoloNMSPlugin::supportsFormatCombination( int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept { if (inOut[pos].format != PluginFormat::kLINEAR) @@ -258,19 +259,19 @@ bool EfficientNMSPlugin::supportsFormatCombination( } PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); - PLUGIN_ASSERT(nbOutputs == 4); + PLUGIN_ASSERT(nbOutputs == 5); if (nbInputs == 2) { - PLUGIN_ASSERT(0 <= pos && pos <= 5); + PLUGIN_ASSERT(0 <= pos && pos <= 6); } if (nbInputs == 3) { - PLUGIN_ASSERT(0 <= pos && pos <= 6); + PLUGIN_ASSERT(0 <= pos && pos <= 7); } // num_detections and detection_classes output: int32_t int32_t const posOut = pos - nbInputs; - if (posOut == 0 || posOut == 3) + if (posOut == 0 || posOut == 3 || posOut == 4) { return inOut[pos].type == DataType::kINT32 && inOut[pos].format == PluginFormat::kLINEAR; } @@ -280,7 +281,7 @@ bool EfficientNMSPlugin::supportsFormatCombination( && (inOut[0].type == inOut[pos].type); } -void EfficientNMSPlugin::configurePlugin( +void YoloNMSPlugin::configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept { try @@ -298,7 +299,7 @@ void EfficientNMSPlugin::configurePlugin( // If two inputs: [0] boxes, [1] scores // If three inputs: [0] boxes, [1] scores, [2] anchors PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); - PLUGIN_ASSERT(nbOutputs == 4); + PLUGIN_ASSERT(nbOutputs == 5); } mParam.datatype = in[0].desc.type; @@ -358,16 +359,16 @@ void EfficientNMSPlugin::configurePlugin( } } -size_t EfficientNMSPlugin::getWorkspaceSize( +size_t YoloNMSPlugin::getWorkspaceSize( PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { int32_t batchSize = inputs[1].dims.d[0]; int32_t numScoreElements = inputs[1].dims.d[1] * inputs[1].dims.d[2]; int32_t numClasses = inputs[1].dims.d[2]; - return EfficientNMSWorkspaceSize(batchSize, numScoreElements, numClasses, mParam.datatype); + return YoloNMSWorkspaceSize(batchSize, numScoreElements, numClasses, mParam.datatype); } -int32_t EfficientNMSPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, +int32_t YoloNMSPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { try @@ -382,7 +383,7 @@ int32_t EfficientNMSPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTen void* nmsIndicesOutput = outputs[0]; - return EfficientNMSInference(mParam, boxesInput, scoresInput, nullptr, nullptr, nullptr, nullptr, nullptr, + return YoloNMSInference(mParam, boxesInput, scoresInput, nullptr, nullptr, nullptr, nullptr, nullptr, nmsIndicesOutput, workspace, stream); } @@ -395,9 +396,10 @@ int32_t EfficientNMSPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTen void* nmsBoxesOutput = outputs[1]; void* nmsScoresOutput = outputs[2]; void* nmsClassesOutput = outputs[3]; + void* nmsIndicesOutput = outputs[4]; - return EfficientNMSInference(mParam, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, nmsBoxesOutput, - nmsScoresOutput, nmsClassesOutput, nullptr, workspace, stream); + return YoloNMSInference(mParam, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, nmsBoxesOutput, + nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, workspace, stream); } catch (std::exception const& e) { @@ -408,7 +410,7 @@ int32_t EfficientNMSPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTen // Standard NMS Plugin Operation -EfficientNMSPluginCreator::EfficientNMSPluginCreator() +YoloNMSPluginCreator::YoloNMSPluginCreator() : mParam{} { mPluginAttributes.clear(); @@ -423,22 +425,22 @@ EfficientNMSPluginCreator::EfficientNMSPluginCreator() mFC.fields = mPluginAttributes.data(); } -char const* EfficientNMSPluginCreator::getPluginName() const noexcept +char const* YoloNMSPluginCreator::getPluginName() const noexcept { - return kEFFICIENT_NMS_PLUGIN_NAME; + return kYOLO_NMS_PLUGIN_NAME; } -char const* EfficientNMSPluginCreator::getPluginVersion() const noexcept +char const* YoloNMSPluginCreator::getPluginVersion() const noexcept { - return kEFFICIENT_NMS_PLUGIN_VERSION; + return kYOLO_NMS_PLUGIN_VERSION; } -PluginFieldCollection const* EfficientNMSPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* YoloNMSPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2DynamicExt* EfficientNMSPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept +IPluginV2DynamicExt* YoloNMSPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { try { @@ -498,7 +500,7 @@ IPluginV2DynamicExt* EfficientNMSPluginCreator::createPlugin(char const* name, P } } - auto* plugin = new EfficientNMSPlugin(mParam); + auto* plugin = new YoloNMSPlugin(mParam); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } @@ -509,14 +511,14 @@ IPluginV2DynamicExt* EfficientNMSPluginCreator::createPlugin(char const* name, P return nullptr; } -IPluginV2DynamicExt* EfficientNMSPluginCreator::deserializePlugin( +IPluginV2DynamicExt* YoloNMSPluginCreator::deserializePlugin( char const* name, void const* serialData, size_t serialLength) noexcept { try { // This object will be deleted when the network is destroyed, which will - // call EfficientNMSPlugin::destroy() - auto* plugin = new EfficientNMSPlugin(serialData, serialLength); + // call YoloNMSPlugin::destroy() + auto* plugin = new YoloNMSPlugin(serialData, serialLength); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } @@ -529,7 +531,7 @@ IPluginV2DynamicExt* EfficientNMSPluginCreator::deserializePlugin( // ONNX NonMaxSuppression Op Compatibility -EfficientNMSONNXPluginCreator::EfficientNMSONNXPluginCreator() +YoloNMSONNXPluginCreator::YoloNMSONNXPluginCreator() : mParam{} { mPluginAttributes.clear(); @@ -541,22 +543,22 @@ EfficientNMSONNXPluginCreator::EfficientNMSONNXPluginCreator() mFC.fields = mPluginAttributes.data(); } -char const* EfficientNMSONNXPluginCreator::getPluginName() const noexcept +char const* YoloNMSONNXPluginCreator::getPluginName() const noexcept { - return kEFFICIENT_NMS_ONNX_PLUGIN_NAME; + return kYOLO_NMS_ONNX_PLUGIN_NAME; } -char const* EfficientNMSONNXPluginCreator::getPluginVersion() const noexcept +char const* YoloNMSONNXPluginCreator::getPluginVersion() const noexcept { - return kEFFICIENT_NMS_ONNX_PLUGIN_VERSION; + return kYOLO_NMS_ONNX_PLUGIN_VERSION; } -PluginFieldCollection const* EfficientNMSONNXPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* YoloNMSONNXPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2DynamicExt* EfficientNMSONNXPluginCreator::createPlugin( +IPluginV2DynamicExt* YoloNMSONNXPluginCreator::createPlugin( char const* name, PluginFieldCollection const* fc) noexcept { try @@ -591,7 +593,7 @@ IPluginV2DynamicExt* EfficientNMSONNXPluginCreator::createPlugin( mParam.outputONNXIndices = true; mParam.numOutputBoxes = mParam.numOutputBoxesPerClass; - auto* plugin = new EfficientNMSPlugin(mParam); + auto* plugin = new YoloNMSPlugin(mParam); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } @@ -602,14 +604,14 @@ IPluginV2DynamicExt* EfficientNMSONNXPluginCreator::createPlugin( return nullptr; } -IPluginV2DynamicExt* EfficientNMSONNXPluginCreator::deserializePlugin( +IPluginV2DynamicExt* YoloNMSONNXPluginCreator::deserializePlugin( char const* name, void const* serialData, size_t serialLength) noexcept { try { // This object will be deleted when the network is destroyed, which will - // call EfficientNMSPlugin::destroy() - auto* plugin = new EfficientNMSPlugin(serialData, serialLength); + // call YoloNMSPlugin::destroy() + auto* plugin = new YoloNMSPlugin(serialData, serialLength); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } diff --git a/plugin/yoloNMSPlugin/yoloNMSPlugin.h b/plugin/yoloNMSPlugin/yoloNMSPlugin.h index afceec01..5f2d2175 100644 --- a/plugin/yoloNMSPlugin/yoloNMSPlugin.h +++ b/plugin/yoloNMSPlugin/yoloNMSPlugin.h @@ -14,25 +14,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef TRT_EFFICIENT_NMS_PLUGIN_H -#define TRT_EFFICIENT_NMS_PLUGIN_H +#ifndef TRT_YOLO_NMS_PLUGIN_H +#define TRT_YOLO_NMS_PLUGIN_H #include #include "common/plugin.h" -#include "efficientNMSPlugin/efficientNMSParameters.h" +#include "yoloNMSPlugin/yoloNMSParameters.h" + +using namespace nvinfer1::plugin; namespace nvinfer1 { namespace plugin { -class EfficientNMSPlugin : public IPluginV2DynamicExt +class YoloNMSPlugin : public IPluginV2DynamicExt { public: - explicit EfficientNMSPlugin(EfficientNMSParameters param); - EfficientNMSPlugin(void const* data, size_t length); - ~EfficientNMSPlugin() override = default; + explicit YoloNMSPlugin(YoloNMSParameters param); + YoloNMSPlugin(void const* data, size_t length); + ~YoloNMSPlugin() override = default; // IPluginV2 methods char const* getPluginType() const noexcept override; @@ -64,7 +66,7 @@ class EfficientNMSPlugin : public IPluginV2DynamicExt void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; protected: - EfficientNMSParameters mParam{}; + YoloNMSParameters mParam{}; bool initialized{false}; std::string mNamespace; @@ -73,11 +75,11 @@ class EfficientNMSPlugin : public IPluginV2DynamicExt }; // Standard NMS Plugin Operation -class EfficientNMSPluginCreator : public nvinfer1::pluginInternal::BaseCreator +class YoloNMSPluginCreator : public nvinfer1::pluginInternal::BaseCreator { public: - EfficientNMSPluginCreator(); - ~EfficientNMSPluginCreator() override = default; + YoloNMSPluginCreator(); + ~YoloNMSPluginCreator() override = default; char const* getPluginName() const noexcept override; char const* getPluginVersion() const noexcept override; @@ -89,17 +91,17 @@ class EfficientNMSPluginCreator : public nvinfer1::pluginInternal::BaseCreator protected: PluginFieldCollection mFC; - EfficientNMSParameters mParam; + YoloNMSParameters mParam; std::vector mPluginAttributes; std::string mPluginName; }; // ONNX NonMaxSuppression Op Compatibility -class EfficientNMSONNXPluginCreator : public nvinfer1::pluginInternal::BaseCreator +class YoloNMSONNXPluginCreator : public nvinfer1::pluginInternal::BaseCreator { public: - EfficientNMSONNXPluginCreator(); - ~EfficientNMSONNXPluginCreator() override = default; + YoloNMSONNXPluginCreator(); + ~YoloNMSONNXPluginCreator() override = default; char const* getPluginName() const noexcept override; char const* getPluginVersion() const noexcept override; @@ -111,7 +113,7 @@ class EfficientNMSONNXPluginCreator : public nvinfer1::pluginInternal::BaseCreat protected: PluginFieldCollection mFC; - EfficientNMSParameters mParam; + YoloNMSParameters mParam; std::vector mPluginAttributes; std::string mPluginName; }; @@ -119,4 +121,4 @@ class EfficientNMSONNXPluginCreator : public nvinfer1::pluginInternal::BaseCreat } // namespace plugin } // namespace nvinfer1 -#endif // TRT_EFFICIENT_NMS_PLUGIN_H +#endif // TRT_YOLO_NMS_PLUGIN_H From 6b751dbcde08a2527e343fd01a5ca14153b082bd Mon Sep 17 00:00:00 2001 From: Levi Pereira Date: Thu, 9 May 2024 20:29:29 -0300 Subject: [PATCH 3/6] Updated yoloNMSPlugin README.md Signed-off-by: Levi Pereira --- plugin/yoloNMSPlugin/README.md | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/plugin/yoloNMSPlugin/README.md b/plugin/yoloNMSPlugin/README.md index b2af6231..f0f93300 100644 --- a/plugin/yoloNMSPlugin/README.md +++ b/plugin/yoloNMSPlugin/README.md @@ -1,4 +1,4 @@ -# Efficient NMS Plugin +# Yolo NMS Plugin #### Table of Contents - [Description](#description) @@ -89,6 +89,10 @@ The following four output tensors are generated: - **detection_classes:** This is a `[batch_size, max_output_boxes]` tensor of data type `int32`, containing the classes for the boxes. +- **detection_indices:** + This is a `[batch_size, max_output_boxes]` tensor of data type `int32`, containing the classes for the boxes. + + ### Parameters | Type | Parameter | Description @@ -105,9 +109,9 @@ Parameters marked with a `*` have a non-negligible effect on runtime latency. Se ## Limitations -The `EfficientNMS_ONNX_TRT` plugin's output may not always be sufficiently sized to capture all NMS-ed boxes. This is because it ignores the number of classes in the calculation of the output size (it produces an output of size `(batch_size * max_output_boxes_per_class, 3)` when in general, a tensor of size `(batch_size * max_output_boxes_per_class * num_classes, 3)`) would be required. This was a compromise made to keep the output size from growing uncontrollably since it lacks an attribute similar to `max_output_boxes` to control the number of output boxes globally. +The `YOLO_NMS_ONNX_TRT` plugin's output may not always be sufficiently sized to capture all NMS-ed boxes. This is because it ignores the number of classes in the calculation of the output size (it produces an output of size `(batch_size * max_output_boxes_per_class, 3)` when in general, a tensor of size `(batch_size * max_output_boxes_per_class * num_classes, 3)`) would be required. This was a compromise made to keep the output size from growing uncontrollably since it lacks an attribute similar to `max_output_boxes` to control the number of output boxes globally. -Due to this reason, please use TensorRT's inbuilt `INMSLayer` instead of the `EfficientNMS_ONNX_TRT` plugin wherever possible. +Due to this reason, please use TensorRT's inbuilt `INMSLayer` instead of the `YOLO_NMS_ONNX_TRT` plugin wherever possible. ## Algorithm @@ -115,17 +119,17 @@ Due to this reason, please use TensorRT's inbuilt `INMSLayer` instead of the `Ef The NMS algorithm in this plugin first filters the scores below the given `scoreThreshold`. This subset of scores is then sorted, and their corresponding boxes are then further filtered out by removing boxes that overlap each other with an IOU above the given `iouThreshold`. -The algorithm launcher and its relevant CUDA kernels are all defined in the `efficientNMSInference.cu` file. +The algorithm launcher and its relevant CUDA kernels are all defined in the `yoloNMSInference.cu` file. Specifically, the NMS algorithm does the following: -- The scores are filtered with the `score_threshold` parameter to reject any scores below the score threshold, while maintaining indexing to cross-reference these scores to their corresponding box coordinates. This is done with the `EfficientNMSFilter` CUDA kernel. +- The scores are filtered with the `score_threshold` parameter to reject any scores below the score threshold, while maintaining indexing to cross-reference these scores to their corresponding box coordinates. This is done with the `YoloNMSFilter` CUDA kernel. -- If too many elements are kept, due to a very low (or zero) score threshold, the filter operation can become a bottleneck due to the atomic operations involved. To mitigate this, a fallback kernel `EfficientNMSDenseIndex` is used instead which passes all the score elements densely packed and indexed. This method is heuristically selected only if the score threshold is less than 0.007. +- If too many elements are kept, due to a very low (or zero) score threshold, the filter operation can become a bottleneck due to the atomic operations involved. To mitigate this, a fallback kernel `YoloNMSDenseIndex` is used instead which passes all the score elements densely packed and indexed. This method is heuristically selected only if the score threshold is less than 0.007. - The selected scores that remain after filtering are sorted in descending order. The indexing is carefully handled to still maintain score to box relationships after sorting. -- After sorting, the highest 4096 scores are processed by the `EfficientNMS` CUDA kernel. This algorithm uses the index data maintained throughout the previous steps to find the boxes corresponding to the remaining scores. If the fused box decoder is being used, decoding will happen until this stage, where only the top scoring boxes need to be decoded. +- After sorting, the highest 4096 scores are processed by the `YoloNMS` CUDA kernel. This algorithm uses the index data maintained throughout the previous steps to find the boxes corresponding to the remaining scores. If the fused box decoder is being used, decoding will happen until this stage, where only the top scoring boxes need to be decoded. - The NMS kernel uses an efficient filtering algorithm that largely reduces the number of IOU overlap cross-checks between box pairs. The boxes that survive the IOU filtering finally pass through to the output results. At this stage, the sigmoid activation is applied to only the final remaining scores, if `score_activation` is enabled, thereby greatly reducing the amount of sigmoid calculations required otherwise. From 0da82349dced47cd06be12b0150c8822049bd405 Mon Sep 17 00:00:00 2001 From: Levi Pereira Date: Fri, 10 May 2024 23:57:40 -0300 Subject: [PATCH 4/6] Update README.md Signed-off-by: Levi Pereira --- plugin/yoloNMSPlugin/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/yoloNMSPlugin/README.md b/plugin/yoloNMSPlugin/README.md index f0f93300..6d91bb5b 100644 --- a/plugin/yoloNMSPlugin/README.md +++ b/plugin/yoloNMSPlugin/README.md @@ -1,4 +1,4 @@ -# Yolo NMS Plugin +# YoloNMS Plugin #### Table of Contents - [Description](#description) From 0a76d6f47245b1b32a6f4f45666e770a9ff2fbb8 Mon Sep 17 00:00:00 2001 From: Levi Pereira Date: Mon, 13 May 2024 12:03:50 -0300 Subject: [PATCH 5/6] Removed YoloNMSONNX Plugin Signed-off-by: Levi Pereira --- plugin/api/inferPlugin.cpp | 1 - plugin/yoloNMSPlugin/README.md | 8 +- plugin/yoloNMSPlugin/yoloNMSInference.cu | 67 ++----- plugin/yoloNMSPlugin/yoloNMSParameters.h | 2 +- plugin/yoloNMSPlugin/yoloNMSPlugin.cpp | 217 ++++------------------- plugin/yoloNMSPlugin/yoloNMSPlugin.h | 22 --- 6 files changed, 43 insertions(+), 274 deletions(-) diff --git a/plugin/api/inferPlugin.cpp b/plugin/api/inferPlugin.cpp index da40766c..b1d1a5d9 100644 --- a/plugin/api/inferPlugin.cpp +++ b/plugin/api/inferPlugin.cpp @@ -220,7 +220,6 @@ extern "C" initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); - initializePlugin(logger, libNamespace); return true; } } // extern "C" diff --git a/plugin/yoloNMSPlugin/README.md b/plugin/yoloNMSPlugin/README.md index 6d91bb5b..12fd88fb 100644 --- a/plugin/yoloNMSPlugin/README.md +++ b/plugin/yoloNMSPlugin/README.md @@ -18,7 +18,7 @@ This TensorRT plugin implements an efficient algorithm to perform Non Maximum Suppression for object detection networks. -This plugin is primarily intended for using with EfficientDet on TensorRT, as this network is particularly sensitive to the latencies introduced by slower NMS implementations. However, the plugin is generic enough that it will work correctly for other detections architectures, such as SSD or FasterRCNN. +This plugin is a version of EfficientNMS but it returns the indices of the detections. ## Structure @@ -107,12 +107,6 @@ The following four output tensors are generated: Parameters marked with a `*` have a non-negligible effect on runtime latency. See the [Performance Tuning](#performance-tuning) section below for more details on how to set them optimally. -## Limitations - -The `YOLO_NMS_ONNX_TRT` plugin's output may not always be sufficiently sized to capture all NMS-ed boxes. This is because it ignores the number of classes in the calculation of the output size (it produces an output of size `(batch_size * max_output_boxes_per_class, 3)` when in general, a tensor of size `(batch_size * max_output_boxes_per_class * num_classes, 3)`) would be required. This was a compromise made to keep the output size from growing uncontrollably since it lacks an attribute similar to `max_output_boxes` to control the number of output boxes globally. - -Due to this reason, please use TensorRT's inbuilt `INMSLayer` instead of the `YOLO_NMS_ONNX_TRT` plugin wherever possible. - ## Algorithm ### Process Description diff --git a/plugin/yoloNMSPlugin/yoloNMSInference.cu b/plugin/yoloNMSPlugin/yoloNMSInference.cu index 61822815..edb7ec1b 100644 --- a/plugin/yoloNMSPlugin/yoloNMSInference.cu +++ b/plugin/yoloNMSPlugin/yoloNMSInference.cu @@ -150,35 +150,6 @@ __device__ void WriteNMSResult(YoloNMSParameters param, int* __restrict__ numDet nmsIndicesOutput[outputIdx] = index; } -__device__ void WriteONNXResult(YoloNMSParameters param, int* outputIndexData, int* __restrict__ nmsIndicesOutput, - int imageIdx, int threadClass, int boxIdxMap) -{ - int index = boxIdxMap % param.numAnchors; - int idx = atomicAdd((unsigned int*) &outputIndexData[0], 1); - nmsIndicesOutput[idx * 3 + 0] = imageIdx; - nmsIndicesOutput[idx * 3 + 1] = threadClass; - nmsIndicesOutput[idx * 3 + 2] = index; -} - -__global__ void PadONNXResult(YoloNMSParameters param, int* outputIndexData, int* __restrict__ nmsIndicesOutput) -{ - if (threadIdx.x > 0) - { - return; - } - int pidx = outputIndexData[0] - 1; - if (pidx < 0) - { - return; - } - for (int idx = pidx + 1; idx < param.batchSize * param.numOutputBoxes; idx++) - { - nmsIndicesOutput[idx * 3 + 0] = nmsIndicesOutput[pidx * 3 + 0]; - nmsIndicesOutput[idx * 3 + 1] = nmsIndicesOutput[pidx * 3 + 1]; - nmsIndicesOutput[idx * 3 + 2] = nmsIndicesOutput[pidx * 3 + 2]; - } -} - template __global__ void YoloNMS(YoloNMSParameters param, const int* topNumData, int* outputIndexData, int* outputClassData, const int* sortedIndexData, const T* __restrict__ sortedScoresData, @@ -272,17 +243,11 @@ __global__ void YoloNMS(YoloNMSParameters param, const int* topNumData, int* out { // This branch is visited by one thread per iteration, so it's safe to do non-atomic increments. resultsCounter++; - if (param.outputONNXIndices) - { - WriteONNXResult( - param, outputIndexData, nmsIndicesOutput, imageIdx, threadClass[tile], boxIdxMap[tile]); - } - else - { - WriteNMSResult(param, numDetectionsOutput, nmsScoresOutput, nmsClassesOutput, - nmsBoxesOutput, nmsIndicesOutput, threadScore[tile], threadClass[tile], threadBox[tile], imageIdx, - resultsCounter, boxIdxMap[tile]); - } + + WriteNMSResult(param, numDetectionsOutput, nmsScoresOutput, nmsClassesOutput, + nmsBoxesOutput, nmsIndicesOutput, threadScore[tile], threadClass[tile], threadBox[tile], imageIdx, + resultsCounter, boxIdxMap[tile]); + } } } @@ -376,11 +341,6 @@ cudaError_t YoloNMSLauncher(YoloNMSParameters& param, int* topNumData, int* outp nmsClassesOutput, nmsIndicesOutput, (BoxCorner*) nmsBoxesOutput); } - if (param.outputONNXIndices) - { - PadONNXResult<<<1, 1, 0, stream>>>(param, outputIndexData, nmsIndicesOutput); - } - return cudaGetLastError(); } @@ -639,18 +599,11 @@ pluginStatus_t YoloNMSDispatch(YoloNMSParameters param, const void* boxesInput, void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream) { // Clear Outputs (not all elements will get overwritten by the kernels, so safer to clear everything out) - if (param.outputONNXIndices) - { - CSC(cudaMemsetAsync(nmsIndicesOutput, 0xFF, param.batchSize * param.numOutputBoxes * 3 * sizeof(int), stream), STATUS_FAILURE); - } - else - { - CSC(cudaMemsetAsync(numDetectionsOutput, 0x00, param.batchSize * sizeof(int), stream), STATUS_FAILURE); - CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream), STATUS_FAILURE); - CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream), STATUS_FAILURE); - CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); - CSC(cudaMemsetAsync(nmsIndicesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); - } + CSC(cudaMemsetAsync(numDetectionsOutput, 0x00, param.batchSize * sizeof(int), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsIndicesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); // Empty Inputs if (param.numScoreElements < 1) diff --git a/plugin/yoloNMSPlugin/yoloNMSParameters.h b/plugin/yoloNMSPlugin/yoloNMSParameters.h index 58430206..c4f59b0a 100644 --- a/plugin/yoloNMSPlugin/yoloNMSParameters.h +++ b/plugin/yoloNMSPlugin/yoloNMSParameters.h @@ -42,7 +42,7 @@ struct YoloNMSParameters // Related to NMS Internals int32_t numSelectedBoxes = 4096; int32_t scoreBits = -1; - bool outputONNXIndices = false; + // Related to Tensor Configuration // (These are set by the various plugin configuration methods, no need to define them during plugin creation.) diff --git a/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp b/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp index aab6554d..f05512ea 100644 --- a/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp +++ b/plugin/yoloNMSPlugin/yoloNMSPlugin.cpp @@ -22,14 +22,11 @@ using namespace nvinfer1; using nvinfer1::plugin::YoloNMSPlugin; using nvinfer1::plugin::YoloNMSParameters; using nvinfer1::plugin::YoloNMSPluginCreator; -using nvinfer1::plugin::YoloNMSONNXPluginCreator; namespace { char const* const kYOLO_NMS_PLUGIN_VERSION{"1"}; char const* const kYOLO_NMS_PLUGIN_NAME{"YOLO_NMS_TRT"}; -char const* const kYOLO_NMS_ONNX_PLUGIN_VERSION{"1"}; -char const* const kYOLO_NMS_ONNX_PLUGIN_NAME{"YOLO_NMS_ONNX_TRT"}; } // namespace YoloNMSPlugin::YoloNMSPlugin(YoloNMSParameters param) @@ -61,12 +58,6 @@ char const* YoloNMSPlugin::getPluginVersion() const noexcept int32_t YoloNMSPlugin::getNbOutputs() const noexcept { - if (mParam.outputONNXIndices) - { - // ONNX NonMaxSuppression Compatibility - return 1; - } - // Standard Plugin Implementation return 5; } @@ -133,12 +124,6 @@ char const* YoloNMSPlugin::getPluginNamespace() const noexcept nvinfer1::DataType YoloNMSPlugin::getOutputDataType( int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept { - if (mParam.outputONNXIndices) - { - // ONNX NMS uses an integer output - return nvinfer1::DataType::kINT32; - } - // On standard NMS, num_detections and detection_classes use integer outputs if (index == 0 || index == 3 || index == 4) { @@ -184,46 +169,36 @@ DimsExprs YoloNMSPlugin::getOutputDimensions( *exprBuilder.operation(DimensionOperation::kPROD, *numOutputBoxesPerClass, *numClasses)); } - if (mParam.outputONNXIndices) - { - // ONNX NMS - PLUGIN_ASSERT(outputIndex == 0); - // detection_indices + + // Standard NMS + PLUGIN_ASSERT(outputIndex >= 0 && outputIndex <= 4); + + // num_detections + if (outputIndex == 0) + { out_dim.nbDims = 2; - out_dim.d[0] = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[0], *numOutputBoxes); - out_dim.d[1] = exprBuilder.constant(3); + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = exprBuilder.constant(1); } - else + // detection_boxes + else if (outputIndex == 1) { - // Standard NMS - PLUGIN_ASSERT(outputIndex >= 0 && outputIndex <= 4); - - // num_detections - if (outputIndex == 0) - { - out_dim.nbDims = 2; - out_dim.d[0] = inputs[0].d[0]; - out_dim.d[1] = exprBuilder.constant(1); - } - // detection_boxes - else if (outputIndex == 1) - { - out_dim.nbDims = 3; - out_dim.d[0] = inputs[0].d[0]; - out_dim.d[1] = numOutputBoxes; - out_dim.d[2] = exprBuilder.constant(4); - } - // detection_scores: outputIndex == 2 - // detection_classes: outputIndex == 3 - // detection_indices: outputIndex == 4 - else if (outputIndex == 2 || outputIndex == 3 || outputIndex == 4) - { - out_dim.nbDims = 2; - out_dim.d[0] = inputs[0].d[0]; - out_dim.d[1] = numOutputBoxes; - } + out_dim.nbDims = 3; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = numOutputBoxes; + out_dim.d[2] = exprBuilder.constant(4); + } + // detection_scores: outputIndex == 2 + // detection_classes: outputIndex == 3 + // detection_indices: outputIndex == 4 + else if (outputIndex == 2 || outputIndex == 3 || outputIndex == 4) + { + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = numOutputBoxes; } + return out_dim; } @@ -242,22 +217,6 @@ bool YoloNMSPlugin::supportsFormatCombination( return false; } - if (mParam.outputONNXIndices) - { - PLUGIN_ASSERT(nbInputs == 2); - PLUGIN_ASSERT(nbOutputs == 1); - - // detection_indices output: int32_t - if (pos == 2) - { - return inOut[pos].type == DataType::kINT32; - } - - // boxes and scores input: fp32 or fp16 - return (inOut[pos].type == DataType::kHALF || inOut[pos].type == DataType::kFLOAT) - && (inOut[0].type == inOut[pos].type); - } - PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); PLUGIN_ASSERT(nbOutputs == 5); if (nbInputs == 2) @@ -286,21 +245,12 @@ void YoloNMSPlugin::configurePlugin( { try { - if (mParam.outputONNXIndices) - { - // Accepts two inputs - // [0] boxes, [1] scores - PLUGIN_ASSERT(nbInputs == 2); - PLUGIN_ASSERT(nbOutputs == 1); - } - else - { - // Accepts two or three inputs - // If two inputs: [0] boxes, [1] scores - // If three inputs: [0] boxes, [1] scores, [2] anchors - PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); - PLUGIN_ASSERT(nbOutputs == 5); - } + // Accepts two or three inputs + // If two inputs: [0] boxes, [1] scores + // If three inputs: [0] boxes, [1] scores, [2] anchors + PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); + PLUGIN_ASSERT(nbOutputs == 5); + mParam.datatype = in[0].desc.type; // Shape of scores input should be @@ -375,18 +325,6 @@ int32_t YoloNMSPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDe { mParam.batchSize = inputDesc[0].dims.d[0]; - if (mParam.outputONNXIndices) - { - // ONNX NonMaxSuppression Op Support - void const* const boxesInput = inputs[0]; - void const* const scoresInput = inputs[1]; - - void* nmsIndicesOutput = outputs[0]; - - return YoloNMSInference(mParam, boxesInput, scoresInput, nullptr, nullptr, nullptr, nullptr, nullptr, - nmsIndicesOutput, workspace, stream); - } - // Standard NMS Operation void const* const boxesInput = inputs[0]; void const* const scoresInput = inputs[1]; @@ -528,96 +466,3 @@ IPluginV2DynamicExt* YoloNMSPluginCreator::deserializePlugin( } return nullptr; } - -// ONNX NonMaxSuppression Op Compatibility - -YoloNMSONNXPluginCreator::YoloNMSONNXPluginCreator() - : mParam{} -{ - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back(PluginField("max_output_boxes_per_class", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("center_point_box", nullptr, PluginFieldType::kINT32, 1)); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -char const* YoloNMSONNXPluginCreator::getPluginName() const noexcept -{ - return kYOLO_NMS_ONNX_PLUGIN_NAME; -} - -char const* YoloNMSONNXPluginCreator::getPluginVersion() const noexcept -{ - return kYOLO_NMS_ONNX_PLUGIN_VERSION; -} - -PluginFieldCollection const* YoloNMSONNXPluginCreator::getFieldNames() noexcept -{ - return &mFC; -} - -IPluginV2DynamicExt* YoloNMSONNXPluginCreator::createPlugin( - char const* name, PluginFieldCollection const* fc) noexcept -{ - try - { - PluginField const* fields = fc->fields; - for (int32_t i = 0; i < fc->nbFields; ++i) - { - char const* attrName = fields[i].name; - if (!strcmp(attrName, "score_threshold")) - { - PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32); - mParam.scoreThreshold = *(static_cast(fields[i].data)); - } - if (!strcmp(attrName, "iou_threshold")) - { - PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32); - mParam.iouThreshold = *(static_cast(fields[i].data)); - } - if (!strcmp(attrName, "max_output_boxes_per_class")) - { - PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); - mParam.numOutputBoxesPerClass = *(static_cast(fields[i].data)); - } - if (!strcmp(attrName, "center_point_box")) - { - PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); - mParam.boxCoding = *(static_cast(fields[i].data)); - } - } - - // This enables ONNX compatibility mode - mParam.outputONNXIndices = true; - mParam.numOutputBoxes = mParam.numOutputBoxesPerClass; - - auto* plugin = new YoloNMSPlugin(mParam); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; - } - catch (std::exception const& e) - { - caughtError(e); - } - return nullptr; -} - -IPluginV2DynamicExt* YoloNMSONNXPluginCreator::deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept -{ - try - { - // This object will be deleted when the network is destroyed, which will - // call YoloNMSPlugin::destroy() - auto* plugin = new YoloNMSPlugin(serialData, serialLength); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; - } - catch (std::exception const& e) - { - caughtError(e); - } - return nullptr; -} diff --git a/plugin/yoloNMSPlugin/yoloNMSPlugin.h b/plugin/yoloNMSPlugin/yoloNMSPlugin.h index 5f2d2175..451bcddf 100644 --- a/plugin/yoloNMSPlugin/yoloNMSPlugin.h +++ b/plugin/yoloNMSPlugin/yoloNMSPlugin.h @@ -96,28 +96,6 @@ class YoloNMSPluginCreator : public nvinfer1::pluginInternal::BaseCreator std::string mPluginName; }; -// ONNX NonMaxSuppression Op Compatibility -class YoloNMSONNXPluginCreator : public nvinfer1::pluginInternal::BaseCreator -{ -public: - YoloNMSONNXPluginCreator(); - ~YoloNMSONNXPluginCreator() override = default; - - char const* getPluginName() const noexcept override; - char const* getPluginVersion() const noexcept override; - PluginFieldCollection const* getFieldNames() noexcept override; - - IPluginV2DynamicExt* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override; - IPluginV2DynamicExt* deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept override; - -protected: - PluginFieldCollection mFC; - YoloNMSParameters mParam; - std::vector mPluginAttributes; - std::string mPluginName; -}; - } // namespace plugin } // namespace nvinfer1 From 36d6911c03ef8395e20fddee107eea51e306fd97 Mon Sep 17 00:00:00 2001 From: Levi Pereira Date: Mon, 13 May 2024 12:10:43 -0300 Subject: [PATCH 6/6] Update README.md Signed-off-by: Levi Pereira --- plugin/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/plugin/README.md b/plugin/README.md index 2a3c68d3..4f5ade67 100644 --- a/plugin/README.md +++ b/plugin/README.md @@ -47,6 +47,7 @@ | [specialSlicePlugin](specialSlicePlugin) | SpecialSlice_TRT | 1 | | [splitPlugin](splitPlugin) | Split | 1 | | [voxelGeneratorPlugin](voxelGeneratorPlugin) | VoxelGeneratorPlugin | 1 | +| [yoloNMSPlugin](yoloNMSPlugin) | yoloNMSPlugin | 1 | ## Known Limitations