In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
import onnx
import torch.onnx

In [2]:
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10, init_weights=True):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(in_channels=5, out_channels=5, kernel_size=4, stride=2)

        self.linear1 = nn.Linear(in_features=5 * 2 * 2, out_features=10)
        self.linear2 = nn.Linear(in_features=10, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

In [3]:
model = SimpleNet()
    
torchinfo.summary(model, input_size=(16, 1, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
SimpleNet                                [16, 10]                  --
├─Conv2d: 1-1                            [16, 3, 15, 15]           51
├─Conv2d: 1-2                            [16, 5, 6, 6]             245
├─Conv2d: 1-3                            [16, 5, 2, 2]             405
├─Linear: 1-4                            [16, 10]                  210
├─Linear: 1-5                            [16, 10]                  110
Total params: 1,021
Trainable params: 1,021
Non-trainable params: 0
Total mult-adds (M): 0.36
Input size (MB): 0.07
Forward/backward pass size (MB): 0.11
Params size (MB): 0.00
Estimated Total Size (MB): 0.18

In [4]:
x = torch.randn(16, 1, 32, 32, requires_grad=True)
model = model.cpu()
x = x.cpu()

# Export the model
torch.onnx.export(model,                     # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",              # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=13,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

In [5]:
def onnx_check_model(onnx_model):
    try:
        onnx.checker.check_model(onnx_model)
    except onnx.checker.ValidationError as e:
        print('The model is invalid: %s' % e)
    else:
        print('The model is valid!')

model_onnx = onnx.load("model.onnx")
onnx_check_model(model_onnx)

The model is valid!


In [28]:
print(model_onnx.graph.initializer[0].dims)
print(model_onnx.graph.initializer[1].dims)
print(model_onnx.graph.initializer[2].dims)
print(model_onnx.graph.initializer[3].dims)

[3, 1, 4, 4]
[3]
[5, 3, 4, 4]
[5]


In [20]:
model_onnx.graph.node[0]

input: "input"
input: "conv1.weight"
input: "conv1.bias"
output: "x"
name: "Conv_0"
op_type: "Conv"
attribute {
  name: "dilations"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "group"
  i: 1
  type: INT
}
attribute {
  name: "kernel_shape"
  ints: 4
  ints: 4
  type: INTS
}
attribute {
  name: "pads"
  ints: 0
  ints: 0
  ints: 0
  ints: 0
  type: INTS
}
attribute {
  name: "strides"
  ints: 2
  ints: 2
  type: INTS
}

In [158]:
from collections import OrderedDict
from enum import IntEnum
from typing import Dict, Optional, List

class TYPEDEF:
    ONNX_MODEL = onnx.onnx_ml_pb2.ModelProto
    ONNX_NODE = onnx.onnx_ml_pb2.NodeProto
    ONNX_IR = OrderedDict

class OPTYPE(IntEnum):
    CONV2D = 0
    RELU = 1
    FLATTEN = 2
    LINEAR = 3

    def __repr__(self):
        return self.name
    
    def __str__(self):
        return self.name

class DTYPE(IntEnum):
    # https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L483-L485
    FLOAT = 1,   # float
    INT8 =  3,   # int8_t
    INT16 = 4,   # int16_t
    INT32 = 6,   # int32_t
    INT64 = 7    # int64_t

    def __repr__(self):
        return self.name
    
    def __str__(self):
        return self.name

def dump_onnx_model(ir: TYPEDEF.ONNX_IR) -> None:
    pass

def get_attributes(node: TYPEDEF.ONNX_NODE) -> OrderedDict:
    return OrderedDict()

def get_dtype(data_type: int) -> DTYPE:
    if data_type == DTYPE.FLOAT:
        return DTYPE.FLOAT
    elif data_type == DTYPE.INT8:
        return DTYPE.INT8
    elif data_type == DTYPE.INT16:
        return DTYPE.INT16
    elif data_type == DTYPE.INT32:
        return DTYPE.INT32
    else:
        return DTYPE.INT64

def get_optype(node: TYPEDEF.ONNX_NODE) -> OPTYPE:
    optype = node.op_type.lower()

    if optype == "conv":
        return OPTYPE.CONV2D
    elif optype == "relu":
        return OPTYPE.RELU
    elif optype == "gemm":
        return OPTYPE.LINEAR
    elif optype == "flatten":
        return OPTYPE.FLATTEN
    else:
        raise NotImplementedError("get_optype: Operator type not implemented yet.")    

def get_initializer(node: TYPEDEF.ONNX_NODE) -> Dict[Dict, Optional[Dict]]:
    """
    {
       weight: {
            name:
            raw_data:
            dims:
            data_type:
       },
       bias: {
            name:
            raw_data:
            dims:
            data_type:
       }   
    }
    """
    initializer = dict()
    initializer["weight"], initializer["bias"] = {}, {}
    
    for inp in node.input:
        for init in model_onnx.graph.initializer:
            if inp == init.name:
                if len(init.dims) > 1:
                    initializer["weight"]["name"] = init.name
                    initializer["weight"]["raw_data"] = init.raw_data
                    initializer["weight"]["dims"] = list(init.dims)
                    initializer["weight"]["data_type"] = get_dtype(init.data_type)
                else:
                    initializer["bias"]["name"] = init.name
                    initializer["bias"]["raw_data"] = init.raw_data
                    initializer["bias"]["dims"] = list(init.dims)
                    initializer["bias"]["data_type"] = get_dtype(init.data_type)

    return initializer

def get_inputs(node: TYPEDEF.ONNX_NODE) -> List[int]:
    pass

def get_outputs(node: TYPEDEF.ONNX_NODE) -> List[int]:
    pass

def parse_onnx_model(model: TYPEDEF.ONNX_MODEL) -> OrderedDict:
    """
    OrderedDict
    {
        inputs: []

        node_ith: {
            op_type: OPTYPE
            name: str 
            inputs: List
            outputs: List
            initializer: Dict
            attributes: {}
            data_layout: [NCHW, NHWC]
        }

        outputs: []
    }

    Returns an intermediate representation of ONNX model
    """
    ir = OrderedDict()

    # ir["inputs"] = [inp.name for inp in model.graph.input]
    
    for i, node in enumerate(model.graph.node):
        ir[i] = {}
        ir[i]["name"] = node.name
        ir[i]["op_type"] = get_optype(node)
        ir[i]["inputs"] = get_inputs(node)
        ir[i]["outputs"] = get_outputs(node)
        ir[i]["initializer"] = get_initializer(node)
        ir[i]["attributes"] = get_attributes(node)
        ir[i]["data_layout"] = None
        

    # ir["outputs"] = [out.name for out in model.graph.output]

    return ir

ir = parse_onnx_model(model_onnx)

In [159]:
ir

OrderedDict([(0,
              {'name': 'Conv_0',
               'op_type': CONV2D,
               'inputs': None,
               'outputs': None,
               'initializer': {'weight': {'name': 'conv1.weight',
                 'raw_data': b'\x9e@\x08>@M\x16>\xd4\xfa\xe3\xbd,\xc4\x86=d\x8c\xa4\xbd6\x18w\xbe\x10c\xe1\xbd\x0c\x88\xde\xbd\x00j5\xbd\xa8\xb0\x84=\x8e\x0c:\xbe\xecee\xbe\xf2Y9\xbe\x16a\x1a\xbe\xd0\xf6\xb7<\x8ek\x00>\xa0\xc0\xa1\xbd@\xecW>\xc8\x8e\xc9\xbd\x98\xd8\x97=\xa88\\>\x0e\x94d\xbe\x14\xc8\xe0=\x00\x83\x94;b\xf2(>\x16G\x0c\xbe,\xe6i>x87\xbex\xb9*>\xe0VI<\x80\xfd\xd9<4\x8bJ\xbe\xe2AU\xbe\x9a\xdfy>L\xb7\xff=\xc0\xe6\x97\xbb\x0e&P\xbe\xb4\x84\xe5=\x80.\x86\xbb\xfc\x027><x\xe4=\xc0\xb8\xf1=d\xef\x8a\xbd\x88\x17\x91=@w\xa4=\xe0\xa2\x05=$\x88u>\xc0\xc12>',
                 'dims': [3, 1, 4, 4],
                 'data_type': FLOAT},
                'bias': {'name': 'conv1.bias',
                 'raw_data': b'$\xa8\xe1\xbd\xb6\xe0;>\xc0_\x98=',
                 'dims': [3],


In [165]:
model_onnx.graph.node[2]

input: "input.1"
input: "conv2.weight"
input: "conv2.bias"
output: "x.3"
name: "Conv_2"
op_type: "Conv"
attribute {
  name: "dilations"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "group"
  i: 1
  type: INT
}
attribute {
  name: "kernel_shape"
  ints: 4
  ints: 4
  type: INTS
}
attribute {
  name: "pads"
  ints: 0
  ints: 0
  ints: 0
  ints: 0
  type: INTS
}
attribute {
  name: "strides"
  ints: 2
  ints: 2
  type: INTS
}

In [115]:
for inp in model_onnx.graph.node[0].input:
    for initializer in model_onnx.graph.initializer:
        if inp == initializer.name:
            print(inp)

conv1.weight
conv1.bias


In [160]:
model_onnx.graph.node[0].input

['input', 'conv1.weight', 'conv1.bias']