# Convert mask-rcnn Torch model to ONNX

Most of the code is taken from https://github.com/phamquiluan/PubLayNet/blob/master/maskrcnn/infer.py

TODO - optimise model:
- https://github.com/microsoft/onnxruntime/issues/1899

In [1]:
import random
import os
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import transforms
from onnx import optimizer
import onnx

In [2]:
# Download checkpoint from https://github.com/phamquiluan/PubLayNet
checkpoint_path = r'D:\MachineLearning\Models\phamquiluan-PubLayNet\model_196000.pth'
num_classes = 6
CATEGORIES2LABELS = { 0: "bg", 1: "text", 2: "title", 3: "list", 4: "table", 5: "figure" }

In [3]:
def get_instance_segmentation_model(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256

    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )
    model.training = False
    return model

seed = 1234
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
model = get_instance_segmentation_model(num_classes)

assert os.path.exists(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [None]:
image = torch.randn(3, 1300, 1300, requires_grad=True)
torch.onnx.export(model, [image], 
                  'model_196000_v12.onnx', 
                  opset_version=12,
                  do_constant_folding=True,
                  input_names = ['image'],
                  output_names = ['boxes', 'labels', 'scores', 'masks'],
                  dynamic_axes=
                  {
                      'masks' : {0 : 'pred'},
                      'boxes' : {0 : 'pred'},
                      'labels' : {0 : 'pred'},
                      'scores' : {0 : 'pred'},
                  })

In [7]:
with torch.no_grad():
    model_path = 'model_196000_before_opt.onnx'
    
    for param in model.parameters():
        param.requires_grad = False

    for param in model.roi_heads.box_predictor.parameters():
        param.requires_grad = False
    
    for param in model.roi_heads.mask_predictor.parameters():
        param.requires_grad = False
        
    model.eval()
    image = torch.randn(3, 1300, 1300, requires_grad=False)
    torch.onnx.export(model,
                      [image], 
                      model_path, 
                      opset_version=11,
                      do_constant_folding=True,
                      input_names = ['image'],
                      output_names = ['boxes', 'labels', 'scores', 'masks'],
                      dynamic_axes=
                      {
                          'masks' : {0 : 'nbox'},
                          'boxes' : {0 : 'nbox'},
                          'labels' : {0 : 'nbox'},
                          'scores' : {0 : 'nbox'},
                      })

In [None]:
# https://github.com/BowenBao/maskrcnn-benchmark/blob/onnx_stage/demo/export_to_onnx.py
#postprocess_model(model_path)

In [None]:
#https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py
model_path = 'model_196000_opt.onnx'
model = onnx.load(model_path)

inputs = model.graph.input
name_to_input = {}
for input in inputs:
    name_to_input[input.name] = input

for initializer in model.graph.initializer:
    if initializer.name in name_to_input:
        inputs.remove(name_to_input[initializer.name])

onnx.save(model, model_path)

In [None]:
def add_value_info_for_constants(model : onnx.ModelProto):
    """
    Currently onnx.shape_inference doesn't use the shape of initializers, so add
    that info explicitly as ValueInfoProtos.
    Mutates the model.
    Args:
        model: The ModelProto to update.
    """
    # All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
    if model.ir_version < 4:
        return

    def add_const_value_infos_to_graph(graph : onnx.GraphProto):
        inputs = {i.name for i in graph.input}
        existing_info = {vi.name: vi for vi in graph.value_info}
        for init in graph.initializer:
            # Check it really is a constant, not an input
            if init.name in inputs:
                continue

            # The details we want to add
            elem_type = init.data_type
            shape = init.dims

            # Get existing or create new value info for this constant
            vi = existing_info.get(init.name)
            if vi is None:
                vi = graph.value_info.add()
                vi.name = init.name

            # Even though it would be weird, we will not overwrite info even if it doesn't match
            tt = vi.type.tensor_type
            if tt.elem_type == onnx.TensorProto.UNDEFINED:
                tt.elem_type = elem_type
            if not tt.HasField("shape"):
                # Ensure we set an empty list if the const is scalar (zero dims)
                tt.shape.dim.extend([])
                for dim in shape:
                    tt.shape.dim.add().dim_value = dim

        # Handle subgraphs
        for node in graph.node:
            for attr in node.attribute:
                # Ref attrs refer to other attrs, so we don't need to do anything
                if attr.ref_attr_name != "":
                    continue

                if attr.type == onnx.AttributeProto.GRAPH:
                    add_const_value_infos_to_graph(attr.g)
                if attr.type == onnx.AttributeProto.GRAPHS:
                    for g in attr.graphs:
                        add_const_value_infos_to_graph(g)


    return add_const_value_infos_to_graph(model.graph)

In [None]:
model_path = 'model_196000_before_opt.onnx'
onnx_model = onnx.load(model_path)
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]

onnx_model = optimizer.optimize(onnx_model)#, passes)
onnx.save(onnx_model, 'model_196000_after_opt.onnx')

In [None]:
# https://github.com/microsoft/onnxruntime/issues/1899
# https://github.com/onnx/onnx/issues/2903
# https://github.com/microsoft/onnxruntime/issues/4033

model_path = 'model_196000_before_opt.onnx'

onnx_model = onnx.load(model_path)
add_value_info_for_constants(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)

for init in onnx_model.graph.initializer:
    for value_info in onnx_model.graph.value_info:
        if init.name == value_info.name:
            onnx_model.graph.input.append(value_info)
            
# Handle subgraphs
for node in onnx_model.graph.node:
    for attr in node.attribute:
        # Ref attrs refer to other attrs, so we don't need to do anything
        if attr.ref_attr_name != "":
            continue

        if attr.type == onnx.AttributeProto.GRAPH:
            for initializer in attr.g.initializer:
                for value_info in attr.g.value_info:
                    if init.name == value_info.name:
                        attr.g.input.append(value_info)
            
        if attr.type == onnx.AttributeProto.GRAPHS:
            for g in attr.graphs:
                for init in g.initializer:
                    for value_info in g.value_info:
                        if init.name == value_info.name:
                            g.input.append(value_info)
                        

passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]

onnx_model = optimizer.optimize(onnx_model)#, passes)

inputs = onnx_model.graph.input
name_to_input = {}
for input in inputs:
    name_to_input[input.name] = input

for initializer in onnx_model.graph.initializer:
    if initializer.name in name_to_input:
        inputs.remove(name_to_input[initializer.name])

# Handle subgraphs
for node in onnx_model.graph.node:
    for attr in node.attribute:
        # Ref attrs refer to other attrs, so we don't need to do anything
        if attr.ref_attr_name != "":
            continue

        if attr.type == onnx.AttributeProto.GRAPH:
            inputs = attr.g.input
            name_to_input = {}
            for input in inputs:
                name_to_input[input.name] = input

            for initializer in attr.g.initializer:
                if initializer.name in name_to_input:
                    inputs.remove(name_to_input[initializer.name])
        
        if attr.type == onnx.AttributeProto.GRAPHS:
            for g in attr.graphs:
                inputs = g.input
                name_to_input = {}
                for input in inputs:
                    name_to_input[input.name] = input

                for initializer in g.initializer:
                    if initializer.name in name_to_input:
                        inputs.remove(name_to_input[initializer.name])
                        
onnx.save(onnx_model, 'model_196000_after_opt.onnx')

In [None]:
def remove_unused_floor(model):
    nodes = model.graph.node

    for i, n in enumerate(nodes):
        n.name = str(i)

    floor_nodes = [node for node in nodes if node.op_type=='Floor']

    for f in floor_nodes:
        in_id = f.input[0]
        out_id = f.output[0]
        in_n = [node for node in nodes if node.output == [in_id]][0]
        if in_n.op_type == 'Mul':
            out_n = [node for node in nodes if node.input == [out_id]][0]
            out_n.input[0] = in_n.output[0]
            nodes.remove(f)
            print(f)

    return model