Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detectron2 conversion to TensorRT #2546

Closed
andreysher opened this issue Dec 13, 2022 · 35 comments
Closed

Detectron2 conversion to TensorRT #2546

andreysher opened this issue Dec 13, 2022 · 35 comments
Labels
triaged Issue has been triaged by maintainers

Comments

@andreysher
Copy link

Description

I follow your sample here that support convert detectron2 model to TensorRT :
https://github.com/NVIDIA/TensorRT/tree/main/samples/python/detectron2
By default this sample use export_model.py from detectron2 repo. This script use torch.onnx.export for conversion. There is no any Caffe2 nodes in onnx after conversion.

The next step is onnx editing by your samples/python/detectron2/create_onnx.py. In this script there are many hardcoded onnx node names, such as: Gemm_1685, Softmax_1796, Gemm_1690 and so on. In the last detectron2 version this onnx node names is changed and your script doesn't work correct. I get the error:

Traceback (most recent call last):
File "TensorRT/samples/python/detectron2/create_onnx.py", line 556, in
main(args)
File "TensorRT/samples/python/detectron2/create_onnx.py", line 537, in main
det2_gs.process_graph(anchors, args.first_nms_threshold, args.second_nms_threshold)
File "TensorRT/samples/python/detectron2/create_onnx.py", line 525, in process_graph
box_head_outputs, mask_head_output = roi_heads(rpn_outputs, p2, p3, p4, p5, second_nms_threshold)
File "TensorRT/samples/python/detectron2/create_onnx.py", line 442, in roi_heads
first_box_head_gemm.inputs[0] = box_pooler_reshape[0]
AttributeError: 'NoneType' object has no attribute 'inputs'

But i get it after change some onnx node names in your script. I am not sure that i made it correctly.

Can you update your script corresponding to new onnx node names and support detectron2 to TensorRT conversion?

@williamf-searidgetech
Copy link

williamf-searidgetech commented Dec 15, 2022

I was also facing the same problem and ended up editing the create_onnx.py script.
Please note I also changed the STABLE_ONNX_OPSET_VERSION value to 16, the value can be found under the detectron2 repo detectron2/detectron2/export/__init__.py (please edit the value before installing detectron2 and using the export_model.py script to generate model.onnx) and I was running this with the 22.11-py3 Pytorch docker found on NVIDIA NGC.
Below is the result of running the newly built TensorRT engine generated by the edited create_onnx.py script. It was obtain with the infer.py script found under the detectron2 sample on TensorRT.
airplanes

Please see the newly edited create_onnx.py and note all the find_node_by_op_name() are now being called with a different second argument than before.

#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import re
import sys
import argparse
import logging
import cv2
import onnx_graphsurgeon as gs
import numpy as np
import onnx
from onnx import shape_inference
import torch

try:
    from detectron2.engine.defaults import DefaultPredictor
    from detectron2.modeling import build_model
    from detectron2.config import get_cfg
    from detectron2.structures import ImageList
except ImportError:
    print("Could not import Detectron 2 modules. Maybe you did not install Detectron 2")
    print("Please install Detectron 2, check https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md")
    sys.exit(1)

import onnx_utils

logging.basicConfig(level=logging.INFO)
logging.getLogger("ModelHelper").setLevel(logging.INFO)
log = logging.getLogger("ModelHelper")


class DET2GraphSurgeon:
    def __init__(self, saved_model_path, config_file, weights):
        """
        Constructor of the Model Graph Surgeon object, to do the conversion of a Detectron 2 Mask R-CNN exported model
        to an ONNX-TensorRT parsable model.
        :param saved_model_path: The path pointing to the exported Detectron 2 Mask R-CNN ONNX model. 
        :param config_file: The path pointing to the Detectron 2 yaml file which describes the model.
        :param config_file: Weights to load for the Detectron 2 model. 
        """

        def det2_setup(config_file, weights):
            """
            Create configs and perform basic setups.
            """
            cfg = get_cfg()
            cfg.merge_from_file(config_file)
            cfg.merge_from_list(["MODEL.WEIGHTS", weights])
            cfg.freeze()
            return cfg

        # Import exported Detectron 2 Mask R-CNN ONNX model as GraphSurgeon object.
        self.graph = gs.import_onnx(onnx.load(saved_model_path))
        assert self.graph
        log.info("ONNX graph loaded successfully")

        # Fold constants via ONNX-GS that exported script might've missed.
        self.graph.fold_constants()

        # Set up Detectron 2 model configuration.
        self.det2_cfg = det2_setup(config_file, weights)

        # Getting model characteristics.
        self.fpn_out_channels = self.det2_cfg.MODEL.FPN.OUT_CHANNELS
        self.num_classes = self.det2_cfg.MODEL.ROI_HEADS.NUM_CLASSES
        self.first_NMS_max_proposals = self.det2_cfg.MODEL.RPN.POST_NMS_TOPK_TEST
        self.first_NMS_iou_threshold = self.det2_cfg.MODEL.RPN.NMS_THRESH
        self.first_NMS_score_threshold = 0.01
        self.first_ROIAlign_pooled_size = self.det2_cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
        self.first_ROIAlign_sampling_ratio = self.det2_cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
        self.first_ROIAlign_type = self.det2_cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
        self.second_NMS_max_proposals = self.det2_cfg.TEST.DETECTIONS_PER_IMAGE
        self.second_NMS_iou_threshold = self.det2_cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST
        self.second_NMS_score_threshold = self.det2_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST
        self.second_ROIAlign_pooled_size = self.det2_cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
        self.second_ROIAlign_sampling_ratio = self.det2_cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO
        self.second_ROIAlign_type = self.det2_cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE
        self.mask_out_res = 28

        # Model characteristics.
        log.info("Number of FPN output channels is {}".format(self.fpn_out_channels))
        log.info("Number of classes is {}".format(self.num_classes))
        log.info("First NMS max proposals is {}".format(self.first_NMS_max_proposals))
        log.info("First NMS iou threshold is {}".format(self.first_NMS_iou_threshold))
        log.info("First NMS score threshold is {}".format(self.first_NMS_score_threshold))
        log.info("First ROIAlign type is {}".format(self.first_ROIAlign_type))
        log.info("First ROIAlign pooled size is {}".format(self.first_ROIAlign_pooled_size))
        log.info("First ROIAlign sampling ratio is {}".format(self.first_ROIAlign_sampling_ratio))
        log.info("Second NMS max proposals is {}".format(self.second_NMS_max_proposals))
        log.info("Second NMS iou threshold is {}".format(self.second_NMS_iou_threshold))
        log.info("Second NMS score threshold is {}".format(self.second_NMS_score_threshold))
        log.info("Second ROIAlign type is {}".format(self.second_ROIAlign_type))
        log.info("Second ROIAlign pooled size is {}".format(self.second_ROIAlign_pooled_size))
        log.info("Second ROIAlign sampling ratio is {}".format(self.second_ROIAlign_sampling_ratio))
        log.info("Individual mask output resolution is {}x{}".format(self.mask_out_res, self.mask_out_res))
        
        self.batch_size = None

    def sanitize(self):
        """
        Sanitize the graph by cleaning any unconnected nodes, do a topological resort, and fold constant inputs values.
        When possible, run shape inference on the ONNX graph to determine tensor shapes.
        """

        for i in range(3):
            count_before = len(self.graph.nodes)
            self.graph.cleanup().toposort()
            try:
                for node in self.graph.nodes:
                    for o in node.outputs:
                        o.shape = None
                model = gs.export_onnx(self.graph)
                model = shape_inference.infer_shapes(model)
                self.graph = gs.import_onnx(model)
            except Exception as e:
                log.info("Shape inference could not be performed at this time:\n{}".format(e))
            try:
                self.graph.fold_constants(fold_shapes=True)
            except TypeError as e:
                log.error("This version of ONNX GraphSurgeon does not support folding shapes, please upgrade your "
                          "onnx_graphsurgeon module. Error:\n{}".format(e))
                raise

            count_after = len(self.graph.nodes)
            if count_before == count_after:
                # No new folding occurred in this iteration, so we can stop for now.
                break

    def get_anchors(self, sample_image):
        """
        Detectron 2 exported ONNX does not contain anchors required for efficientNMS plug-in, so they must be generated
        "offline" by calling actual Detectron 2 model and getting anchors from it. 
        :param sample_image: Sample image required to run through the model and obtain anchors. 
        Can be any image from a dataset. Make sure listed here Detectron 2 preprocessing steps 
        actually match your preprocessing steps. Otherwise, behavior can be unpredictable. 
        Additionally, anchors have to be generated for a fixed input dimensions, 
        meaning as soon as image leaves a preprocessor and enters predictor.model.backbone() it must have 
        a fixed dimension (1344x1344 in my case) that every single image in dataset must follow, since currently
        TensorRT plug-ins do not support dynamic shapes.  
        """
        # Get Detectron 2 model config and build it.
        predictor = DefaultPredictor(self.det2_cfg)
        model = build_model(self.det2_cfg)

        # Image preprocessing.
        input_im = cv2.imread(sample_image)
        raw_height, raw_width = input_im.shape[:2]
        image = predictor.aug.get_transform(input_im).apply_image(input_im)
        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

        # Model preprocessing.
        inputs = [{"image": image, "height": raw_height, "width": raw_width}]
        images = [x["image"].to(model.device) for x in inputs]
        images = [(x - model.pixel_mean) / model.pixel_std for x in images]
        imagelist_images = ImageList.from_tensors(images, 1344)

        # Get feature maps from backbone.
        features = predictor.model.backbone(imagelist_images.tensor)

        # Get proposals from Region Proposal Network and obtain anchors from anchor generator.
        features = [features[f] for f in predictor.model.proposal_generator.in_features]
        det2_anchors = predictor.model.proposal_generator.anchor_generator(features)

        # Extract anchors based on feature maps in ascending order (P2->P6).
        p2_anchors = det2_anchors[0].tensor.detach().cpu().numpy()
        p3_anchors = det2_anchors[1].tensor.detach().cpu().numpy()
        p4_anchors = det2_anchors[2].tensor.detach().cpu().numpy()
        p5_anchors = det2_anchors[3].tensor.detach().cpu().numpy()
        p6_anchors = det2_anchors[4].tensor.detach().cpu().numpy()
        final_anchors = np.concatenate((p2_anchors,p3_anchors,p4_anchors,p5_anchors,p6_anchors))
        
        return final_anchors

    def save(self, output_path):
        """
        Save the ONNX model to the given location.
        :param output_path: Path pointing to the location where to write out the updated ONNX model.
        """
        self.graph.cleanup().toposort()
        model = gs.export_onnx(self.graph)
        output_path = os.path.realpath(output_path)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        onnx.save(model, output_path)
        log.info("Saved ONNX model to {}".format(output_path))

    def update_preprocessor(self, batch_size):
        """
        Remove all the pre-processing nodes in the ONNX graph and leave only the image normalization essentials.
        :param batch_size: The batch size to use for the ONNX graph.
        """
        # Set graph inputs.
        self.batch_size = batch_size
        self.height = self.graph.inputs[0].shape[1]
        self.width = self.graph.inputs[0].shape[2]

        input_shape = [self.batch_size, 3, self.height, self.width]
        self.graph.inputs[0].shape = input_shape
        self.graph.inputs[0].dtype = np.float32
        self.graph.inputs[0].name = "input_tensor"

        self.sanitize()
        log.info("ONNX graph input shape: {} [NCHW format set]".format(self.graph.inputs[0].shape))
        
        # Find the initial nodes of the graph, whatever the input is first connected to, and disconnect them.
        for node in [node for node in self.graph.nodes if self.graph.inputs[0] in node.inputs]:
            node.inputs.clear()

        # Get input tensor.
        input_tensor = self.graph.inputs[0]
        
        # Create preprocessing Sub node and connect input tensor to it.
        sub_const = np.expand_dims(np.asarray([255 * 0.406, 255 * 0.456, 255 * 0.485], dtype=np.float32), axis=(1, 2))
        sub_out = self.graph.op_with_const("Sub", "preprocessor/mean", input_tensor, sub_const)

        # Find first Div node and connect to output of Sub node.
        div_node = self.graph.find_node_by_op("Div")
        log.info("Found {} node".format(div_node.op))
        div_node.inputs[0] = sub_out[0]
        
        # Find first Conv and connect preprocessor directly to it.
        conv_node = self.graph.find_node_by_op("Conv")
        log.info("Found {} node".format(conv_node.op))
        conv_node.inputs[0] = div_node.outputs[0]

        # Reshape nodes tend to update the batch dimension to a fixed value of 1, they should use the batch size instead.
        for node in [node for node in self.graph.nodes if node.op == "Reshape"]:
            if type(node.inputs[1]) == gs.Constant and node.inputs[1].values[0] == 1:
                node.inputs[1].values[0] = self.batch_size

    def NMS(self, boxes, scores, anchors, background_class, score_activation, max_proposals, iou_threshold, nms_score_threshold, user_threshold, nms_name=None):
        # Helper function to create the NMS Plugin node with the selected inputs. 
        # EfficientNMS_TRT TensorRT Plugin is suitable for our use case.
        # :param boxes: The box predictions from the Box Net.      
        # :param scores: The class predictions from the Class Net.
        # :param anchors: The default anchor coordinates.
        # :param background_class: The label ID for the background class.
        # :param max_proposals: Number of proposals made by NMS.
        # :param score_activation: If set to True - apply sigmoid activation to the confidence scores during NMS operation, 
        # if false - no activation.
        # :param iou_threshold: NMS intersection over union threshold, given by self.det2_cfg.
        # :param nms_score_threshold: NMS score threshold, given by self.det2_cfg.
        # :param user_threshold: User's given threshold to overwrite default NMS score threshold. 
        # :param nms_name: Name of NMS node in a graph, renames NMS elements accordingly in order to eliminate cycles.

        if nms_name is None:
            nms_name = ""
        else:
            nms_name = "_" + nms_name
        
        # Set score threshold.
        score_threshold = nms_score_threshold if user_threshold is None else user_threshold

        # NMS Outputs.
        nms_output_num_detections = gs.Variable(name="num_detections"+nms_name, dtype=np.int32, shape=[self.batch_size, 1])
        nms_output_boxes = gs.Variable(name="detection_boxes"+nms_name, dtype=np.float32,
                                       shape=[self.batch_size, max_proposals, 4])
        nms_output_scores = gs.Variable(name="detection_scores"+nms_name, dtype=np.float32,
                                        shape=[self.batch_size, max_proposals])
        nms_output_classes = gs.Variable(name="detection_classes"+nms_name, dtype=np.int32,
                                         shape=[self.batch_size, max_proposals])

        nms_outputs = [nms_output_num_detections, nms_output_boxes, nms_output_scores, nms_output_classes]

        # Plugin.
        self.graph.plugin(
            op="EfficientNMS_TRT",
            name="nms"+nms_name,
            inputs=[boxes, scores, anchors],
            outputs=nms_outputs,
            attrs={
                'plugin_version': "1",
                'background_class': background_class,
                'max_output_boxes': max_proposals,
                'score_threshold': max(0.01, score_threshold),
                'iou_threshold': iou_threshold,
                'score_activation': score_activation,
                'box_coding': 1,
            } 
        )
        log.info("Created nms{} with EfficientNMS_TRT plugin".format(nms_name))

        return nms_outputs

    def ROIAlign(self, rois, p2, p3, p4, p5, pooled_size, sampling_ratio, roi_align_type, num_rois, ra_name):
        # Helper function to create the ROIAlign Plugin node with the selected inputs. 
        # PyramidROIAlign_TRT TensorRT Plugin is suitable for our use case.
        # :param rois: Regions of interest/detection boxes outputs from preceding NMS node. 
        # :param p2: Output of p2 feature map. 
        # :param p3: Output of p3 feature map. 
        # :param p4: Output of p4 feature map. 
        # :param p5: Output of p5 feature map. 
        # :param pooled_size: Pooled output dimensions.
        # :param sampling_ratio: Number of sampling points in the interpolation grid used to compute the output value of each pooled output bin. 
        # :param roi_align_type: Type of Detectron 2 ROIAlign op, either ROIAlign (vanilla) or ROIAlignV2 (0.5 coordinate offset).
        # :param num_rois: Number of ROIs resulting from ROIAlign operation. 
        # :param ra_name: Name of ROIAlign node in a graph, renames ROIAlign elements accordingly in order to eliminate cycles.

        # Different types of Detectron 2's ROIAlign ops require coordinate offset that is supported by PyramidROIAlign_TRT.
        if roi_align_type == "ROIAlignV2":
            roi_coords_transform = 2
        elif roi_align_type == "ROIAlign":
            roi_coords_transform = 0
        
        # ROIAlign outputs. 
        roi_align_output = gs.Variable(name="roi_align/output_"+ra_name, dtype=np.float32,
                                shape=[self.batch_size, num_rois, self.fpn_out_channels, pooled_size, pooled_size])
        
        # Plugin.
        self.graph.plugin(
            op="PyramidROIAlign_TRT",
            name="roi_align_"+ra_name,
            inputs=[rois, p2, p3, p4, p5],
            outputs=[roi_align_output],
            attrs={
                'plugin_version': "1",
                'fpn_scale': 224,
                'pooled_size': pooled_size,
                'image_size': [self.height, self.width],
                'roi_coords_absolute': 0,
                'roi_coords_swap': 0,
                'roi_coords_transform': roi_coords_transform,
                'sampling_ratio': sampling_ratio,
            } 
        )
        log.info("Created {} with PyramidROIAlign_TRT plugin".format(ra_name))

        return roi_align_output

    def process_graph(self, anchors, first_nms_threshold=None, second_nms_threshold=None):
        """
        Processes the graph to replace the GenerateProposals and BoxWithNMSLimit operations with EfficientNMS_TRT 
        TensorRT plugin nodes and ROIAlign operations with PyramidROIAlign_TRT plugin nodes.
        :param anchors: Anchors generated from sample image "offline" by Detectron 2, since anchors are not provided
        inside the graph.
        :param first_nms_threshold: Override the 1st NMS score threshold value. If set to None, use the value in the graph.
        :param second_nms_threshold: Override the 2nd NMS score threshold value. If set to None, use the value in the graph.
        """
        def backbone():
            """
            Updates the graph to replace all ResizeNearest ops with ResizeNearest plugins in backbone. 
            """
            # Get final backbone outputs.
            p2 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output2/Conv")
            p3 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output3/Conv")
            p4 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output4/Conv")
            p5 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output5/Conv")

            return p2.outputs[0], p3.outputs[0], p4.outputs[0], p5.outputs[0]

        def proposal_generator(anchors, first_nms_threshold):
            """
            Updates the graph to replace all GenerateProposals Caffe ops with one single NMS for proposals generation. 
            :param anchors: Anchors generated from sample image "offline" by Detectron 2, since anchors are not provided
            inside the graph
            :param first_nms_threshold: Override the 1st NMS score threshold value. If set to None, use the value in the graph.
            """
            # Get nodes containing final objectness logits.
            p2_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten")
            p3_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_1")
            p4_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_2")
            p5_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_3")
            p6_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_4")

            # Get nodes containing final anchor_deltas.
            p2_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_1")
            p3_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_3")
            p4_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_5")
            p5_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_7")
            p6_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_9")                    

            # Concatenate all objectness logits/scores data.
            scores_inputs = [p2_logits.outputs[0], p3_logits.outputs[0], p4_logits.outputs[0], p5_logits.outputs[0], p6_logits.outputs[0]]
            scores_tensor = self.graph.layer(name="scores", op="Concat", inputs=scores_inputs, outputs=['scores'], attrs={'axis': 1})[0]
            # Unsqueeze to add 3rd dimension of 1 to match tensor dimensions of boxes tensor.
            scores = self.graph.unsqueeze("scores_unsqueeze", scores_tensor, [2])[0]

            # Concatenate all boxes/anchor_delta data.
            boxes_inputs = [p2_anchors.outputs[0], p3_anchors.outputs[0], p4_anchors.outputs[0], p5_anchors.outputs[0], p6_anchors.outputs[0]]
            boxes = self.graph.layer(name="boxes", op="Concat", inputs=boxes_inputs, outputs=['anchors'], attrs={'axis': 1})[0]

            # Convert the anchors from Corners to CenterSize encoding.
            anchors = np.matmul(anchors, [[0.5, 0, -1, 0], [0, 0.5, 0, -1], [0.5, 0, 1, 0], [0, 0.5, 0, 1]])
            anchors = anchors / [self.width, self.height, self.width, self.height] # Normalize anchors to [0-1] range
            anchors = np.expand_dims(anchors, axis=0)
            anchors = anchors.astype(np.float32)
            anchors = gs.Constant(name="default_anchors", values=anchors)

            # Create NMS node.
            nms_outputs = self.NMS(boxes, scores, anchors, -1, False, self.first_NMS_max_proposals, self.first_NMS_iou_threshold, self.first_NMS_score_threshold, first_nms_threshold, 'rpn')

            return nms_outputs

        def roi_heads(rpn_outputs, p2, p3, p4, p5, second_nms_threshold):
            """
            Updates the graph to replace all ROIAlign Caffe ops with one single pyramid ROIAlign. Eliminates CollectRpnProposals
            DistributeFpnProposals and BatchPermutation nodes that are not supported by TensorRT. Connects pyramid ROIAlign to box_head
            and connects box_head to final box head outputs in a form of second NMS. In order to implement mask head outputs,
            similar steps as in box_pooler are performed to replace mask_pooler. Finally, reimplemented mask_pooler is connected to 
            mask_head and mask head outputs are produced.
            :param rpn_outputs: Outputs of the first NMS/proposal generator. 
            :param p2: Output of p2 feature map, required for ROIAlign operation. 
            :param p3: Output of p3 feature map, required for ROIAlign operation.  
            :param p4: Output of p4 feature map, required for ROIAlign operation.  
            :param p5: Output of p5 feature map, required for ROIAlign operation.  
            :param second_nms_threshold: Override the 2nd NMS score threshold value. If set to None, use the value in the graph.
            """
            # Create ROIAlign node. 
            box_pooler_output = self.ROIAlign(rpn_outputs[1], p2, p3, p4, p5, self.first_ROIAlign_pooled_size, self.first_ROIAlign_sampling_ratio, self.first_ROIAlign_type, self.first_NMS_max_proposals, 'box_pooler')
            
            # Reshape node that prepares ROIAlign/box pooler output for Gemm node that comes next.
            box_pooler_shape = np.asarray([-1, self.fpn_out_channels*self.first_ROIAlign_pooled_size*self.first_ROIAlign_pooled_size], dtype=np.int64)
            box_pooler_reshape = self.graph.op_with_const("Reshape", "box_pooler/reshape", box_pooler_output, box_pooler_shape)
            
            # Get first Gemm op of box head and connect box pooler to it.
            first_box_head_gemm = self.graph.find_node_by_op_name("Gemm", "/roi_heads/box_head/fc1/Gemm")
            first_box_head_gemm.inputs[0] = box_pooler_reshape[0]

            # Get final two nodes of box predictor. Softmax op for cls_score, Gemm op for bbox_pred.
            cls_score = self.graph.find_node_by_op_name("Softmax", "/roi_heads/Softmax")
            bbox_pred = self.graph.find_node_by_op_name("Gemm", "/roi_heads/box_predictor/bbox_pred/Gemm")

            # Linear transformation to convert box coordinates from (TopLeft, BottomRight) Corner encoding
            # to CenterSize encoding. 1st NMS boxes are multiplied by transformation matrix in order to 
            # encode it into CenterSize format.
            matmul_const = np.matrix('0.5 0 -1 0; 0 0.5 0 -1; 0.5 0 1 0; 0 0.5 0 1', dtype=np.float32)
            matmul_out = self.graph.matmul("RPN_NMS/detection_boxes_conversion", rpn_outputs[1], matmul_const)

            # Reshape node that prepares bbox_pred for scaling and second NMS.
            bbox_pred_shape = np.asarray([self.batch_size, self.first_NMS_max_proposals, self.num_classes, 4], dtype=np.int64)
            bbox_pred_reshape = self.graph.op_with_const("Reshape", "bbox_pred/reshape", bbox_pred.outputs[0], bbox_pred_shape)
            
            # 0.1, 0.1, 0.2, 0.2 are localization head variance numbers, they scale bbox_pred_reshape, in order to get accurate coordinates.
            scale_adj = np.expand_dims(np.asarray([0.1, 0.1, 0.2, 0.2], dtype=np.float32), axis=(0, 1))
            final_bbox_pred = self.graph.op_with_const("Mul", "bbox_pred/scale", bbox_pred_reshape[0], scale_adj)

            # Reshape node that prepares cls_score for slicing and second NMS.
            cls_score_shape = np.array([self.batch_size, self.first_NMS_max_proposals, self.num_classes+1], dtype=np.int64)
            cls_score_reshape = self.graph.op_with_const("Reshape", "cls_score/reshape", cls_score.outputs[0], cls_score_shape)
            
            # Slice operation to adjust third dimension of cls_score tensor, deletion of background class (81 in Detectron 2).
            final_cls_score = self.graph.slice("cls_score/slicer", cls_score_reshape[0], 0, self.num_classes, 2)

            # Create NMS node.
            nms_outputs = self.NMS(final_bbox_pred[0], final_cls_score[0], matmul_out[0], -1, False, self.second_NMS_max_proposals, self.second_NMS_iou_threshold, self.second_NMS_score_threshold, second_nms_threshold, 'box_outputs')

            # Create ROIAlign node. 
            mask_pooler_output = self.ROIAlign(nms_outputs[1], p2, p3, p4, p5, self.second_ROIAlign_pooled_size, self.second_ROIAlign_sampling_ratio, self.second_ROIAlign_type, self.second_NMS_max_proposals, 'mask_pooler')
            
            # Reshape mask pooler output. 
            mask_pooler_shape = np.asarray([self.second_NMS_max_proposals*self.batch_size, self.fpn_out_channels, self.second_ROIAlign_pooled_size, self.second_ROIAlign_pooled_size], dtype=np.int64)
            mask_pooler_reshape_node = self.graph.op_with_const("Reshape", "mask_pooler/reshape", mask_pooler_output, mask_pooler_shape)
            
            # Get first Conv op in mask head and connect ROIAlign's squeezed output to it. 
            mask_head_conv = self.graph.find_node_by_op_name("Conv", "/roi_heads/mask_head/mask_fcn1/Conv")
            mask_head_conv.inputs[0] = mask_pooler_reshape_node[0]
           
            # Reshape node that is preparing 2nd NMS class outputs for Add node that comes next.
            classes_reshape_shape = np.asarray([self.second_NMS_max_proposals*self.batch_size], dtype=np.int64)
            classes_reshape_node = self.graph.op_with_const("Reshape", "box_outputs/reshape_classes", nms_outputs[3], classes_reshape_shape)
            
            # This loop will generate an array used in Add node, which eventually will help Gather node to pick the single
            # class of interest per bounding box, instead of creating 80 masks for every single bounding box. 
            add_array = []
            for i in range(self.second_NMS_max_proposals*self.batch_size):
                if i == 0:
                    start_pos = 0
                else:
                    start_pos = i * self.num_classes
                add_array.append(start_pos)
            
            # This Add node is one of the Gather node inputs, Gather node performs gather on 0th axis of data tensor 
            # and requires indices that set tensors to be withing bounds, this Add node provides the bounds for Gather. 
            add_array = np.asarray(add_array, dtype=np.int32)
            classes_add_node = self.graph.op_with_const("Add", "box_outputs/add", classes_reshape_node[0], add_array)
            
            # Get the last Conv op in mask head and reshape it to correctly gather class of interest's masks. 
            last_conv = self.graph.find_node_by_op_name("Conv", "/roi_heads/mask_head/predictor/Conv")
            last_conv_reshape_shape = np.asarray([self.second_NMS_max_proposals*self.num_classes*self.batch_size, self.mask_out_res, self.mask_out_res], dtype=np.int64)
            last_conv_reshape_node = self.graph.op_with_const("Reshape", "mask_head/reshape_all_masks", last_conv.outputs[0], last_conv_reshape_shape)
            
            # Gather node that selects only masks belonging to detected class, 79 other masks are discarded. 
            final_gather = self.graph.gather("mask_head/final_gather", last_conv_reshape_node[0], classes_add_node[0], 0)
            
            # Get last Sigmoid node and connect Gather node to it. 
            mask_head_sigmoid = self.graph.find_node_by_op_name("Sigmoid", "/roi_heads/mask_head/Sigmoid")
            mask_head_sigmoid.inputs[0] = final_gather[0]
            
            # Final Reshape node, reshapes output of Sigmoid, important for various batch_size support (not tested yet).
            final_graph_reshape_shape = np.asarray([self.batch_size, self.second_NMS_max_proposals, self.mask_out_res, self.mask_out_res], dtype=np.int64)
            final_graph_reshape_node = self.graph.op_with_const("Reshape", "mask_head/final_reshape", mask_head_sigmoid.outputs[0], final_graph_reshape_shape)
            final_graph_reshape_node[0].dtype = np.float32
            final_graph_reshape_node[0].name = "detection_masks"

            return nms_outputs, final_graph_reshape_node[0]

        # Only Detectron 2's Mask-RCNN R50-FPN 3x is supported currently.
        p2, p3, p4, p5 = backbone()
        rpn_outputs = proposal_generator(anchors, first_nms_threshold)
        box_head_outputs, mask_head_output = roi_heads(rpn_outputs, p2, p3, p4, p5, second_nms_threshold)
        # Append segmentation head output.
        box_head_outputs.append(mask_head_output)
        # Set graph outputs, both bbox and segmentation heads.
        self.graph.outputs = box_head_outputs
        self.sanitize()        


def main(args):
    det2_gs = DET2GraphSurgeon(args.exported_onnx, args.det2_config, args.det2_weights)
    det2_gs.update_preprocessor(args.batch_size)
    anchors = det2_gs.get_anchors(args.sample_image)
    det2_gs.process_graph(anchors, args.first_nms_threshold, args.second_nms_threshold)
    det2_gs.save(args.onnx)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--exported_onnx", help="The exported to ONNX Detectron 2 Mask R-CNN", type=str)
    parser.add_argument("-o", "--onnx", help="The output ONNX model file to write", type=str)
    parser.add_argument("-c", "--det2_config", help="The Detectron 2 config file (.yaml) for the model", type=str)
    parser.add_argument("-w", "--det2_weights", help="The Detectron 2 model weights (.pkl)", type=str)
    parser.add_argument("-s", "--sample_image", help="Sample image for anchors generation", type=str)
    parser.add_argument("-b", "--batch_size", help="Batch size for the model", type=int, default=1)
    parser.add_argument("-t1", "--first_nms_threshold", help="Override the score threshold for the 1st NMS operation", type=float)
    parser.add_argument("-t2", "--second_nms_threshold", help="Override the score threshold for the 2nd NMS operation", type=float)
    args = parser.parse_args()
    if not all([args.exported_onnx, args.onnx, args.det2_config, args.det2_weights, args.sample_image]):
        parser.print_help()
        print("\nThese arguments are required: --exported_onnx --onnx --det2_config --det2_weights and --sample_image")
        sys.exit(1)
    main(args)

@williamf-searidgetech
Copy link

Furthermore, the converted.onnx should be looking like this
converted onnx

@zerollzeng
Copy link
Collaborator

@azhurkevich Can you help here? thanks!

@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Dec 19, 2022
@azhurkevich
Copy link
Contributor

@williamf-searidgetech Great work! Haven't tried it myself, but graph looks correct to me!
@andreysher Reason for hardcoded names is because it was the only way to do it back when I created the latest version. William, did some great work by replacing names with node prefixes, but those also change with versions sometimes when detectron2 is updated. Back in summer there was det2 update that took out Caffe2 + ONNX export and replaced it with ONNX only. When that happened node naming and graph itself changed a lot, hence there was not easy way to traverse the graph to dynamically find nodes that are needed. Most likely, lately det2 devs changed something again, hence node naming is different -> not able to find right nodes -> not able to graph surge correctly and it is broken. Keep in mind if det2 export graph changes or node prefixes change, dynamic node search will also not work since one way or the other you need to rely on preexisting patterns.

@andreysher
Copy link
Author

Hi @williamf-searidgetech thanks for your reply, but i can't reproduce this. I use detectron2 from this commit: 210809711838d5bd193e8bb3bc4bab39a660bf5b, change STABLE_ONNX_OPSET_VERSION to 16. After python detectron2/tools/deploy/export_model.py i got model.onnx (you can see it in attachments), and node names in your edited script do not match it.model.onnx

@williamf-searidgetech
Copy link

williamf-searidgetech commented Dec 21, 2022

@andreysher I have just checked your model and it seems like we are using different version of pytorch. The model.onnx you have generated was created with pytorch 1.12.1. Whereas, I was using pytorch 1.13.0. While I am not completely certain if that is the cause for the discrepancies between our model.onnx, it is likely the root cause as the export_model.py script from detectron2 ultimately calls the underlying pytorch onnx converter.
Here is my model.onnx for comparison and hopefully you can search and replace the node names and create a modified create_onnx.py for your model.onnx.
I would also suggest to use the official nvidia pytorch docker to ensure the python environments are consistent.

@azhurkevich
Copy link
Contributor

Thank you @williamf-searidgetech

@Photoheyler
Copy link

Photoheyler commented Dec 30, 2022

@williamf-searidgetech
I tried your approach. The Model conversion works and also the conversion to Tensor RT. But the Model doesn't predict anything anymore.
I also tried to predict with the converted onnx Model.
I got the following error

onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from C:/Users/Photoheyler/Desktop/YoloV7Instance/detectron2/tools/deploy/output/converted.onnx failed:This is an invalid model. In Node, ("roi_align_box_pooler", PyramidROIAlign_TRT, "", -1) : ("detection_boxes_rpn": tensor(float),"/backbone/fpn_output2/Conv_output_0": tensor(float),"/backbone/fpn_output3/Conv_output_0": tensor(float),"/backbone/fpn_output4/Conv_output_0": tensor(float),"/backbone/fpn_output5/Conv_output_0": tensor(float),) -> ("roi_align/output_box_pooler": tensor(float),) , Error Unrecognized attribute: fpn_scale for operator PyramidROIAlign_TRT

Maybe this is the error why the tensor rt Model can't predict anything

@azhurkevich
Copy link
Contributor

@Photoheyler I never tested ONNX-RT. I think it fails because updated PyramidROIAlign_TRT plug-in has not been registered with ONNX-RT. We had to make some significant changes to the plug-in early this year to make it work with this model. fpn_scale is a new argument.

@Photoheyler
Copy link

Photoheyler commented Dec 30, 2022

@azhurkevich
I tested the converted onnx model with onnx cuda not with onnx-tensor RT.
Usually it should work, or not?

@azhurkevich
Copy link
Contributor

azhurkevich commented Dec 30, 2022

@Photoheyler Nope, it will not work with ONNX CUDA backend because converted model heavily relies on TRT plug-ins for speed. CUDA backend will be much slower otherwise. Please convert model and evaluate to check if it produces expected mAP.

@Photoheyler
Copy link

so when I use a sample image with 1280x800 I can export the converted onnx and build the trt engine. But I get no Results.
When I use a sample Image with 1344x1344. Ican also build the converted onnx. But Icannot build the Trt engine.
I'm getting the following error.
ERROR:EngineBuilder:In node 257 (parseGraph): INVALID_NODE: Invalid Node - box_pooler/reshape
Attribute not found: allowzero

@tomasctg
Copy link

tomasctg commented Jan 6, 2023

I've been working on this problem for a week. I solved it looking for the first oonx conversion nodes names. Then changing the function input of find_node_by_op_name of the create_onnx.py script from TensorRT. Both the name of the nodes and the type had to be changed.

The first conversion from .pth to onnx using export_model.py uses -export-method tracing.

python3 /home/appuser/detectron2_repo/tools/deploy/export_model.py \
                --config-file \
                --output  \
                --format onnx \
                --sample-image \
                --export-method tracing \
                MODEL.DEVICE cuda \
                MODEL.WEIGHTS model_final.pth

Be careful with the sample image, it must has 1344X1344 shape like the neural network asks to. Also it must contain at least one instance of any class you trained.

You can see my modifications on the create_onnx.py script, mainly in the function find_node_by_op_name.

#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import re
import sys
import argparse
import logging
import cv2
import onnx_graphsurgeon as gs
import numpy as np
import onnx
from onnx import shape_inference
import torch

try:
    from detectron2.engine.defaults import DefaultPredictor
    from detectron2.modeling import build_model
    from detectron2.config import get_cfg
    from detectron2.structures import ImageList
except ImportError:
    print("Could not import Detectron 2 modules. Maybe you did not install Detectron 2")
    print("Please install Detectron 2, check https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md")
    sys.exit(1)

import onnx_utils

logging.basicConfig(level=logging.INFO)
logging.getLogger("ModelHelper").setLevel(logging.INFO)
log = logging.getLogger("ModelHelper")


class DET2GraphSurgeon:
    def __init__(self, saved_model_path, config_file, weights):
        """
        Constructor of the Model Graph Surgeon object, to do the conversion of a Detectron 2 Mask R-CNN exported model
        to an ONNX-TensorRT parsable model.
        :param saved_model_path: The path pointing to the exported Detectron 2 Mask R-CNN ONNX model. 
        :param config_file: The path pointing to the Detectron 2 yaml file which describes the model.
        :param config_file: Weights to load for the Detectron 2 model. 
        """

        def det2_setup(config_file, weights):
            """
            Create configs and perform basic setups.
            """
            cfg = get_cfg()
            cfg.merge_from_file(config_file)
            cfg.merge_from_list(["MODEL.WEIGHTS", weights])
            cfg.freeze()
            return cfg

        # Import exported Detectron 2 Mask R-CNN ONNX model as GraphSurgeon object.
        self.graph = gs.import_onnx(onnx.load(saved_model_path))
        assert self.graph
        log.info("ONNX graph loaded successfully")

        # Fold constants via ONNX-GS that exported script might've missed.
        self.graph.fold_constants()

        # Set up Detectron 2 model configuration.
        self.det2_cfg = det2_setup(config_file, weights)

        # Getting model characteristics.
        self.fpn_out_channels = self.det2_cfg.MODEL.FPN.OUT_CHANNELS
        self.num_classes = self.det2_cfg.MODEL.ROI_HEADS.NUM_CLASSES
        self.first_NMS_max_proposals = self.det2_cfg.MODEL.RPN.POST_NMS_TOPK_TEST
        self.first_NMS_iou_threshold = self.det2_cfg.MODEL.RPN.NMS_THRESH
        self.first_NMS_score_threshold = 0.01
        self.first_ROIAlign_pooled_size = self.det2_cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
        self.first_ROIAlign_sampling_ratio = self.det2_cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
        self.first_ROIAlign_type = self.det2_cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
        self.second_NMS_max_proposals = self.det2_cfg.TEST.DETECTIONS_PER_IMAGE
        self.second_NMS_iou_threshold = self.det2_cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST
        self.second_NMS_score_threshold = self.det2_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST
        self.second_ROIAlign_pooled_size = self.det2_cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
        self.second_ROIAlign_sampling_ratio = self.det2_cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO
        self.second_ROIAlign_type = self.det2_cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE
        self.mask_out_res = 28

        # Model characteristics.
        log.info("Number of FPN output channels is {}".format(self.fpn_out_channels))
        log.info("Number of classes is {}".format(self.num_classes))
        log.info("First NMS max proposals is {}".format(self.first_NMS_max_proposals))
        log.info("First NMS iou threshold is {}".format(self.first_NMS_iou_threshold))
        log.info("First NMS score threshold is {}".format(self.first_NMS_score_threshold))
        log.info("First ROIAlign type is {}".format(self.first_ROIAlign_type))
        log.info("First ROIAlign pooled size is {}".format(self.first_ROIAlign_pooled_size))
        log.info("First ROIAlign sampling ratio is {}".format(self.first_ROIAlign_sampling_ratio))
        log.info("Second NMS max proposals is {}".format(self.second_NMS_max_proposals))
        log.info("Second NMS iou threshold is {}".format(self.second_NMS_iou_threshold))
        log.info("Second NMS score threshold is {}".format(self.second_NMS_score_threshold))
        log.info("Second ROIAlign type is {}".format(self.second_ROIAlign_type))
        log.info("Second ROIAlign pooled size is {}".format(self.second_ROIAlign_pooled_size))
        log.info("Second ROIAlign sampling ratio is {}".format(self.second_ROIAlign_sampling_ratio))
        log.info("Individual mask output resolution is {}x{}".format(self.mask_out_res, self.mask_out_res))
        
        self.batch_size = None

    def sanitize(self):
        """
        Sanitize the graph by cleaning any unconnected nodes, do a topological resort, and fold constant inputs values.
        When possible, run shape inference on the ONNX graph to determine tensor shapes.
        """

        for i in range(3):
            count_before = len(self.graph.nodes)
            self.graph.cleanup().toposort()
            try:
                for node in self.graph.nodes:
                    for o in node.outputs:
                        o.shape = None
                model = gs.export_onnx(self.graph)
                model = shape_inference.infer_shapes(model)
                self.graph = gs.import_onnx(model)
            except Exception as e:
                log.info("Shape inference could not be performed at this time:\n{}".format(e))
            try:
                self.graph.fold_constants(fold_shapes=True)
            except TypeError as e:
                log.error("This version of ONNX GraphSurgeon does not support folding shapes, please upgrade your "
                          "onnx_graphsurgeon module. Error:\n{}".format(e))
                raise

            count_after = len(self.graph.nodes)
            if count_before == count_after:
                # No new folding occurred in this iteration, so we can stop for now.
                break

    def get_anchors(self, sample_image):
        """
        Detectron 2 exported ONNX does not contain anchors required for efficientNMS plug-in, so they must be generated
        "offline" by calling actual Detectron 2 model and getting anchors from it. 
        :param sample_image: Sample image required to run through the model and obtain anchors. 
        Can be any image from a dataset. Make sure listed here Detectron 2 preprocessing steps 
        actually match your preprocessing steps. Otherwise, behavior can be unpredictable. 
        Additionally, anchors have to be generated for a fixed input dimensions, 
        meaning as soon as image leaves a preprocessor and enters predictor.model.backbone() it must have 
        a fixed dimension (1344x1344 in my case) that every single image in dataset must follow, since currently
        TensorRT plug-ins do not support dynamic shapes.  
        """
        # Get Detectron 2 model config and build it.
        predictor = DefaultPredictor(self.det2_cfg)
        model = build_model(self.det2_cfg)

        # Image preprocessing.
        input_im = cv2.imread(sample_image)
        raw_height, raw_width = input_im.shape[:2]
        image = predictor.aug.get_transform(input_im).apply_image(input_im)
        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

        # Model preprocessing.
        inputs = [{"image": image, "height": raw_height, "width": raw_width}]
        images = [x["image"].to(model.device) for x in inputs]
        images = [(x - model.pixel_mean) / model.pixel_std for x in images]
        imagelist_images = ImageList.from_tensors(images, 1344)

        # Get feature maps from backbone.
        features = predictor.model.backbone(imagelist_images.tensor)

        # Get proposals from Region Proposal Network and obtain anchors from anchor generator.
        features = [features[f] for f in predictor.model.proposal_generator.in_features]
        det2_anchors = predictor.model.proposal_generator.anchor_generator(features)

        # Extract anchors based on feature maps in ascending order (P2->P6).
        p2_anchors = det2_anchors[0].tensor.detach().cpu().numpy()
        p3_anchors = det2_anchors[1].tensor.detach().cpu().numpy()
        p4_anchors = det2_anchors[2].tensor.detach().cpu().numpy()
        p5_anchors = det2_anchors[3].tensor.detach().cpu().numpy()
        p6_anchors = det2_anchors[4].tensor.detach().cpu().numpy()
        final_anchors = np.concatenate((p2_anchors,p3_anchors,p4_anchors,p5_anchors,p6_anchors))
        
        return final_anchors

    def save(self, output_path):
        """
        Save the ONNX model to the given location.
        :param output_path: Path pointing to the location where to write out the updated ONNX model.
        """
        self.graph.cleanup().toposort()
        model = gs.export_onnx(self.graph)
        output_path = os.path.realpath(output_path)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        onnx.save(model, output_path)
        log.info("Saved ONNX model to {}".format(output_path))

    def update_preprocessor(self, batch_size):
        """
        Remove all the pre-processing nodes in the ONNX graph and leave only the image normalization essentials.
        :param batch_size: The batch size to use for the ONNX graph.
        """
        # Set graph inputs.
        self.batch_size = batch_size
        self.height = self.graph.inputs[0].shape[1]
        self.width = self.graph.inputs[0].shape[2]

        input_shape = [self.batch_size, 3, self.height, self.width]
        self.graph.inputs[0].shape = input_shape
        self.graph.inputs[0].dtype = np.float32
        self.graph.inputs[0].name = "input_tensor"

        self.sanitize()
        log.info("ONNX graph input shape: {} [NCHW format set]".format(self.graph.inputs[0].shape))
        
        # Find the initial nodes of the graph, whatever the input is first connected to, and disconnect them.
        for node in [node for node in self.graph.nodes if self.graph.inputs[0] in node.inputs]:
            node.inputs.clear()

        # Get input tensor.
        input_tensor = self.graph.inputs[0]
        
        # Create preprocessing Sub node and connect input tensor to it.
        sub_const = np.expand_dims(np.asarray([255 * 0.406, 255 * 0.456, 255 * 0.485], dtype=np.float32), axis=(1, 2))
        sub_out = self.graph.op_with_const("Sub", "preprocessor/mean", input_tensor, sub_const)

        # Find first Div node and connect to output of Sub node.
        div_node = self.graph.find_node_by_op("Div")
        log.info("Found {} node".format(div_node.op))
        div_node.inputs[0] = sub_out[0]
        
        # Find first Conv and connect preprocessor directly to it.
        conv_node = self.graph.find_node_by_op("Conv")
        log.info("Found {} node".format(conv_node.op))
        conv_node.inputs[0] = div_node.outputs[0]

        # Reshape nodes tend to update the batch dimension to a fixed value of 1, they should use the batch size instead.
        for node in [node for node in self.graph.nodes if node.op == "Reshape"]:
            if type(node.inputs[1]) == gs.Constant and node.inputs[1].values[0] == 1:
                node.inputs[1].values[0] = self.batch_size

    def NMS(self, boxes, scores, anchors, background_class, score_activation, max_proposals, iou_threshold, nms_score_threshold, user_threshold, nms_name=None):
        # Helper function to create the NMS Plugin node with the selected inputs. 
        # EfficientNMS_TRT TensorRT Plugin is suitable for our use case.
        # :param boxes: The box predictions from the Box Net.      
        # :param scores: The class predictions from the Class Net.
        # :param anchors: The default anchor coordinates.
        # :param background_class: The label ID for the background class.
        # :param max_proposals: Number of proposals made by NMS.
        # :param score_activation: If set to True - apply sigmoid activation to the confidence scores during NMS operation, 
        # if false - no activation.
        # :param iou_threshold: NMS intersection over union threshold, given by self.det2_cfg.
        # :param nms_score_threshold: NMS score threshold, given by self.det2_cfg.
        # :param user_threshold: User's given threshold to overwrite default NMS score threshold. 
        # :param nms_name: Name of NMS node in a graph, renames NMS elements accordingly in order to eliminate cycles.

        if nms_name is None:
            nms_name = ""
        else:
            nms_name = "_" + nms_name
        
        # Set score threshold.
        score_threshold = nms_score_threshold if user_threshold is None else user_threshold

        # NMS Outputs.
        nms_output_num_detections = gs.Variable(name="num_detections"+nms_name, dtype=np.int32, shape=[self.batch_size, 1])
        nms_output_boxes = gs.Variable(name="detection_boxes"+nms_name, dtype=np.float32,
                                       shape=[self.batch_size, max_proposals, 4])
        nms_output_scores = gs.Variable(name="detection_scores"+nms_name, dtype=np.float32,
                                        shape=[self.batch_size, max_proposals])
        nms_output_classes = gs.Variable(name="detection_classes"+nms_name, dtype=np.int32,
                                         shape=[self.batch_size, max_proposals])

        nms_outputs = [nms_output_num_detections, nms_output_boxes, nms_output_scores, nms_output_classes]

        # Plugin.
        self.graph.plugin(
            op="EfficientNMS_TRT",
            name="nms"+nms_name,
            inputs=[boxes, scores, anchors],
            outputs=nms_outputs,
            attrs={
                'plugin_version': "1",
                'background_class': background_class,
                'max_output_boxes': max_proposals,
                'score_threshold': max(0.01, score_threshold),
                'iou_threshold': iou_threshold,
                'score_activation': score_activation,
                'box_coding': 1,
            } 
        )
        log.info("Created nms{} with EfficientNMS_TRT plugin".format(nms_name))

        return nms_outputs

    def ROIAlign(self, rois, p2, p3, p4, p5, pooled_size, sampling_ratio, roi_align_type, num_rois, ra_name):
        # Helper function to create the ROIAlign Plugin node with the selected inputs. 
        # PyramidROIAlign_TRT TensorRT Plugin is suitable for our use case.
        # :param rois: Regions of interest/detection boxes outputs from preceding NMS node. 
        # :param p2: Output of p2 feature map. 
        # :param p3: Output of p3 feature map. 
        # :param p4: Output of p4 feature map. 
        # :param p5: Output of p5 feature map. 
        # :param pooled_size: Pooled output dimensions.
        # :param sampling_ratio: Number of sampling points in the interpolation grid used to compute the output value of each pooled output bin. 
        # :param roi_align_type: Type of Detectron 2 ROIAlign op, either ROIAlign (vanilla) or ROIAlignV2 (0.5 coordinate offset).
        # :param num_rois: Number of ROIs resulting from ROIAlign operation. 
        # :param ra_name: Name of ROIAlign node in a graph, renames ROIAlign elements accordingly in order to eliminate cycles.

        # Different types of Detectron 2's ROIAlign ops require coordinate offset that is supported by PyramidROIAlign_TRT.
        if roi_align_type == "ROIAlignV2":
            roi_coords_transform = 2
        elif roi_align_type == "ROIAlign":
            roi_coords_transform = 0
        
        # ROIAlign outputs. 
        roi_align_output = gs.Variable(name="roi_align/output_"+ra_name, dtype=np.float32,
                                shape=[self.batch_size, num_rois, self.fpn_out_channels, pooled_size, pooled_size])
        
        # Plugin.
        self.graph.plugin(
            op="PyramidROIAlign_TRT",
            name="roi_align_"+ra_name,
            inputs=[rois, p2, p3, p4, p5],
            outputs=[roi_align_output],
            attrs={
                'plugin_version': "1",
                'fpn_scale': 224,
                'pooled_size': pooled_size,
                'image_size': [self.height, self.width],
                'roi_coords_absolute': 0,
                'roi_coords_swap': 0,
                'roi_coords_transform': roi_coords_transform,
                'sampling_ratio': sampling_ratio,
            } 
        )
        log.info("Created {} with PyramidROIAlign_TRT plugin".format(ra_name))

        return roi_align_output

    def process_graph(self, anchors, first_nms_threshold=None, second_nms_threshold=None):
        """
        Processes the graph to replace the GenerateProposals and BoxWithNMSLimit operations with EfficientNMS_TRT 
        TensorRT plugin nodes and ROIAlign operations with PyramidROIAlign_TRT plugin nodes.
        :param anchors: Anchors generated from sample image "offline" by Detectron 2, since anchors are not provided
        inside the graph.
        :param first_nms_threshold: Override the 1st NMS score threshold value. If set to None, use the value in the graph.
        :param second_nms_threshold: Override the 2nd NMS score threshold value. If set to None, use the value in the graph.
        """
        def backbone():
            """
            Updates the graph to replace all ResizeNearest ops with ResizeNearest plugins in backbone. 
            """
            # Get final backbone outputs.
            p2 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output2/Conv")
            p3 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output3/Conv")
            p4 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output4/Conv")
            p5 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output5/Conv")

            return p2.outputs[0], p3.outputs[0], p4.outputs[0], p5.outputs[0]

        def proposal_generator(anchors, first_nms_threshold):
            """
            Updates the graph to replace all GenerateProposals Caffe ops with one single NMS for proposals generation. 
            :param anchors: Anchors generated from sample image "offline" by Detectron 2, since anchors are not provided
            inside the graph
            :param first_nms_threshold: Override the 1st NMS score threshold value. If set to None, use the value in the graph.
            """
            # Get nodes containing final objectness logits.
            p2_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten")
            p3_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_1")
            p4_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_2")
            p5_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_3")
            p6_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_4")

            # Get nodes containing final anchor_deltas.
            p2_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_1")
            p3_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_3")
            p4_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_5")
            p5_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_7")
            p6_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_9")                    

            # Concatenate all objectness logits/scores data.
            scores_inputs = [p2_logits.outputs[0], p3_logits.outputs[0], p4_logits.outputs[0], p5_logits.outputs[0], p6_logits.outputs[0]]
            scores_tensor = self.graph.layer(name="scores", op="Concat", inputs=scores_inputs, outputs=['scores'], attrs={'axis': 1})[0]
            # Unsqueeze to add 3rd dimension of 1 to match tensor dimensions of boxes tensor.
            scores = self.graph.unsqueeze("scores_unsqueeze", scores_tensor, [2])[0]

            # Concatenate all boxes/anchor_delta data.
            boxes_inputs = [p2_anchors.outputs[0], p3_anchors.outputs[0], p4_anchors.outputs[0], p5_anchors.outputs[0], p6_anchors.outputs[0]]
            boxes = self.graph.layer(name="boxes", op="Concat", inputs=boxes_inputs, outputs=['anchors'], attrs={'axis': 1})[0]

            # Convert the anchors from Corners to CenterSize encoding.
            anchors = np.matmul(anchors, [[0.5, 0, -1, 0], [0, 0.5, 0, -1], [0.5, 0, 1, 0], [0, 0.5, 0, 1]])
            anchors = anchors / [self.width, self.height, self.width, self.height] # Normalize anchors to [0-1] range
            anchors = np.expand_dims(anchors, axis=0)
            anchors = anchors.astype(np.float32)
            anchors = gs.Constant(name="default_anchors", values=anchors)

            # Create NMS node.
            nms_outputs = self.NMS(boxes, scores, anchors, -1, False, self.first_NMS_max_proposals, self.first_NMS_iou_threshold, self.first_NMS_score_threshold, first_nms_threshold, 'rpn')

            return nms_outputs

        def roi_heads(rpn_outputs, p2, p3, p4, p5, second_nms_threshold):
            """
            Updates the graph to replace all ROIAlign Caffe ops with one single pyramid ROIAlign. Eliminates CollectRpnProposals
            DistributeFpnProposals and BatchPermutation nodes that are not supported by TensorRT. Connects pyramid ROIAlign to box_head
            and connects box_head to final box head outputs in a form of second NMS. In order to implement mask head outputs,
            similar steps as in box_pooler are performed to replace mask_pooler. Finally, reimplemented mask_pooler is connected to 
            mask_head and mask head outputs are produced.
            :param rpn_outputs: Outputs of the first NMS/proposal generator. 
            :param p2: Output of p2 feature map, required for ROIAlign operation. 
            :param p3: Output of p3 feature map, required for ROIAlign operation.  
            :param p4: Output of p4 feature map, required for ROIAlign operation.  
            :param p5: Output of p5 feature map, required for ROIAlign operation.  
            :param second_nms_threshold: Override the 2nd NMS score threshold value. If set to None, use the value in the graph.
            """
            # Create ROIAlign node. 
            box_pooler_output = self.ROIAlign(rpn_outputs[1], p2, p3, p4, p5, self.first_ROIAlign_pooled_size, self.first_ROIAlign_sampling_ratio, self.first_ROIAlign_type, self.first_NMS_max_proposals, 'box_pooler')
            
            # Reshape node that prepares ROIAlign/box pooler output for Gemm node that comes next.
            box_pooler_shape = np.asarray([-1, self.fpn_out_channels*self.first_ROIAlign_pooled_size*self.first_ROIAlign_pooled_size], dtype=np.int64)
            box_pooler_reshape = self.graph.op_with_const("Reshape", "box_pooler/reshape", box_pooler_output, box_pooler_shape)
            
            # Get first Gemm op of box head and connect box pooler to it.
            first_box_head_gemm = self.graph.find_node_by_op_name("Gemm", "/roi_heads/box_head/fc1/Gemm")
            first_box_head_gemm.inputs[0] = box_pooler_reshape[0]

            # Get final two nodes of box predictor. Softmax op for cls_score, Gemm op for bbox_pred.
            cls_score = self.graph.find_node_by_op_name("Softmax", "/roi_heads/Softmax")
            bbox_pred = self.graph.find_node_by_op_name("Gemm", "/roi_heads/box_predictor/bbox_pred/Gemm")

            # Linear transformation to convert box coordinates from (TopLeft, BottomRight) Corner encoding
            # to CenterSize encoding. 1st NMS boxes are multiplied by transformation matrix in order to 
            # encode it into CenterSize format.
            matmul_const = np.matrix('0.5 0 -1 0; 0 0.5 0 -1; 0.5 0 1 0; 0 0.5 0 1', dtype=np.float32)
            matmul_out = self.graph.matmul("RPN_NMS/detection_boxes_conversion", rpn_outputs[1], matmul_const)

            # Reshape node that prepares bbox_pred for scaling and second NMS.
            bbox_pred_shape = np.asarray([self.batch_size, self.first_NMS_max_proposals, self.num_classes, 4], dtype=np.int64)
            bbox_pred_reshape = self.graph.op_with_const("Reshape", "bbox_pred/reshape", bbox_pred.outputs[0], bbox_pred_shape)
            
            # 0.1, 0.1, 0.2, 0.2 are localization head variance numbers, they scale bbox_pred_reshape, in order to get accurate coordinates.
            scale_adj = np.expand_dims(np.asarray([0.1, 0.1, 0.2, 0.2], dtype=np.float32), axis=(0, 1))
            final_bbox_pred = self.graph.op_with_const("Mul", "bbox_pred/scale", bbox_pred_reshape[0], scale_adj)

            # Reshape node that prepares cls_score for slicing and second NMS.
            cls_score_shape = np.array([self.batch_size, self.first_NMS_max_proposals, self.num_classes+1], dtype=np.int64)
            cls_score_reshape = self.graph.op_with_const("Reshape", "cls_score/reshape", cls_score.outputs[0], cls_score_shape)
            
            # Slice operation to adjust third dimension of cls_score tensor, deletion of background class (81 in Detectron 2).
            final_cls_score = self.graph.slice("cls_score/slicer", cls_score_reshape[0], 0, self.num_classes, 2)

            # Create NMS node.
            nms_outputs = self.NMS(final_bbox_pred[0], final_cls_score[0], matmul_out[0], -1, False, self.second_NMS_max_proposals, self.second_NMS_iou_threshold, self.second_NMS_score_threshold, second_nms_threshold, 'box_outputs')

            # Create ROIAlign node. 
            mask_pooler_output = self.ROIAlign(nms_outputs[1], p2, p3, p4, p5, self.second_ROIAlign_pooled_size, self.second_ROIAlign_sampling_ratio, self.second_ROIAlign_type, self.second_NMS_max_proposals, 'mask_pooler')
            
            # Reshape mask pooler output. 
            mask_pooler_shape = np.asarray([self.second_NMS_max_proposals*self.batch_size, self.fpn_out_channels, self.second_ROIAlign_pooled_size, self.second_ROIAlign_pooled_size], dtype=np.int64)
            mask_pooler_reshape_node = self.graph.op_with_const("Reshape", "mask_pooler/reshape", mask_pooler_output, mask_pooler_shape)
            
            # Get first Conv op in mask head and connect ROIAlign's squeezed output to it. 
            mask_head_conv = self.graph.find_node_by_op_name("Conv", "/roi_heads/mask_head/mask_fcn1/Conv")
            mask_head_conv.inputs[0] = mask_pooler_reshape_node[0]
           
            # Reshape node that is preparing 2nd NMS class outputs for Add node that comes next.
            classes_reshape_shape = np.asarray([self.second_NMS_max_proposals*self.batch_size], dtype=np.int64)
            classes_reshape_node = self.graph.op_with_const("Reshape", "box_outputs/reshape_classes", nms_outputs[3], classes_reshape_shape)
            
            # This loop will generate an array used in Add node, which eventually will help Gather node to pick the single
            # class of interest per bounding box, instead of creating 80 masks for every single bounding box. 
            add_array = []
            for i in range(self.second_NMS_max_proposals*self.batch_size):
                if i == 0:
                    start_pos = 0
                else:
                    start_pos = i * self.num_classes
                add_array.append(start_pos)
            
            # This Add node is one of the Gather node inputs, Gather node performs gather on 0th axis of data tensor 
            # and requires indices that set tensors to be withing bounds, this Add node provides the bounds for Gather. 
            add_array = np.asarray(add_array, dtype=np.int32)
            classes_add_node = self.graph.op_with_const("Add", "box_outputs/add", classes_reshape_node[0], add_array)
            
            # Get the last Conv op in mask head and reshape it to correctly gather class of interest's masks. 
            last_conv = self.graph.find_node_by_op_name("Conv", "/roi_heads/mask_head/predictor/Conv")
            last_conv_reshape_shape = np.asarray([self.second_NMS_max_proposals*self.num_classes*self.batch_size, self.mask_out_res, self.mask_out_res], dtype=np.int64)
            last_conv_reshape_node = self.graph.op_with_const("Reshape", "mask_head/reshape_all_masks", last_conv.outputs[0], last_conv_reshape_shape)
            
            # Gather node that selects only masks belonging to detected class, 79 other masks are discarded. 
            final_gather = self.graph.gather("mask_head/final_gather", last_conv_reshape_node[0], classes_add_node[0], 0)
            
            # Get last Sigmoid node and connect Gather node to it. 
            mask_head_sigmoid = self.graph.find_node_by_op_name("Sigmoid", "/roi_heads/mask_head/Sigmoid")
            mask_head_sigmoid.inputs[0] = final_gather[0]
            
            # Final Reshape node, reshapes output of Sigmoid, important for various batch_size support (not tested yet).
            final_graph_reshape_shape = np.asarray([self.batch_size, self.second_NMS_max_proposals, self.mask_out_res, self.mask_out_res], dtype=np.int64)
            final_graph_reshape_node = self.graph.op_with_const("Reshape", "mask_head/final_reshape", mask_head_sigmoid.outputs[0], final_graph_reshape_shape)
            final_graph_reshape_node[0].dtype = np.float32
            final_graph_reshape_node[0].name = "detection_masks"

            return nms_outputs, final_graph_reshape_node[0]

        # Only Detectron 2's Mask-RCNN R50-FPN 3x is supported currently.
        p2, p3, p4, p5 = backbone()
        rpn_outputs = proposal_generator(anchors, first_nms_threshold)
        box_head_outputs, mask_head_output = roi_heads(rpn_outputs, p2, p3, p4, p5, second_nms_threshold)
        # Append segmentation head output.
        box_head_outputs.append(mask_head_output)
        # Set graph outputs, both bbox and segmentation heads.
        self.graph.outputs = box_head_outputs
        self.sanitize()        


def main(args):
    det2_gs = DET2GraphSurgeon(args.exported_onnx, args.det2_config, args.det2_weights)
    det2_gs.update_preprocessor(args.batch_size)
    anchors = det2_gs.get_anchors(args.sample_image)
    det2_gs.process_graph(anchors, args.first_nms_threshold, args.second_nms_threshold)
    det2_gs.save(args.onnx)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--exported_onnx", help="The exported to ONNX Detectron 2 Mask R-CNN", type=str)
    parser.add_argument("-o", "--onnx", help="The output ONNX model file to write", type=str)
    parser.add_argument("-c", "--det2_config", help="The Detectron 2 config file (.yaml) for the model", type=str)
    parser.add_argument("-w", "--det2_weights", help="The Detectron 2 model weights (.pkl)", type=str)
    parser.add_argument("-s", "--sample_image", help="Sample image for anchors generation", type=str)
    parser.add_argument("-b", "--batch_size", help="Batch size for the model", type=int, default=1)
    parser.add_argument("-t1", "--first_nms_threshold", help="Override the score threshold for the 1st NMS operation", type=float)
    parser.add_argument("-t2", "--second_nms_threshold", help="Override the score threshold for the 2nd NMS operation", type=float)
    args = parser.parse_args()
    if not all([args.exported_onnx, args.onnx, args.det2_config, args.det2_weights, args.sample_image]):
        parser.print_help()
        print("\nThese arguments are required: --exported_onnx --onnx --det2_config --det2_weights and --sample_image")
        sys.exit(1)
    main(args)

After that fix to the script you can convert from .pth to .trt engine.

@azhurkevich
Copy link
Contributor

@tomasctg can you point out specific changes

@tomasctg
Copy link

tomasctg commented Jan 6, 2023

@azhurkevich
1.

            p2 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output2/Conv")
            p3 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output3/Conv")
            p4 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output4/Conv")
            p5 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output5/Conv")
            p2_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten")
            p3_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_1")
            p4_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_2")
            p5_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_3")
            p6_logits = self.graph.find_node_by_op_name("Flatten", "/proposal_generator/Flatten_4")

            # Get nodes containing final anchor_deltas.
            p2_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_1")
            p3_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_3")
            p4_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_5")
            p5_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_7")
            p6_anchors = self.graph.find_node_by_op_name("Reshape", "/proposal_generator/Reshape_9")                    
            first_box_head_gemm = self.graph.find_node_by_op_name("Gemm", "/roi_heads/box_head/fc1/Gemm")
            cls_score = self.graph.find_node_by_op_name("Softmax", "/roi_heads/Softmax")
            bbox_pred = self.graph.find_node_by_op_name("Gemm", "/roi_heads/box_predictor/bbox_pred/Gemm")
            mask_head_conv = self.graph.find_node_by_op_name("Conv", "/roi_heads/mask_head/mask_fcn1/Conv")
            last_conv = self.graph.find_node_by_op_name("Conv", "/roi_heads/mask_head/predictor/Conv")
            mask_head_sigmoid = self.graph.find_node_by_op_name("Sigmoid", "/roi_heads/mask_head/Sigmoid")

@azhurkevich
Copy link
Contributor

@tomasctg Great stuff, bravo! I'll revisit the sample at some point when time would allow and will make changes. Thanks for helping people!

@tomasctg
Copy link

tomasctg commented Jan 7, 2023

@azhurkevich Let me know if it worked.

@Photoheyler
Copy link

@tomasctg your code is completely the same like @williamf-searidgetech already posted. I don't see any difference.

@tomasctg
Copy link

tomasctg commented Jan 12, 2023

so when I use a sample image with 1280x800 I can export the converted onnx and build the trt engine. But I get no Results. When I use a sample Image with 1344x1344. Ican also build the converted onnx. But Icannot build the Trt engine. I'm getting the following error. ERROR:EngineBuilder:In node 257 (parseGraph): INVALID_NODE: Invalid Node - box_pooler/reshape Attribute not found: allowzero

@Photoheyler Did you changed the lines

         aug = T.ResizeShortestEdge(
             [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
         )

for

        aug = T.ResizeShortestEdge(
            [1344, 1344], 1344
        )

on the script export_model.py.

@Photoheyler
Copy link

@tomasctg yes i changed it

@Photoheyler
Copy link

@tomasctg when I look in create_onnx.py

line 160-164

Image preprocessing.

    input_im = cv2.imread(sample_image)
    raw_height, raw_width = input_im.shape[:2]
    image = predictor.aug.get_transform(input_im).apply_image(input_im)
    image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

my Input_im is 1344x1344x3 but image is 3x800x800
Is this correct?

@tomasctg
Copy link

tomasctg commented Jan 12, 2023

@Photoheyler Yes that's correct. I share you the console output of the create_onnx.py conversion.

INFO:ModelHelper:ONNX graph loaded successfully
INFO:ModelHelper:Number of FPN output channels is 256
INFO:ModelHelper:Number of classes is 49
INFO:ModelHelper:First NMS max proposals is 1000
INFO:ModelHelper:First NMS iou threshold is 0.7
INFO:ModelHelper:First NMS score threshold is 0.01
INFO:ModelHelper:First ROIAlign type is ROIAlignV2
INFO:ModelHelper:First ROIAlign pooled size is 7
INFO:ModelHelper:First ROIAlign sampling ratio is 0
INFO:ModelHelper:Second NMS max proposals is 100
INFO:ModelHelper:Second NMS iou threshold is 0.5
INFO:ModelHelper:Second NMS score threshold is 0.85
INFO:ModelHelper:Second ROIAlign type is ROIAlignV2
INFO:ModelHelper:Second ROIAlign pooled size is 14
INFO:ModelHelper:Second ROIAlign sampling ratio is 0
INFO:ModelHelper:Individual mask output resolution is 28x28
.
.
.
.
INFO:ModelHelper:ONNX graph input shape: [1, 3, 1344, 1344] [NCHW format set]
INFO:ModelHelper:Found Div node
INFO:ModelHelper:Found Conv node

And the conversion to TensorRT engine

INFO:EngineBuilder:Input 'input_tensor' with shape (1, 3, 1344, 1344) and dtype DataType.FLOAT
INFO:EngineBuilder:Output 'num_detections_box_outputs' with shape (1, 1) and dtype DataType.INT32
INFO:EngineBuilder:Output 'detection_boxes_box_outputs' with shape (1, 100, 4) and dtype DataType.FLOAT
INFO:EngineBuilder:Output 'detection_scores_box_outputs' with shape (1, 100) and dtype DataType.FLOAT
INFO:EngineBuilder:Output 'detection_classes_box_outputs' with shape (1, 100) and dtype DataType.INT32
INFO:EngineBuilder:Output 'detection_masks' with shape (1, 100, 28, 28) and dtype DataType.FLOAT

You can check the input shape of each stage of conversion with https://netron.app/.

@Photoheyler
Copy link

Photoheyler commented Jan 12, 2023

@tomasctg
Hi thnx for the hint, but the error stays still the same.
Can you send me the comlete output of build_engine.py

[01/12/2023-19:15:53] [TRT] [I] [MemUsageChange] Init CUDA: CPU +433, GPU +0, now: CPU 15085, GPU 1157 (MiB)
[01/12/2023-19:15:54] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +337, GPU +104, now: CPU 15614, GPU 1261 (MiB)
C:\Users\Photoheyler\Desktop\YoloV7Instance\TensorRT\samples\python\detectron2\build_engine.py:147: DeprecationWarning: Use set_memory_pool_limit instead.
self.config.max_workspace_size = workspace * (2 ** 30)
[01/12/2023-19:15:54] [TRT] [W] onnx2trt_utils.cpp:369: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/12/2023-19:15:54] [TRT] [I] No importer registered for op: EfficientNMS_TRT. Attempting to import as plugin.
[01/12/2023-19:15:54] [TRT] [I] Searching for plugin: EfficientNMS_TRT, plugin_version: 1, plugin_namespace:
[01/12/2023-19:15:54] [TRT] [I] Successfully created plugin: EfficientNMS_TRT
[01/12/2023-19:15:54] [TRT] [I] No importer registered for op: PyramidROIAlign_TRT. Attempting to import as plugin.
[01/12/2023-19:15:54] [TRT] [I] Searching for plugin: PyramidROIAlign_TRT, plugin_version: 1, plugin_namespace:
[01/12/2023-19:15:54] [TRT] [W] builtin_op_importers.cpp:4716: Attribute roi_coords_plusone not found in plugin node! Ensure that the plugin creator has a default value defined or the engine may fail to build.
[01/12/2023-19:15:54] [TRT] [W] builtin_op_importers.cpp:4716: Attribute legacy not found in plugin node! Ensure that the plugin creator has a default value defined or the engine may fail to build.
[01/12/2023-19:15:54] [TRT] [I] Successfully created plugin: PyramidROIAlign_TRT
ERROR:EngineBuilder:Failed to load ONNX file: C:\Users\Photoheyler\Desktop\YoloV7Instance\detectron2\tools\deploy\output\converted.onnx
ERROR:EngineBuilder:In node 257 (parseGraph): INVALID_NODE: Invalid Node - box_pooler/reshape
Attribute not found: allowzero
__

This is my output of create_onnx.py (it's completely the same):

_INFO:ModelHelper:ONNX graph loaded successfully
INFO:ModelHelper:Number of FPN output channels is 256
INFO:ModelHelper:Number of classes is 49
INFO:ModelHelper:First NMS max proposals is 1000
INFO:ModelHelper:First NMS iou threshold is 0.7
INFO:ModelHelper:First NMS score threshold is 0.01
INFO:ModelHelper:First ROIAlign type is ROIAlignV2
INFO:ModelHelper:First ROIAlign pooled size is 7
INFO:ModelHelper:First ROIAlign sampling ratio is 0
INFO:ModelHelper:Second NMS max proposals is 100
INFO:ModelHelper:Second NMS iou threshold is 0.5
INFO:ModelHelper:Second NMS score threshold is 0.7
INFO:ModelHelper:Second ROIAlign type is ROIAlignV2
INFO:ModelHelper:Second ROIAlign pooled size is 14
INFO:ModelHelper:Second ROIAlign sampling ratio is 0
INFO:ModelHelper:Individual mask output resolution is 28x28
...
...
C:\Users\Photoheyler\Desktop\YoloV7Instance\venv_tensorrt\lib\site-packages\torch\functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:3191.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
INFO:ModelHelper:Created nms_rpn with EfficientNMS_TRT plugin
INFO:ModelHelper:Created box_pooler with PyramidROIAlign_TRT plugin
INFO:ModelHelper:Created nms_box_outputs with EfficientNMS_TRT plugin
INFO:ModelHelper:Created mask_pooler with PyramidROIAlign_TRT plugin
INFO:ModelHelper:Shape inference could not be performed at this time:
Input 1 is out of bounds.
INFO:ModelHelper:Shape inference could not be performed at this time:
Input 1 is out of bounds.
INFO:ModelHelper:Saved ONNX model to C:\Users\Photoheyler\Desktop\YoloV7Instance\detectron2\tools\deploy\output\converted.onnx

_

@azhurkevich
Copy link
Contributor

@Photoheyler TRT engines are built with a specific GPU in mind and version of TRT. Highly likely even if @tomasctg will give you prebuilt TRT engine, it will not work. It will work if you match his TRT version and if you have the same GPU. GPUs of same family are supported (for example engine built on 3090 and used on 3060, both are Ampere), but behavior might be unpredictable hence we do not guarantee this approach.

@Photoheyler
Copy link

@azhurkevich I know.
But that's no reason why the conversion to trt engine doesn't work.
I use Trt for YoloV7 on my RTX 3060 and it works perfectly without issues.
I only got some issues with this type of model and I got no Idea what I have to look for.

@azhurkevich
Copy link
Contributor

Visualize your converted graph and compare it with graph here. They should be the same. Also closely follow my instructions on the README, 90% of the problems come from people not following the instructions.

@predrag12
Copy link

Hi, thank you for the custom create_onnx.py. Had several issues that was able to overcome by building pytorch, torchvision and detecton2 from scratch, modifying some config files for supported Cuda arch and for onnx opset version. Currently having problem at:

[E] No function: find_node_by_op_name registered for opset: 16
Traceback (most recent call last):
  File "create_onnx.py", line 545, in <module>
    main(args)
  File "create_onnx.py", line 526, in main
    det2_gs.process_graph(anchors, args.first_nms_threshold, args.second_nms_threshold)
  File "create_onnx.py", line 512, in process_graph
    p2, p3, p4, p5 = backbone()
  File "create_onnx.py", line 358, in backbone
    p2 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output2/Conv")
  File "/home/me/.local/lib/python3.8/site-packages/onnx_graphsurgeon/ir/graph.py", line 147, in __getattr__
    raise err
  File "/home/me/.local/lib/python3.8/site-packages/onnx_graphsurgeon/ir/graph.py", line 137, in __getattr__
    return super().__getattribute__(name)
AttributeError: 'Graph' object has no attribute 'find_node_by_op_name'

Examining self.graph:
GLOBAL_FUNC_MAP = {dict: 14}
...
'find_node_by_op' = {function} <function find_node_by_op at 0xfffeda38bc10>
'find_node_by_op_input_output_name' = {function} <function find_node_by_op_input_output_name at 0xfffeda38bca0>
'find_descendant_by_op' = {function} <function find_descendant_by_op at 0xfffeda38bd30>
'find_ancestor_by_op' = {function} <function find_ancestor_by_op at 0xfffeda38bdc0>
...
DEFAULT_OPSET = {int} 11
opset = {int} 14

Changing opset in detectron2 (16, 14, 11), rebuilding and running export_model.py again does not make a difference for create_onnx.py error. Using TensorRT 8.5.1, Cuda 11.4, onnx 1.13, onnx-graphsurgeoun 0.3.25, Pytorch 1.13, torchvision 0.14, detectron2 0.6. Can run some other TensorRT sample models. But for detectron2, there is missing specificity and conflicting apt and pip versions between README.md, requirements.txt, NGC container including pytorch, torchvision, onnx, onnxruntime, onnxoptimizer, onnx-graphsurgeoun, pycuda, cuda-python. To harmonize versions of c/c++ and python libs that worked successfully, would you be able to list those? Thanks.
apt list --installed | grep cuda
pip list

@azhurkevich
Copy link
Contributor

azhurkevich commented Feb 9, 2023

@tomasctg Got a confirmation that your modifications work. Fix is en route, for now please use these modifications to make det2 sample work. Great work!

@SangbumChoi
Copy link

@predrag12 you need to import onnx_utils that is defined in the repository

@ttyio
Copy link
Collaborator

ttyio commented Nov 23, 2023

closing since this is solved, thanks all!

@ttyio ttyio closed this as completed Nov 23, 2023
@Helloiamwin
Copy link

Helloiamwin commented Dec 27, 2023

Image Pasted at 2023-12-27 09-43
hi everyone, im use model faster r-cnn. I just trained model with detectron2 complete, so i converted model detectron to onnx with tracing but im have a error when im convert model onnx to onnx graph, i really want to helpful. I use github https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples/python/detectron2 to convert model onnx-grapj and to TensoRT. I think Tensorrt does support this convert. Thanks all
p2, p3, p4, p5 = backbone()
File "create_onnx.py", line 365, in backbone
return p2.outputs[0], p3.outputs[0], p4.outputs[0], p5.outputs[0]
AttributeError: 'NoneType' object has no attribute 'outputs'

@Aaronponceuv
Copy link

Image Pasted at 2023-12-27 09-43 hi everyone, im use model faster r-cnn. I just trained model with detectron2 complete, so i converted model detectron to onnx with tracing but im have a error when im convert model onnx to onnx graph, i really want to helpful. I use github https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples/python/detectron2 to convert model onnx-grapj and to TensoRT. I think Tensorrt does support this convert. Thanks all p2, p3, p4, p5 = backbone() File "create_onnx.py", line 365, in backbone return p2.outputs[0], p3.outputs[0], p4.outputs[0], p5.outputs[0] AttributeError: 'NoneType' object has no attribute 'outputs'

@Helloiamwin The same thing happened to me when trying to transform. It happens that the names of the operators change according to the version of detectron2. Most likely, in create_onnx.py in the backbone() function you are setting the name of the operator incorrectly and it cannot be found in self.graph.find_node_by_op_name and that is why it does not give you output

p2 = self.graph.find_node_by_op_name("Conv", "Conv_281") 
p3 = self.graph.find_node_by_op_name("Conv", "Conv_277")
p4 = self.graph.find_node_by_op_name("Conv", "Conv_273")
p5 = self.graph.find_node_by_op_name("Conv", "Conv_269")

@azhurkevich
Copy link
Contributor

I tend to agree with @Aaronponceuv (I haven't looked into latest det2 exported model). Most likely it's just names of the ops that you need have changed, so selecting correct ones will fix it. The question is how to select the correct ones? The most reliable method is to roll back to old det2 versions until you make sure it works. Visualize ONNX graph with netron pre-conversion and after plug-in insertion. This will give you a context and which operators specifically were chosen. Visualize latest pre-conversion ONNX and look for general graph patterns between old one and new one. I guarantee you'll quickly get a sense of which new ops are supposed to be captured. Change their names in the snippet above -> profit 😊 . I know it will take some work, but it's the fastest work around.

Please do not hesitate to submit a PR with a fix, we are open to accept commits. We'll review internally and merge it. Thank you.

@UcanYusuf
Copy link

Hi, thank you for the custom create_onnx.py. Had several issues that was able to overcome by building pytorch, torchvision and detecton2 from scratch, modifying some config files for supported Cuda arch and for onnx opset version. Currently having problem at:

[E] No function: find_node_by_op_name registered for opset: 16
Traceback (most recent call last):
  File "create_onnx.py", line 545, in <module>
    main(args)
  File "create_onnx.py", line 526, in main
    det2_gs.process_graph(anchors, args.first_nms_threshold, args.second_nms_threshold)
  File "create_onnx.py", line 512, in process_graph
    p2, p3, p4, p5 = backbone()
  File "create_onnx.py", line 358, in backbone
    p2 = self.graph.find_node_by_op_name("Conv", "/backbone/fpn_output2/Conv")
  File "/home/me/.local/lib/python3.8/site-packages/onnx_graphsurgeon/ir/graph.py", line 147, in __getattr__
    raise err
  File "/home/me/.local/lib/python3.8/site-packages/onnx_graphsurgeon/ir/graph.py", line 137, in __getattr__
    return super().__getattribute__(name)
AttributeError: 'Graph' object has no attribute 'find_node_by_op_name'

Examining self.graph: GLOBAL_FUNC_MAP = {dict: 14} ... 'find_node_by_op' = {function} <function find_node_by_op at 0xfffeda38bc10> 'find_node_by_op_input_output_name' = {function} <function find_node_by_op_input_output_name at 0xfffeda38bca0> 'find_descendant_by_op' = {function} <function find_descendant_by_op at 0xfffeda38bd30> 'find_ancestor_by_op' = {function} <function find_ancestor_by_op at 0xfffeda38bdc0> ... DEFAULT_OPSET = {int} 11 opset = {int} 14

Changing opset in detectron2 (16, 14, 11), rebuilding and running export_model.py again does not make a difference for create_onnx.py error. Using TensorRT 8.5.1, Cuda 11.4, onnx 1.13, onnx-graphsurgeoun 0.3.25, Pytorch 1.13, torchvision 0.14, detectron2 0.6. Can run some other TensorRT sample models. But for detectron2, there is missing specificity and conflicting apt and pip versions between README.md, requirements.txt, NGC container including pytorch, torchvision, onnx, onnxruntime, onnxoptimizer, onnx-graphsurgeoun, pycuda, cuda-python. To harmonize versions of c/c++ and python libs that worked successfully, would you be able to list those? Thanks. apt list --installed | grep cuda pip list

I'm having this problem too. AttributeError: 'Graph' object has no attribute 'find_node_by_op_name'. When I examine the contents of the onnx_graphsurgeon/ir/graph.py file, there is no such function. How can I solve this problem?

@azhurkevich
Copy link
Contributor

@UcanYusuf find_node_by_op_name is located here it is registered as a GS graph function, something is not right with typing. I would also recommend upgrading to the last version of TRT, just run with latest NGC container nvcr.io/nvidia/pytorch:23.12-py3 or downgrade to older torch if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests