In [108]:
# ! pip install onnx
# ! pip install onnxruntime
# ! python3 -m pip install --upgrade pip
# ! pip uninstall onnxruntime-training

In [64]:
import onnx
import torch
from torchvision.models import mobilenetv2
from pathlib import Path
from onnxruntime.tools.convert_onnx_models_to_ort import convert_onnx_models_to_ort, OptimizationStyle
import onnxruntime as ort
import numpy as np

In [65]:
x = torch.randn(1, 3, 224, 224, requires_grad=True)
model = mobilenetv2.mobilenet_v2(pretrained=True)
torch.onnx.export(model, x, "temp/mobilenet_v2.onnx")

In [66]:
convert_onnx_models_to_ort(
    model_path_or_dir = Path("temp/mobilenet_v2.onnx"),
    optimization_styles = [OptimizationStyle.Fixed, OptimizationStyle.Runtime]
)

2023-12-20 14:35:20,520 ort_format_model.utils [INFO] - Created config in /Users/takenoko/Develop/onnx-ort/temp/mobilenet_v2.required_operators.config
2023-12-20 14:35:20,641 ort_format_model.utils [INFO] - Created config in /Users/takenoko/Develop/onnx-ort/temp/mobilenet_v2.required_operators.with_runtime_opt.config


Converting models with optimization style 'Fixed' and level 'all'
Converting optimized ONNX model /Users/takenoko/Develop/onnx-ort/temp/mobilenet_v2.onnx to ORT format model /Users/takenoko/Develop/onnx-ort/temp/mobilenet_v2.ort
Converted 1/1 models successfully.
Generating config file from ORT format models with optimization style 'Fixed' and level 'all'
Converting models with optimization style 'Runtime' and level 'all'
Converting optimized ONNX model /Users/takenoko/Develop/onnx-ort/temp/mobilenet_v2.onnx to ORT format model /Users/takenoko/Develop/onnx-ort/temp/mobilenet_v2.with_runtime_opt.ort
Converted 1/1 models successfully.
Converting models again without runtime optimizations to generate a complete config file. These converted models are temporary and will be deleted.
Converting optimized ONNX model /Users/takenoko/Develop/onnx-ort/temp/mobilenet_v2.onnx to ORT format model /Users/takenoko/Develop/onnx-ort/temp/tmpx109gy0m.without_runtime_opt/mobilenet_v2.ort
Converted 1/1 mo

In [67]:
onnx_model = onnx.load("temp/mobilenet_v2.onnx")
onnx.checker.check_model(onnx_model)

In [102]:
x = np.ones((1, 3, 224, 224), dtype=np.float32)
ort_sess = ort.InferenceSession('temp/mobilenet_v2.onnx')

input_names = [x.name for x in ort_sess.get_inputs()]
input_shapes = [x.shape for x in ort_sess.get_inputs()]
output_names = [x.name for x in ort_sess.get_outputs()]
output_shapes = [x.shape for x in ort_sess.get_outputs()]
print(f"{input_names}: {input_shapes}")
print(f"{output_names}: {output_shapes}")

outputs = ort_sess.run(None, {input_names[0]: x})
print(np.array(outputs).shape)

['input.1']: [[1, 3, 224, 224]]
['536']: [[1, 1000]]
(1, 1, 1000)
