In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

In [3]:
my_model = MyModel()
my_model

MyModel(
  (linear): Linear(in_features=10, out_features=1, bias=True)
)

In [4]:
my_model.eval() # model.train()

MyModel(
  (linear): Linear(in_features=10, out_features=1, bias=True)
)

# 1. Export to ONNX

In [5]:
torch.save(my_model, 'model.pt')

In [6]:
torch.save(my_model.state_dict(), 'model_state_dict.pt')

In [7]:
dummy_input = torch.randn(1, 10)
dummy_input.shape

torch.Size([1, 10])

In [8]:
torch.onnx.export(my_model, dummy_input, 'model.onnx')

verbose: False, log level: Level.ERROR



# 2. Export to TorchScript

In [9]:
dummy_input = torch.randn(1, 10)
dummy_input.shape

torch.Size([1, 10])

In [10]:
scripted_model = torch.jit.trace(my_model, dummy_input)

In [11]:
my_model

MyModel(
  (linear): Linear(in_features=10, out_features=1, bias=True)
)

In [12]:
scripted_model

MyModel(
  original_name=MyModel
  (linear): Linear(original_name=Linear)
)

In [13]:
scripted_model.save("scripted_model.pt")

# 3. Read ONNX model

In [14]:
import onnx
import onnxruntime

In [15]:
model_path = 'model.onnx'

In [16]:
onnx_model = onnx.load(model_path)
onnx_model

ir_version: 7
producer_name: "pytorch"
producer_version: "2.0.0"
graph {
  node {
    input: "onnx::Gemm_0"
    input: "linear.weight"
    input: "linear.bias"
    output: "3"
    name: "/linear/Gemm"
    op_type: "Gemm"
    attribute {
      name: "alpha"
      f: 1
      type: FLOAT
    }
    attribute {
      name: "beta"
      f: 1
      type: FLOAT
    }
    attribute {
      name: "transB"
      i: 1
      type: INT
    }
  }
  name: "torch_jit"
  initializer {
    dims: 1
    dims: 10
    data_type: 1
    name: "linear.weight"
    raw_data: "Qq7\276>Pu\276X\245)\2758\366\034>0m\346\275\234\346\250=7\220\010\276\302"u\276\323c\221\27607\357="
  }
  initializer {
    dims: 1
    data_type: 1
    name: "linear.bias"
    raw_data: "\025\233\004\276"
  }
  input {
    name: "onnx::Gemm_0"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 10
          }
        }
      }
    }

In [17]:
# create onnx runtime session
ort_session = onnxruntime.InferenceSession(model_path)
ort_session

<onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x116917430>

In [18]:
input_data = np.ones((1, 10)).astype(np.float32)
input_data.shape

(1, 10)

In [19]:
output = ort_session.run(None, {'onnx::Gemm_0': input_data})
output

[array([[-1.0062946]], dtype=float32)]

In [20]:
input_data = np.zeros((1, 10)).astype(np.float32)
input_data.shape

(1, 10)

In [21]:
output = ort_session.run(None, {'onnx::Gemm_0': input_data})
output

[array([[-0.12949784]], dtype=float32)]