In [1]:
import torch
import torchvision.models as models
model = models.resnet18(pretrained=True)
model.eval()

dummy_input = torch.randn(1,3,224,224)

try:
    import onnx
    torch.onnx.export(model,dummy_input,'resnet18.onnx')
except ImportError as e:
    print('no import onnx')



verbose: False, log level: Level.ERROR



we can get layer information of model.

In [6]:
import onnx
model = onnx.load('resnet18.onnx')
print(onnx.helper.printable_graph(model.graph))


graph torch_jit (
  %input.1[FLOAT, 1x3x224x224]
) initializers (
  %fc.weight[FLOAT, 1000x512]
  %fc.bias[FLOAT, 1000]
  %onnx::Conv_193[FLOAT, 64x3x7x7]
  %onnx::Conv_194[FLOAT, 64]
  %onnx::Conv_196[FLOAT, 64x64x3x3]
  %onnx::Conv_197[FLOAT, 64]
  %onnx::Conv_199[FLOAT, 64x64x3x3]
  %onnx::Conv_200[FLOAT, 64]
  %onnx::Conv_202[FLOAT, 64x64x3x3]
  %onnx::Conv_203[FLOAT, 64]
  %onnx::Conv_205[FLOAT, 64x64x3x3]
  %onnx::Conv_206[FLOAT, 64]
  %onnx::Conv_208[FLOAT, 128x64x3x3]
  %onnx::Conv_209[FLOAT, 128]
  %onnx::Conv_211[FLOAT, 128x128x3x3]
  %onnx::Conv_212[FLOAT, 128]
  %onnx::Conv_214[FLOAT, 128x64x1x1]
  %onnx::Conv_215[FLOAT, 128]
  %onnx::Conv_217[FLOAT, 128x128x3x3]
  %onnx::Conv_218[FLOAT, 128]
  %onnx::Conv_220[FLOAT, 128x128x3x3]
  %onnx::Conv_221[FLOAT, 128]
  %onnx::Conv_223[FLOAT, 256x128x3x3]
  %onnx::Conv_224[FLOAT, 256]
  %onnx::Conv_226[FLOAT, 256x256x3x3]
  %onnx::Conv_227[FLOAT, 256]
  %onnx::Conv_229[FLOAT, 256x128x1x1]
  %onnx::Conv_230[FLOAT, 256]
  %onnx::Conv_

### model inference

when we inference onnx file, we typically use onnxruntime.




In [20]:
import numpy as np
import onnxruntime as ort

ort_session = ort.InferenceSession('resnet18.onnx')
outputs = ort_session.run(None,{"input.1": np.random.randn(1, 3, 224, 224).astype(np.float32)})
# print(outputs)
print(len(outputs[0][0]))

1000


inferencesession 을 활용해서 모델을 불러오고, run 명령을 실행.
run을 실행 할때 inference output과 input에 대해서 dictionary 형태로 입력을 줘야함.
ouput에 none을 넣음녀 모든 outpu 출력.
input 을 넣어줄때는 위에서 확인한 그래프의 input layer 의 이름을 활용하여 위와 같이 넣어주면된다.

### 모델 추론 결과

onnxruntime vs torch model

In [25]:
import time
import torch
import torchvision.models as models

import numpy as np
import onnx
import onnxruntime as ort

# torch model
dummy_input = torch.randn(1,3,224,224)
model = models.resnet18(pretrained=True)
model.eval()
start = time.time()
for _ in range(100):
    torch_output = model(dummy_input)
print("torch inference:", time.time() - start)

## onnxruntime
ort_session = ort.InferenceSession('resnet18.onnx')
ort_outputs = ort_session.run(None, {"input.1": dummy_input.numpy()})
np.testing.assert_allclose(torch_output.detach().numpy(),ort_outputs[0],rtol=1e-03,atol=1e-05)

start = time.time()
for _ in range(100):
    ort_outputs = ort_session.run(None, {"input.1": dummy_input.numpy()})
print("ort inference:", time.time() - start)

torch inference: 9.868140697479248
ort inference: 1.827899694442749
