In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnxruntime

import onnx

class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
torch_model = MyModel()

torch_input = torch.randn(1, 1, 32, 32)
onnx_program = torch.onnx.dynamo_export(torch_model, torch_input)

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()
  self.param_schema = self.onnxfunction.param_schemas()


In [4]:
onnx_program.save("my_image_classifier.onnx")

In [5]:
onnx_model = onnx.load("my_image_classifier.onnx")
onnx.checker.check_model(onnx_model)

In [6]:
onnx_input = onnx_program.adapt_torch_inputs_to_onnx(torch_input)
print(f"Input length: {len(onnx_input)}")
print(f"Sample input: {onnx_input}")

ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)

Input length: 1
Sample input: (tensor([[[[ 0.1942, -0.4748, -0.9118,  ...,  0.1208,  0.3907,  0.5167],
          [ 1.1289,  0.1684,  1.0558,  ...,  0.2496, -0.6566,  1.5778],
          [ 1.5094, -1.0881, -1.4730,  ..., -0.8824, -2.6115, -1.6222],
          ...,
          [-0.2878, -0.5005, -1.0448,  ...,  1.5891,  0.8803,  0.1257],
          [-3.7665,  0.2684,  0.5040,  ..., -0.7130, -0.7664, -0.2990],
          [-0.5156,  2.4762, -0.6279,  ...,  0.2797,  0.1082, -1.7672]]]]),)


In [7]:
onnxruntime_input

{'l_x_': array([[[[ 0.19419159, -0.47479498, -0.9117583 , ...,  0.12077317,
            0.39066866,  0.5167254 ],
          [ 1.1289116 ,  0.16836381,  1.0557916 , ...,  0.24960539,
           -0.65658426,  1.5778296 ],
          [ 1.509396  , -1.0881438 , -1.4730294 , ..., -0.88236344,
           -2.6114843 , -1.6222423 ],
          ...,
          [-0.2878471 , -0.5005211 , -1.0448407 , ...,  1.5890512 ,
            0.88028604,  0.12574488],
          [-3.7664676 ,  0.26840717,  0.50401944, ..., -0.71299934,
           -0.7663836 , -0.2989551 ],
          [-0.51561105,  2.4761722 , -0.6278699 , ...,  0.27974784,
            0.10822505, -1.767246  ]]]], dtype=float32)}

In [8]:
torch_outputs = torch_model(torch_input)
torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)

In [9]:
assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
    torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))

print("PyTorch and ONNX Runtime output matched!")
print(f"Output length: {len(onnxruntime_outputs)}")
print(f"Sample output: {onnxruntime_outputs}")

PyTorch and ONNX Runtime output matched!
Output length: 1
Sample output: [array([[-0.01925719, -0.13492969, -0.17477617, -0.04467422,  0.04082778,
        -0.03521884,  0.14429587,  0.07565545, -0.02656089,  0.01284647]],
      dtype=float32)]
