In [1]:
import onnx
import torch
import torch.onnx

In [77]:
class AddConst(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = 5
        self.b = torch.nn.Linear(4, 4)
        self.o = torch.tensor([1, 2, 3, 4], dtype=torch.float32)

    def forward(self, x):
        return self.b(x) + self.w + self.o


In [78]:
model = AddConst()
x = torch.randn(3, 4)
torch.onnx.export(model, x, "add_const.onnx", verbose=True)

Exported graph: graph(%onnx::Gemm_0 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %b.weight : Float(4, 4, strides=[4, 1], requires_grad=1, device=cpu),
      %b.bias : Float(4, strides=[1], requires_grad=1, device=cpu)):
  %/b/Gemm_output_0 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/b/Gemm"](%onnx::Gemm_0, %b.weight, %b.bias), scope: __main__.AddConst::/torch.nn.modules.linear.Linear::b # /Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/nn/modules/linear.py:114:0
  %/Constant_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={5}, onnx_name="/Constant"](), scope: __main__.AddConst:: # /var/folders/vw/rrzy8ptj6k10mvc6tf4170gw0000gq/T/ipykernel_45008/2052299685.py:9:0
  %/Add_output_0 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = onnx::Add[onnx_name="/Add"](%/b/Gemm_output_0, %/Constant_outp

In [79]:
net = onnx.load("add_const.onnx")
len(net.graph.node)

5

In [30]:
import numpy as np

In [80]:
np.frombuffer(net.graph.node[0].attribute[0].t.raw_data, dtype=np.float32)

array([], dtype=float32)

In [39]:
net.graph.node[0].attribute[0].t.dims

[]

In [45]:
net.graph.node[1].input

['onnx::Add_0', '/Constant_output_0']

In [57]:
net.graph.node[1]

input: "onnx::Add_0"
input: "/Constant_output_0"
output: "2"
name: "/Add"
op_type: "Add"
doc_string: "/var/folders/vw/rrzy8ptj6k10mvc6tf4170gw0000gq/T/ipykernel_45008/2842780031.py(7): forward\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py(1488): _slow_forward\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py(1501): _call_impl\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py(118): wrapper\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py(127): forward\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py(1501): _call_i

In [75]:
net.graph.initializer

google._upb._message.RepeatedCompositeContainer

In [71]:
net.graph.input[0].type.tensor_type

elem_type: 1
shape {
  dim {
    dim_value: 3
  }
  dim {
    dim_value: 4
  }
}

In [89]:
net.graph.node[-4].attribute[0].t.dims

[]

In [90]:
net.graph.node[-2].attribute[0].t.raw_data

b'\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x80@'

In [91]:
class LinearProjection(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = torch.ones([5, 5])
        self.b = torch.ones([5])

    def forward(self, x):
        return torch.matmul(x, self.w) + self.b

    @classmethod
    def get_dummy_input_tensor(cls):
        return torch.ones(3, 5)

In [92]:
model = LinearProjection()
x = torch.randn(3, 5)
torch.onnx.export(model, x, "linear_projection.onnx", verbose=True)

Exported graph: graph(%onnx::MatMul_0 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu)):
  %/Constant_output_0 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cpu) = onnx::Constant[value=<Tensor>, onnx_name="/Constant"](), scope: __main__.LinearProjection:: # /var/folders/vw/rrzy8ptj6k10mvc6tf4170gw0000gq/T/ipykernel_45008/907790685.py:8:0
  %/MatMul_output_0 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu) = onnx::MatMul[onnx_name="/MatMul"](%onnx::MatMul_0, %/Constant_output_0), scope: __main__.LinearProjection:: # /var/folders/vw/rrzy8ptj6k10mvc6tf4170gw0000gq/T/ipykernel_45008/907790685.py:8:0
  %/Constant_1_output_0 : Float(5, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1  1  1  1  1 [ CPUFloatType{5} ], onnx_name="/Constant_1"](), scope: __main__.LinearProjection:: # /var/folders/vw/rrzy8ptj6k10mvc6tf4170gw0000gq/T/ipykernel_45008/907790685.py:8:0
  %4 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu) = onnx::Add[onnx_name=

In [93]:
net = onnx.load("linear_projection.onnx")

In [100]:
net.graph.node[1]

input: "onnx::MatMul_0"
input: "/Constant_output_0"
output: "/MatMul_output_0"
name: "/MatMul"
op_type: "MatMul"
doc_string: "/var/folders/vw/rrzy8ptj6k10mvc6tf4170gw0000gq/T/ipykernel_45008/907790685.py(8): forward\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py(1488): _slow_forward\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py(1501): _call_impl\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py(118): wrapper\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py(127): forward\n/Users/benjamintenmann/Library/Caches/pypoetry/virtualenvs/ggml-conversion-fEb3rMr2-py3.11/lib/python3.11/site-packages/torch/nn/modules/m

In [101]:
!poetry run pip install onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.15.1-cp311-cp311-macosx_10_15_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting coloredlogs
  Using cached coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
Collecting flatbuffers
  Using cached flatbuffers-23.5.26-py2.py3-none-any.whl (26 kB)
Collecting humanfriendly>=9.1
  Using cached humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
Installing collected packages: flatbuffers, humanfriendly, coloredlogs, onnxruntime
Successfully installed coloredlogs-15.0.1 flatbuffers-23.5.26 humanfriendly-10.0 onnxruntime-1.15.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [102]:
import onnxruntime

In [103]:
sess = onnxruntime.InferenceSession("linear_projection.onnx")

In [107]:
meta = sess.get_modelmeta()

In [113]:
onnx.shape_inference.infer_shapes(net).graph.value_info[0].type

tensor_type {
  elem_type: 1
  shape {
    dim {
      dim_value: 5
    }
    dim {
      dim_value: 5
    }
  }
}