Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RT-DETR TensorRT 输出边框都是0, #8248

Closed
2 of 3 tasks
aliencaocao opened this issue May 16, 2023 · 7 comments
Closed
2 of 3 tasks

RT-DETR TensorRT 输出边框都是0, #8248

aliencaocao opened this issue May 16, 2023 · 7 comments
Labels
bug Something isn't working status/close

Comments

@aliencaocao
Copy link

问题确认 Search before asking

  • 我已经查询历史issue,没有发现相似的bug。I have searched the issues and found no similar bug report.

Bug组件 Bug Component

Inference, Export, Deploy

Bug描述 Describe the Bug

使用教程内的命令转RT-DETR-X到onnx,然后再转到TensorRT。ONNX一切正常,用TensorRTExecutionProvider也正常,但是如果用trtexec转到TensorRT engine,模型输出的class id和confidence看起来正常(实际对不对不知道),但是x1 x2 y1 y2都是0.000。在Windows 10 CUDA 11.8 + CUDNN 8.9和WSL2 Cuda 11.8 + CUDNN 8.9都可以复现。

image

图片是
soccer

复现环境 Environment

Windows 10 和 WSL2
CUDA 11.8
CUDNN 8.9
PaddlePaddle-gpu 2.4.2-post117
PaddleDetection develop
Python 3.9.13和3.10.6都可以复现

Bug描述确认 Bug description confirmation

  • 我确认已经提供了Bug复现步骤、代码改动说明、以及环境信息,确认问题是可以复现的。I confirm that the bug replication steps, code change instructions, and environment information have been provided, and the problem can be reproduced.

是否愿意提交PR? Are you willing to submit a PR?

  • 我愿意提交PR!I'd like to help by submitting a PR!
@aliencaocao aliencaocao added the bug Something isn't working label May 16, 2023
@aliencaocao
Copy link
Author

用了另一个版本的推理代码,现在bbox不是0了,但是整个结果都是乱的,bbox完全不着边,而且置信度都很低
image

我用的推理代码:

class TRTInference:
    def __init__(self, engine_path, output_names_mapping: dict = None, verbose=False):
        cuda.init()
        self.device_ctx = cuda.Device(0).make_context()
        self.engine_path = engine_path
        self.output_names_mapping = output_names_mapping or {}
        self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO)
        self.engine = None
        self.load_engine()
        assert self.engine is not None, 'Failed to load TensorRT engine.'

        self.context = self.engine.create_execution_context()
        self.stream = cuda.Stream()

        self.bindings = self.get_bindings()
        self.bindings_addr = OrderedDict((n, v.ptr) for n, v in self.bindings.items())

        self.input_names = self.get_input_names()
        self.output_names = self.get_output_names()

    def load_engine(self):
        with open(self.engine_path, 'rb') as f, trt.Runtime(self.logger) as runtime:
            self.engine = runtime.deserialize_cuda_engine(f.read())

    def get_input_names(self):
        names = []
        for _, name in enumerate(self.engine):
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
                names.append(name)
        return names

    def get_output_names(self):
        names = []
        for _, name in enumerate(self.engine):
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
                names.append(name)
        return names

    def get_bindings(self):
        Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
        bindings = OrderedDict()

        for i, name in enumerate(self.engine):
            shape = self.engine.get_tensor_shape(name)
            dtype = trt.nptype(self.engine.get_tensor_dtype(name))
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
                data = np.random.randn(*shape).astype(dtype)
                ptr = cuda.mem_alloc(data.nbytes)
                bindings[name] = Binding(name, dtype, shape, data, ptr)
            else:
                data = cuda.pagelocked_empty(trt.volume(shape), dtype)
                ptr = cuda.mem_alloc(data.nbytes)
                bindings[name] = Binding(name, dtype, shape, data, ptr)

        return bindings

    def __call__(self, blob):
        blob = {n: np.ascontiguousarray(v) for n, v in blob.items()}
        for n in self.input_names:
            cuda.memcpy_htod_async(self.bindings_addr[n], blob[n], self.stream)

        bindings_addr = [int(v) for _, v in self.bindings_addr.items()]
        self.context.execute_async_v2(bindings=bindings_addr, stream_handle=self.stream.handle)

        outputs = {}
        for n in self.output_names:
            cuda.memcpy_dtoh_async(self.bindings[n].data, self.bindings[n].ptr, self.stream)
            o = self.bindings[n].data
            # reshape to correct output shape
            if o.shape != self.bindings[n].shape:
                o = o.reshape(self.bindings[n].shape)
            outputs[self.output_names_mapping.get(n, n)] = o

        self.stream.synchronize()

        return outputs

    def warmup(self, blob, n):
        for _ in range(n):
            _ = self(blob)

    def __del__(self):
        try:
            self.device_ctx.pop()
        except cuda.LogicError as _:
            pass

model = TRTInference('rtdetr_hgnetv2_x_6x_coco.trt', output_names_mapping={'tile_3.tmp_0': 'bbox_num', 'reshape2_83.tmp_0': 'bbox'}, verbose=True)

img = cv2.imread("soccer.jpg")
h, w = 640, 640
org_img = img
img = cv2.resize(img, (h, w))
im_shape = np.array([[float(img.shape[0]), float(img.shape[1])]]).astype('float32')
scale_factor = np.array([[float(h/org_img.shape[0]), float(w/org_img.shape[1])]]).astype('float32')
img = img.astype(np.float32) / 255.0
img = np.expand_dims(img, axis=0)  # add batch dimension
inputs_dict = {
    'image': img,
    'im_shape': im_shape,
    'scale_factor': scale_factor
}
result = model(inputs_dict)['bbox']

confidence_threshold = 0.1
boxes = result[result[:, 1] > confidence_threshold]
for box in boxes:
    cv2.rectangle(org_img, (int(box[2]), int(box[3])), (int(box[4]), int(box[5])), (0, 255, 0), 2)
cv2.imwrite("output/soccer_trt.jpg", org_img)

@aliencaocao
Copy link
Author

用了 https://aistudio.baidu.com/aistudio/projectdetail/6000200的 inputs_dict 这部分推理代码后问题解决了。

img = cv2.imread("soccer.jpg")
org_img = img
im_shape = np.array([[float(img.shape[0]), float(img.shape[1])]]).astype('float32')
img = cv2.resize(img, (640, 640))
scale_factor = np.array([[float(640/img.shape[0]), float(640/img.shape[1])]]).astype('float32')
img = img.astype(np.float32) / 255.0
input_img = np.transpose(img, [2, 0, 1])
image = input_img[np.newaxis, :, :, :]
output_dict = ["reshape2_83.tmp_0", "tile_3.tmp_0"]
inputs_dict = {
    'im_shape': im_shape,
    'image': image,
    'scale_factor': scale_factor
}
result = model(inputs_dict)['bbox']

虽然并不懂为什么。我个人认为这个代码是错的,希望能有大神来指导一下。根据https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md,有以下几点不明白:

  1. im_shape 是resize后的大小,但是这个代码里用的是resize之前的,我觉得应该是640x640。
  2. scale factor 是输入图像大小比真实图像大小,但是这个代码里不管那种情况都是1,我觉得应该是640/im_shape

我原来的代码其实就是以上两点不同,但是不知道为什么。

@aliencaocao
Copy link
Author

aliencaocao commented May 16, 2023

@lyuwenyu 希望大佬能帮忙解释一下谢谢

@lyuwenyu
Copy link
Collaborator

lyuwenyu commented May 17, 2023

这块的逻辑可以看DETRPostProcess的逻辑, 就是im_shapescale_factor这两者相除能得到origin_shape就行

https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/post_process.py#L507

@lyuwenyu lyuwenyu reopened this May 17, 2023
@lyuwenyu
Copy link
Collaborator

用trt推理的结果没啥问题,你那个结果应该就是im_shape, scale_factor 这些参数不太对
image

@aliencaocao
Copy link
Author

是的,确实是这两个参数不对,谢谢解答

@luoshiyong
Copy link

大佬好,遇到了同样的问题(box全零,labels和scores正常),使用了@aliencaocao的推理代码后box变成全是inf。。。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working status/close
Projects
None yet
Development

No branches or pull requests

3 participants