In [135]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
import onnx
import torch.onnx

In [136]:
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10, init_weights=True):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(in_channels=5, out_channels=5, kernel_size=4, stride=2)

        self.linear1 = nn.Linear(in_features=5 * 2 * 2, out_features=10)
        self.linear2 = nn.Linear(in_features=10, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

In [137]:
model = SimpleNet()
    
torchinfo.summary(model, input_size=(16, 1, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
SimpleNet                                [16, 10]                  --
├─Conv2d: 1-1                            [16, 3, 15, 15]           51
├─Conv2d: 1-2                            [16, 5, 6, 6]             245
├─Conv2d: 1-3                            [16, 5, 2, 2]             405
├─Linear: 1-4                            [16, 10]                  210
├─Linear: 1-5                            [16, 10]                  110
Total params: 1,021
Trainable params: 1,021
Non-trainable params: 0
Total mult-adds (M): 0.36
Input size (MB): 0.07
Forward/backward pass size (MB): 0.11
Params size (MB): 0.00
Estimated Total Size (MB): 0.18

In [138]:
x = torch.randn(16, 1, 32, 32, requires_grad=True)
model = model.cpu()
x = x.cpu()

# Export the model
torch.onnx.export(model,                     # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",              # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=13,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

In [139]:
def onnx_check_model(onnx_model):
    try:
        onnx.checker.check_model(onnx_model)
    except onnx.checker.ValidationError as e:
        print('The model is invalid: %s' % e)
    else:
        print('The model is valid!')

model_onnx = onnx.load("model.onnx")
onnx_check_model(model_onnx)

The model is valid!


In [140]:
print(model_onnx.graph.initializer[0].dims)
print(model_onnx.graph.initializer[1].dims)
print(model_onnx.graph.initializer[2].dims)
print(model_onnx.graph.initializer[3].dims)

[3, 1, 4, 4]
[3]
[5, 3, 4, 4]
[5]


In [143]:
from converter import parse_onnx_model

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [144]:
parse_onnx_model(model_onnx)

OrderedDict([('input_nodes', ['input']),
             ('Conv_0',
              {'op_type': CONV2D,
               'initializer': {'weight': {'name': 'conv1.weight',
                 'raw_data': b'\xb0\xfeE= \xfa*\xbc\xb6\xf4)\xbe\xc0\xf3\xb0\xbc\xfc8Z\xbeT\xe3\xaf=\xec*\x8b\xbd\xb4\xb4\x12>\x04\x19\x1e\xbe\xd6\x0f\x04>$\x9f]>\x08\x04\x16\xbd\x082\xaa\xbd&\xefN\xbe\xc0\xf3\x8b;\xb4\x9d\xe8=\x00\xf4\xab\xbb,\x8eI>*Y>>\xf4\xae\xe4\xbd\xa8`\x93\xbd\x04\xc8\x87\xbd@\xf4\xe7<:\'\x04>\x98\x8a\xa6\xbd<Ss>P\xa6\xac\xbdn\rS>~\x0f\x1e\xbe\x00z\xc2\xbd\xf6M%>f\xcbO>"?\x01>0\xb1\x9e\xbcD3\xcb=\x00\xd8G>DD\xa9=\xd8\xf0\xaf\xbd\xc0\xbdM\xbd\xf8\x8d\xb8\xbdf\xedr\xbe\xb4\x80\xd5=\xd6\xfe\x07>\x10*B>b&\r\xbe\xd8;5\xbd\x00\xd5\xe4\xbcpx\xc2<',
                 'dims': [3, 1, 4, 4],
                 'data_type': FLOAT},
                'bias': {'name': 'conv1.bias',
                 'raw_data': b'l\x98\x0f>\xc8\xde\xa0=\xc0\x81U\xbd',
                 'dims': [3],
                 'data_type': FLOAT}},
 