In [1]:
import torch
print(torch.__version__)

import onnxscript
print(onnxscript.__version__)

from onnxscript import opset18  # opset 18 is the latest (and only) supported version for now

import onnxruntime
print(onnxruntime.__version__)

import onnx

2.4.0
0.1.0.dev20240723
1.18.1


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
model = MyModel()
torch_input = torch.randn(1, 1, 32, 32)
onnx_program = torch.onnx.dynamo_export(model, torch_input)



In [4]:
type(onnx_program)

torch.onnx.ONNXProgram

In [5]:
onnx_program.save("model_test.onnx")

In [6]:
onnx_model = onnx.load("model_test.onnx")
onnx.checker.check_model(onnx_model)

In [7]:
type(onnx_model)

onnx.onnx_ml_pb2.ModelProto

In [8]:
import onnxruntime

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(torch_input)
print(f"Input length: {len(onnx_input)}")
print(f"Sample input: {onnx_input}")

ort_session = onnxruntime.InferenceSession("./model_test.onnx", providers=['CPUExecutionProvider'])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)

Input length: 1
Sample input: (tensor([[[[-0.0738, -1.2470, -0.4809,  ..., -1.2223, -0.6317,  0.6253],
          [-0.0137, -1.5498, -0.2800,  ...,  0.7993, -0.8143,  2.4987],
          [-0.8035, -1.2920, -0.6143,  ..., -1.2958, -0.0844,  0.0715],
          ...,
          [ 0.5365,  0.2230,  0.5253,  ...,  0.6625, -0.0130, -1.3151],
          [-0.8605, -1.6757, -0.5549,  ...,  0.1405, -0.7908, -1.1467],
          [ 0.7817, -0.2687, -0.0614,  ...,  0.1668,  0.8369, -0.5398]]]]),)


In [9]:
type(onnxruntime_outputs)

list

In [10]:
onnxruntime_outputs

[array([[-0.15734848,  0.10147543, -0.08366099, -0.00755851,  0.07382713,
         -0.07885513, -0.0932669 , -0.1155888 , -0.09175342,  0.05691574]],
       dtype=float32)]

## ready to transform to native Pytorch model

In [11]:
from onnx2torch import convert
# from onnx2pytorch import ConvertModel

In [12]:
torch_model = convert(onnx_model)
torch_model

GraphModule(
  (torch_nn_modules_conv_Conv2d_conv1_1_1_Conv_0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (aten_relu_2_n0): ReLU()
  (MaxPool_3): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=False)
  (torch_nn_modules_conv_Conv2d_conv2_1_10_Conv_0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (aten_relu_11_n0): ReLU()
  (MaxPool_12): MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=False)
  (Constant): OnnxConstant()
  (Reshape_21): OnnxReshape()
  (Constant_1): OnnxConstant()
  (initializers): Module()
  (torch_nn_modules_linear_Linear_fc1_1_22_aten_addmm_1_n0): OnnxGemm()
  (aten_relu_23_n0): ReLU()
  (Constant_2): OnnxConstant()
  (torch_nn_modules_linear_Linear_fc2_1_24_aten_addmm_1_n0): OnnxGemm()
  (aten_relu_25_n0): ReLU()
  (Constant_3): OnnxConstant()
  (torch_nn_modules_linear_Linear_fc3_1_26_aten_addmm_1_n0): OnnxGemm()
)

### compare native model parameters with transformed model's parameters

In [14]:
for value in model.parameters():
    print(value)
    break

Parameter containing:
tensor([[[[-1.9584e-01,  1.9015e-01,  1.6752e-04, -1.9745e-02,  1.7147e-01],
          [-1.9642e-01,  5.9971e-02,  7.6872e-02,  1.3239e-01,  8.9234e-03],
          [ 1.0216e-01, -3.6773e-02,  1.5789e-01, -1.9236e-01,  1.0799e-01],
          [-7.1792e-02,  1.2483e-01, -1.9534e-01,  1.8139e-01, -8.4642e-02],
          [ 1.8736e-01, -1.2367e-01,  1.0867e-01, -1.6235e-01,  1.6699e-01]]],


        [[[ 1.8887e-01, -1.3021e-01,  1.2391e-01, -1.5503e-01,  3.3777e-02],
          [-5.1727e-02,  1.8902e-01, -1.4596e-01,  5.5943e-02,  1.5232e-01],
          [-1.1762e-01,  7.0078e-02,  8.8413e-02, -1.6833e-01,  2.6654e-02],
          [-1.9923e-01,  1.3626e-01, -1.6662e-01, -6.1196e-02,  1.6797e-01],
          [-1.9300e-01,  3.9912e-02, -1.8984e-01, -3.9656e-02, -1.3479e-02]]],


        [[[-8.5220e-02,  1.3631e-01, -7.7719e-02,  1.3761e-01,  1.3421e-01],
          [-1.5418e-01,  1.5575e-01, -1.3451e-01,  1.6811e-01, -9.4692e-02],
          [-3.8944e-03, -1.6381e-02,  7.4568e-

In [20]:
for value in torch_model.parameters():
    print(value)
    break

Parameter containing:
tensor([[[[-0.1652, -0.0676,  0.0429,  0.1445,  0.0531],
          [ 0.0162,  0.0668,  0.0195,  0.1646, -0.0267],
          [ 0.1495, -0.1044,  0.1614, -0.0026, -0.0051],
          [ 0.1450,  0.0476, -0.1513, -0.1519,  0.1008],
          [ 0.0947, -0.1343, -0.1398,  0.1571,  0.1432]]],


        [[[ 0.0950, -0.0642,  0.1295,  0.1532,  0.1561],
          [-0.0454,  0.1509,  0.1706, -0.0137, -0.0611],
          [ 0.0958, -0.0243,  0.0590, -0.0385, -0.0975],
          [ 0.0471, -0.0112, -0.0990, -0.0416,  0.0571],
          [-0.0812, -0.0369, -0.0034,  0.1992, -0.1905]]],


        [[[ 0.1827,  0.1212,  0.1868, -0.1294,  0.1722],
          [ 0.0441,  0.0315,  0.1384, -0.0743, -0.1234],
          [ 0.1914, -0.0528,  0.1493, -0.0461,  0.1197],
          [ 0.1537,  0.0287,  0.0793,  0.0761, -0.1018],
          [-0.1051,  0.1933,  0.1802,  0.0079,  0.1572]]],


        [[[ 0.1726, -0.1499,  0.0470, -0.1310, -0.0217],
          [-0.1805,  0.1859, -0.1632,  0.1520, -0.1076