In [None]:
!pip install onnx
!pip install onnxscript
!pip install onnxruntime

## Import Libraries

In [1]:
import torch
import torchvision
import onnx
import onnxruntime as ort
import torch.nn as nn
import torch.nn.functional as F

## Create Torch Model

In [2]:
class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
torch_model = MyModel()

## Convert Torch Model To ONNX

In [3]:
torch_input = torch.randn(1, 1, 32, 32)
onnx_model = torch.onnx.dynamo_export(torch_model, torch_input)
onnx_model.save("pytorch_to_onnx_model.onnx")



In [4]:
onnx_model_check = onnx.load_model("/content/pytorch_to_onnx_model.onnx")
onnx.checker.check_model(onnx_model_check)

## Use ONNX Model

In [5]:
onnx_model_session = ort.InferenceSession(
    "/content/pytorch_to_onnx_model.onnx",
    providers = ["CPUExecutionProvider"]
)
input_name = onnx_model_session.get_inputs()[0].name
output_name = onnx_model_session.get_outputs()[0].name
print(input_name, output_name)

l_x_ fc3_1


In [6]:
input_data = torch_input.detach().cpu().numpy()
predict = onnx_model_session.run(
    [output_name],
    {input_name:input_data}
)
predict

[array([[-0.03642596,  0.09679464, -0.10962822,  0.05402166, -0.13292298,
          0.03535164, -0.06996707, -0.02941129, -0.07467412,  0.0602996 ]],
       dtype=float32)]

### Compare The PyTorch Results

In [7]:
torch_outputs = torch_model(torch_input)
torch_outputs = onnx_model.adapt_torch_outputs_to_onnx(torch_outputs)

assert len(torch_outputs) == len(predict)
for torch_output, onnxruntime_output in zip(torch_outputs, predict):
    torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))

print("PyTorch and ONNX Runtime output matched!")
print(f"Output length: {len(predict)}")
print(f"Sample output: {predict}")

PyTorch and ONNX Runtime output matched!
Output length: 1
Sample output: [array([[-0.03642596,  0.09679464, -0.10962822,  0.05402166, -0.13292298,
         0.03535164, -0.06996707, -0.02941129, -0.07467412,  0.0602996 ]],
      dtype=float32)]
