## Exporting a convnet using ONNX in PyTorch

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
import numpy as np
VERBOSE = True  # Suppress printing of mini-batch losses

cpu


Note the 2-D input_size.  Read the MNIST dataset and resizing the images to 32x32 pixels with zero padding includes.  They are grayscale already, so the input will be 32x32x1.

In [4]:
# Hyper Parameters 
input_size = (32, 32)
num_classes = 10
num_epochs = 10
batch_size_train = 256
batch_size_val = 256
batch_size_test = 1024
learning_rate = 2e-3
num_folds = 6  # V-fold cross validation!
v = 4  # The filter hyperparameter.  The number of activation maps is dependent.
torch.set_printoptions(threshold=1000)

This initially downloads two datasets, one for training and validation (called train_dataset) and one for test.

In [6]:
# Load image data and transform images to 32x32x1
train_dataset = dsets.MNIST(root='./data',
                         train=True,
                         transform=transforms.Compose([
                             transforms.Resize(input_size),
                             transforms.ToTensor()]),
                         download=True)
test_dataset = dsets.MNIST('./data',
                        train=False,
                        transform=transforms.Compose([
                            transforms.Resize(input_size),
                            transforms.ToTensor()]),
                        download=True)

In [7]:
# Dataset loaders (handle mini-batching of data) 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size_train, shuffle=True) 
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size_test, shuffle=False)

A version of the well-known LeNet-5 CNN model architecture.

In [8]:
class LeNet5(nn.Module):
    # A version of LeNet-5.  Note the hyperparameter 'v' (n^v activation maps).
    def __init__(self, v=0):
        super(LeNet5, self).__init__()
        # 1 image input channel, 6 filters, 5x5 kernel
        self.convnet = nn.Sequential(
            nn.Conv2d(1, 2**v, kernel_size=(5, 5)),  # c1
            nn.ReLU(),  # relu1
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),  # s2
            nn.Conv2d(2**v, 3**v, kernel_size=(5, 5)),  # c3
            nn.ReLU(),  # relu3
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),  # s4
            nn.Conv2d(3**v, 5**v, kernel_size=(5, 5)),  # c5
            nn.ReLU(),  # relu5
        )

        self.fc = nn.Sequential(
            nn.Linear(5**v, 84),  # f6
            nn.ReLU(),  # relu6
            nn.Linear(84, 10),  # f7
        )
        
    def forward(self, input):
        convout = self.convnet(input)
        convout = convout.view(input.size(0), -1)
        output = self.fc(convout)
        return output
        

The following function will train the model for a single epoch and report the training loss.

In [9]:
def train_one_epoch(epoch_num, verbose=VERBOSE):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        optimizer.zero_grad() 
        outputs = model(images) 
        loss = criterion(outputs, labels)
        # Backward pass
        loss.backward()
        # Optimize
        optimizer.step()
        if verbose is True:
            if (i + 1) % 100 == 0:
                print('Epoch: [% d/% d], Step: [% d/% d], Loss: %.4f'
                      % (epoch_num + 1, num_epochs, i + 1,
                         len(train_dataset) // batch_size_train, loss.item())) 

The following function computes the error for one epoch of data.

In [10]:
def epoch_error(loader, length, split='validation'):
    """ Computes the error for all data points in a loader.
       
        Inputs:
            loader: Pytorch data loader (object)
            length: Number of data points (integer)
            split: Name of split, typically 'train', 'test', or 'validation' (string)
        
        Returns:
            error (floating point)
    """
    model.eval()
    # Measure the error for the entire loader split.
    i = 0
    total = 0.
    incorrect = 0.
    for images, labels in loader:  # One batch at a time!
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1) 
        total += labels.size(0) 
        incorrect += (predicted != labels).sum()

    print(f'Error of the model on the {length} {split} images: {float(incorrect) / total:3.1%}')
    return float(incorrect) / total

This procedure will initialize the model and run a training loop.

This is a 10-class classification problem.  Adam is used for optimization.

In [11]:
def run_training(v):
    # Re-initialize model and optimizer!
    global model, criterion, optimizer
    model = LeNet5(v).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):
        train_one_epoch(epoch)

In [12]:
# Assess the training and test accuracy of LeNet-5
run_training(v)
test_accuracy = 1. - epoch_error(test_loader, test_dataset, 'test')
train_accuracy = 1. - epoch_error(train_loader, train_dataset, 'train')

Epoch: [ 1/ 10], Step: [ 100/ 234], Loss: 0.1376
Epoch: [ 1/ 10], Step: [ 200/ 234], Loss: 0.1151
Epoch: [ 2/ 10], Step: [ 100/ 234], Loss: 0.0401
Epoch: [ 2/ 10], Step: [ 200/ 234], Loss: 0.0277
Epoch: [ 3/ 10], Step: [ 100/ 234], Loss: 0.0881
Epoch: [ 3/ 10], Step: [ 200/ 234], Loss: 0.0192
Epoch: [ 4/ 10], Step: [ 100/ 234], Loss: 0.0145
Epoch: [ 4/ 10], Step: [ 200/ 234], Loss: 0.0167
Epoch: [ 5/ 10], Step: [ 100/ 234], Loss: 0.0055
Epoch: [ 5/ 10], Step: [ 200/ 234], Loss: 0.0165
Epoch: [ 6/ 10], Step: [ 100/ 234], Loss: 0.0051
Epoch: [ 6/ 10], Step: [ 200/ 234], Loss: 0.0101
Epoch: [ 7/ 10], Step: [ 100/ 234], Loss: 0.0288
Epoch: [ 7/ 10], Step: [ 200/ 234], Loss: 0.0059
Epoch: [ 8/ 10], Step: [ 100/ 234], Loss: 0.0160
Epoch: [ 8/ 10], Step: [ 200/ 234], Loss: 0.0123
Epoch: [ 9/ 10], Step: [ 100/ 234], Loss: 0.0049
Epoch: [ 9/ 10], Step: [ 200/ 234], Loss: 0.0064
Epoch: [ 10/ 10], Step: [ 100/ 234], Loss: 0.0033
Epoch: [ 10/ 10], Step: [ 200/ 234], Loss: 0.0178
Error of the model

In [13]:
# Set the model to inference mode
model.eval()

LeNet5(
  (convnet): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 81, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(81, 625, kernel_size=(5, 5), stride=(1, 1))
    (7): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=625, out_features=84, bias=True)
    (1): ReLU()
    (2): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [14]:
for images, labels in test_loader:
    batch_of_images = images.to(device)  # one batch
    print(batch_of_images.shape)
    break

torch.Size([1024, 1, 32, 32])


In [18]:
# Input to the model
torch_out = model(batch_of_images)

# Export the model
torch.onnx.export(model,                     # model being run
                  batch_of_images,           # model input (or a tuple for multiple inputs)
                  "convnet.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

In [19]:
import onnx

onnx_model = onnx.load("convnet.onnx")
onnx.checker.check_model(onnx_model)

Verify that the ONNX runtime and PyTorch models are computing the same values for the network.  Do this by creating an inference session for the model and evaluating it.

In [20]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("convnet.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(batch_of_images)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-02, atol=1e-03)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!


In [21]:
!ls -alt

total 5388
drwxr-xr-x 4 padraig padraig    4096 Nov 18 10:01 .
-rw-r--r-- 1 padraig padraig   49997 Nov 18 10:01 convnet_onnx_example.ipynb
-rw-r--r-- 1 padraig padraig 5411976 Nov 18 10:01 convnet.onnx
drwxr-xr-x 2 padraig padraig    4096 Nov 18 09:57 .ipynb_checkpoints
drwxr-xr-x 3 padraig padraig    4096 Nov 18 09:56 data
-rw-r--r-- 1 padraig padraig   30538 Nov 17 18:23 onnx.proto
drwxr-xr-x 3 padraig padraig    4096 Nov 17 18:23 ..


In [23]:
!protoc --decode=onnx.ModelProto onnx.proto < convnet.onnx

ir_version: 6
producer_name: "pytorch"
producer_version: "1.7"
graph {
  node {
    input: "input"
    input: "convnet.0.weight"
    input: "convnet.0.bias"
    output: "11"
    name: "Conv_0"
    op_type: "Conv"
    attribute {
      name: "dilations"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 5
      ints: 5
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  node {
    input: "11"
    output: "12"
    name: "Relu_1"
    op_type: "Relu"
  }
  node {
    input: "12"
    output: "13"
    name: "MaxPool_2"
    op_type: "MaxPool"
    attribute {
      name: "ceil_mode"
      i: 0
      type: INT
    }
    attribute {
     

    raw_data: "\2322\017=\226\374\'=\363#\246=\017\262\022>\246\301\214\272\225Q\347;\025\247\344\275\353]\334\275\305\236\360<l\235\007\273b\361\006=\325kb<y\324\375\274\343\303\020\276]\360+<`\352\200\275\320\021\324\274\000\362j<\323f\207\275\207\336\254<M\302D\275\224\224|\274\300\367:=%\007\206=\315\347\373<\277M\330\275\362:\026\275\026\335\022>\037\025\354=\274!|\275\265\025\374\275\335rA;x\371\251\275\210M\310\275F8?=\366J\351<Q,\t>\335\276h>\222b\277\275\221\363\027\276\032\240\013=\024)\337=/\3102>Mz >\010\010\311\275\3645\206\275\201\313\260=Z\035\317=ru\031=\244\323\334<\"\216\236<\365&\261=\247\327 >l\"{\275\304\331\303\275\034sc\276\001\375\334\275\350\263\323\274*\236\265=\305z\322=],b\276\237z\247\276\255Nu\276\316\360\334<;\001\025=\357i\264\2741q\021\275\003$\006\276\206\244\017\275\200\257\005\276\007\275\251;\320w0<N7\316\275\362\363\326\274\256\300d\2750+\251=Q\321{=-\273\013>\310\277\236=\351M\223\275\026w\331\275-\355s\275/^\001>\207\2757>\257tS=\032\006\245<T\37

    raw_data: "\347\244\354\273\316\003\304\274}\264\250<+\001\246<\334\010\255<6\311\262\274\243\251\203\274\226\345l\274\327\330\326\274\322\004\373;\312\221\006=\275\324\210;\307\225\217\274#\344\243\274\373\235\331\273\371\036\013\273\354]\225\274<y\232\274K$\016;\223\332\023<v\356\225\274\262\3542=Mv\311\274\014\006\007=\330$`\273=\312\224<\014\372\273\274$4\275\274(s\361\274\006\025\037\275%\016\023<N\225;\274b\025\000\274\232\331\023\272\024\352=\275\031\370\014\275\320\244\362\272\355I\031=Z\216\262;\016\010h\273c\324\200\274\3453\227\275]\261\225\273\224a\023\2755g\261;\320\233\035\274E:U=a\030:\275+X\007\275\2070!=\340Ae<\246[Z<\240\014\373\274\302\326a<=7\027\275\235I\310;\324g1\274\352\2551=\355\214\007\274_ox\274?\306\305;\311\323\272\272\376E\t<a\267\322\274d\270\005\275\177\241\231\274\333;7\272\376\357\310\274*K\341\2737G\236;\277$\222\273\036\356\274\274-\014\031\273\356\310\272\274X\356\364<2\313\035=Z\257\276;\025\364\373\274\354R\024;\373j\366\272MvW\274\031J@=\310\

    raw_data: "\372\355\3529\203i\301;[(/:\372\3608\274!\372\203;\002)\246<\261\031\024\273H\203\240\274\363\2176\274\027\227\177\274j\234M<\203 \301\273\006\237K\274\025x\275;\253z1\273\264\351T\274\2632\254;\322\250d\273?oi<\247\261\227\274Bf\314\274\374\275L\274+\245\234\274\026x\330\274\331\035\206;\026\324\272\274\321\235\310;\227\235\320;\373\326\273\274\036f\377;1hD\274\334\317\000\275l\335(;\351S\326\274\3636\031:\354\036\214\274{|4\273\237\257\215\274C\314/\274M\345U\274\000\200x\271\350b ;:U\022<\017\220\376\274\214#\277\274\321\244\260\274\262\210\024<X\320\025;\256K\352\274\365\206\327;\2048\245\274\331O\021<\331\363\000\274\215F\304\274\227\356\240\272\363\022\034\274@\222\024\274\211\246\265;\314\347\250\274\313\026\264\274\341\272\214\273G\rK\274y\267\221\274\036^b;\2063s\274O\245L\274\210l\320\274\255\034\025\274\251)\027\274G1\025\274e\260\374;\205\256\010<\243\351\331\274\342/\035\2749+G\274\030\261$\274\245\311\307\274\224\034\240\274R\332#<\356\031\346\274\375\203v\

\374\274\206\244Q\275\376xA\273i@\362\274\325\0346\275~F\301<\210\266Y<\240\332\031\274\316O+\2750\224+\275q5\255\275_\272D\274\257yQ\275\302\330$\275\371\200\"= \373\023\276\275\372_\275l\260\217\274\346\3632\274AL\'=\240\2177\275\354N\265\275\254>\255\275\224b\347<6\377D\275\270\245\230<UW\037\275\221E\272\274a\264\250\275\373\005\310\275^\376s\275&t\354\274N!v\274?\006\224=AY1\275\344\344u\272\202\177=\273|\253\005>k.E\275\241\335\216\275#\323%\274$\310\035<X\365\211=\022\371\264\275\263\300\364<\236.,<j\234+=\016?\r=\253\350\202\275\330\216\254\273\356\353\214\274\366\372?\273\275\273\307=\313B9<\212\327\212=\020\244V=5\305\014=\372q<=\363p\253=`L\277=A\217P=\271\341w=}0\270=\353\373\243\274\341`\270\275\343s\372<\033\016\177=\262f\016<C|\261\275/X\257<A\272\235=\210\217\n\271+\360\202\275\251H2=\362\275\016>\271- =\010\314:=\026\336\216<\357j\234;?\241\263<l=\032<\035\372V<J\306\373;.\204\224;\373\306\007;u|\247\274\n\245\233\274\021\353p;\2455L;\253\007\242<\263+\213\274~2\337:.\

2<cX\023\273\312\357\220\274\361\231\002\274\n\013]\274V\313\312\274\016\254\364;\t\261\014<\254\331q\274<\331\006\2754Y\330;\376&\205<\277\0232\274\261|\235;\3524Q\274F(t\274\0056n\274Z\n%\274Ii9\273\373\177w\274\302\235 ;\030\237\306;\033\347u<\301\375\247\274k*+\274\253\330\347<\366\223\240\274%\230\227\273\253.\007=\227\305\317\274\215\325\365<\2432X:\266\'\n\274\001/\365;\357`{<e\257\242\274Q#\024\275\226B\374\274\276\360`;H\230\001\275K\352\212;\324h\340\274\203\r\325\274$\276\367\274\217\nq\274\240$)\273\224[K;\213\224\"\274\215\365\346\274\t\336.\274\371K\326;\004\274.\273r\304]\274C\340\346\274\035<\257\272\024\000\302\273>\241\210\272\n\216\347\274Bd\322\272\020\335\256\274\355\002\317<\265\211\334:\007\000\177\273}\377\034\2759.\346\274\025\265\204\274SG\"\274u\216\t<\332\337\210;\2466\027\275\032\265\035\274Fr\204\273\306|\201\274\203\326Y\274\034\230\005\275\357\210\t\2755.\223:\352\275f\274,v\267\273\202\376M\274\232\366\2059\017\t\026\274\247\243!\274\370,\n\275\243\240W

62\3664<\344\333\206\274\301k\365\274\337\'\003\275;\306\375\274\352\000\233\274,?\243\274\272\243\211\273\224\353\223\273\245\272\371;N\221N\274\236+\314\274\202\017\224\273\017\315&\274\347\350$<4\010\236\274\340\363\022\274SO\335\274\325~\003<\223\225\374\272\310\r\377;r.\203\273\264\203\036\273[Y6\270\372\213\235\273~\266*<\360}\312\274\336\206\316\273\007\335%\274\247\354\223\274F^\252\274\253*\362\274\244^\003\274\035\320J\273Y\205\205:G\032W\274vHL<\326\346\250\2745W\311\274\321\005\206\274\351@\377;8\033\264\274\tpu\273\315\314\2707\231x\026;\315\020\276\274jK\3169\332k!\273w\"\234:\242\3167<\t@(\274\234VS\274\274\262\204\274\037\210C\274&^\321\274J;\211\274\030|f\2734\340\267\273\005+}\2747fG9\t\032v\274\376\356i\2735n\377;\020x5<)\205\315\274K\252\247;X\021\376\2735\355\001<\222{\376\274\030\304\035\274\030\367\273\274g5(<\324\"\236\273u\030\372\274\360\017\027<\243\256\342\274\344\247K\274\031\302\265;z\235\020<?L\257;\235\226o\274\037\302\352\273\351\024[\274W.\240\274\262\

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)

