In [13]:
import torch
import torch.nn as nn
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np

In [14]:
class SuperSimpleNet(nn.Module):
    def __init__(self):
        super(SuperSimpleNet, self).__init__()
        # Takes 10 inputs, produces 5 outputs
        self.fc1 = nn.Linear(10, 5) 
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        return out

In [15]:
pytorch_model = SuperSimpleNet()
pytorch_model.eval()

SuperSimpleNet(
  (fc1): Linear(in_features=10, out_features=5, bias=True)
  (relu): ReLU()
)

In [16]:
dummy_input = torch.randn(1, 10, requires_grad=True)
dummy_input

tensor([[-2.7076, -0.6751, -1.4080, -1.7014, -0.1946, -0.1022, -2.0475,  0.3873,
         -0.5109, -0.1874]], requires_grad=True)

In [17]:
onnx_filename = "torch_net.onnx"
torch.onnx.export(
    pytorch_model,             # Model to export
    dummy_input,               # Dummy input for tracing the graph
    onnx_filename,             # Output file name
    export_params=True,        # Store the trained parameter weights inside the model file
    opset_version=11,          # The ONNX version to export to (11 is very standard)
    do_constant_folding=True,  # Optimization technique
    input_names=['input'],     # Define standard names for the interface
    output_names=['output'],
    dynamic_axes={             # Allow variable batch sizes (Very important tip!)
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

In [18]:
# Verify the ONNX File
onnx_model = onnx.load(onnx_filename)
onnx.checker.check_model(onnx_model)

In [20]:
# Create an inference session
ort_session = ort.InferenceSession(onnx_filename)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

#  input data must be numpy array (float32)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}

# run the inference session
ort_outs = ort_session.run(None, ort_inputs)

In [22]:
print("PyTorch Output : ", to_numpy(pytorch_model(dummy_input))[0][:3])
print("ONNX Output : ", ort_outs[0][0][:3])

PyTorch Output :  [0.24815577 0.03326878 0.        ]
ONNX Output :  [0.24815576 0.03326878 0.        ]


In [23]:
# Verify they are mathematically identical
np.testing.assert_allclose(to_numpy(pytorch_model(dummy_input)), ort_outs[0], rtol=1e-03, atol=1e-05)

In [24]:
ort_outs

[array([[0.24815576, 0.03326878, 0.        , 0.6358243 , 0.        ]],
       dtype=float32)]

In [25]:
to_numpy(pytorch_model(dummy_input))

array([[0.24815577, 0.03326878, 0.        , 0.6358243 , 0.        ]],
      dtype=float32)