<a href="https://colab.research.google.com/github/PacktPublishing/Hands-On-Computer-Vision-with-Detectron2/blob/main/Chapter13/Detectron2_Chapter13_Intro2ONNX.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chapter 13  - Introduction to ONNX

## A simple PyTorch model

In [None]:
import torch
import torch.nn as nn

class SimplePyTorchModel(nn.Module):
  def __init__(self):
    super(SimplePyTorchModel, self).__init__()
    self.linear = nn.Linear(4, 1)   
    self.linear.weight.data.fill_(0.01)
    self.linear.bias.data.fill_(0.01)

  def forward(self, X):
    return torch.relu(self.linear(X))

In [None]:
pt_model = SimplePyTorchModel()
pt_model.eval()

SimplePyTorchModel(
  (linear): Linear(in_features=4, out_features=1, bias=True)
)

## Export

In [None]:
dummy_X = torch.tensor([[1, 2, 3, 4]], dtype=torch.float32)
model_name = 'onnx_model.onnx'
torch.onnx.export(model   = pt_model, 
                  args    = dummy_X, 
                  f       = model_name, 
                  verbose = True)

## Load back

In [None]:
!pip install onnx

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting onnx
  Downloading onnx-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.5/13.5 MB[0m [31m74.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting protobuf<4,>=3.20.2
  Downloading protobuf-3.20.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m53.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: protobuf, onnx
  Attempting uninstall: protobuf
    Found existing installation: protobuf 3.19.6
    Uninstalling protobuf-3.19.6:
      Successfully uninstalled protobuf-3.19.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.11.0 requires protobu

In [None]:
import onnx
from onnx.helper import printable_graph
# load
loaded_model = onnx.load(model_name)
# check (well formed)
onnx.checker.check_model(loaded_model)
# graph
print(printable_graph(loaded_model.graph))

graph torch_jit (
  %onnx::Gemm_0[FLOAT, 1x4]
) initializers (
  %linear.weight[FLOAT, 1x4]
  %linear.bias[FLOAT, 1]
) {
  %/linear/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%onnx::Gemm_0, %linear.weight, %linear.bias)
  %4 = Relu(%/linear/Gemm_output_0)
  return %4
}


In [None]:
!pip install onnxruntime

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting onnxruntime
  Downloading onnxruntime-1.14.0-cp38-cp38-manylinux_2_27_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m57.2 MB/s[0m eta [36m0:00:00[0m
Collecting coloredlogs
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 KB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 KB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: humanfriendly, coloredlogs, onnxruntime
Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-1.14.0


In [None]:
import onnxruntime as ort
ort_session = ort.InferenceSession(model_name)
input_name = ort_session.get_inputs()[0].name
print(input_name)

onnx::Gemm_0


In [None]:
import numpy as np
X = np.array([[2, 3, 4, 5]])
outputs = ort_session.run(
    None,
    {input_name: X.astype(np.float32)},
)
print(outputs[0])

[[0.15]]


In [None]:
X = torch.tensor([[2, 3, 4, 5]], dtype=torch.float32)
with torch.no_grad():
  y = pt_model(X)
  print(y)

tensor([[0.1500]])


## Download model for future uses

In [None]:
from google.colab import files 
files.download('onnx_model.onnx')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>