In [1]:
import numpy as np
import torch
from torchvision import models
import onnx
import onnxsim
from onnxsim import simplify
import onnxruntime as ort

In [2]:
onnxsim.__version__

'0.4.28'

# export example

In [3]:
model = models.resnet18()

In [4]:
x = torch.ones(1, 3, 224, 224)

In [5]:
onnx_path = r"resnet18.onnx"
torch.onnx.export(model,                        # 保存的模型
                    x,                          # 模型输入
                    onnx_path,                  # 模型保存 (can be a file or file-like object)
                    verbose=False,              # 如果为True，则打印一些转换日志，并且onnx模型中会包含doc_string信息
                    opset_version=16,           # ONNX version 值必须等于_onnx_main_opset或在_onnx_stable_opsets之内。具体可在torch/onnx/symbolic_helper.py中找到
                    input_names=["image"],      # 按顺序分配给onnx图的输入节点的名称列表
                    output_names=["classes"],   # 按顺序分配给onnx图的输出节点的名称列表
                    )

verbose: False, log level: Level.ERROR



In [6]:
model = onnx.load(onnx_path)

In [7]:
model_, ok = simplify(model)

In [8]:
onnx.save(model_, "resnet18.sim.onnx")

In [24]:
def ort_run(onnx_path, x):
    so = ort.SessionOptions()
    so.log_severity_level = 3
    ort_model = ort.InferenceSession(onnx_path, sess_options=so, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
    res = ort_model.run(None, {ort_model.get_inputs()[0].name: x})
    print(res[0].shape)

In [25]:
ort_run("resnet18.sim.onnx", np.ones((1, 3, 224, 224), dtype=np.float32))

(1, 1000)


# export func

In [22]:
def export_onnx(model, input, onnx_path, dynamic_axes=False, half=False):
    if half: # half不支持cpu导出,必须使用cuda
        model = model.half()
        input = input.half()
    model.eval()
    torch.onnx.export(model,                        # 保存的模型
                        input,                      # 模型输入
                        onnx_path,                  # 模型保存 (can be a file or file-like object)
                        export_params=True,         # 如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False.
                        verbose=False,              # 如果为True，则打印一些转换日志，并且onnx模型中会包含doc_string信息
                        opset_version=16,           # ONNX version 值必须等于_onnx_main_opset或在_onnx_stable_opsets之内。具体可在torch/onnx/symbolic_helper.py中找到
                        do_constant_folding=True,   # 是否使用"常量折叠"优化。常量折叠将使用一些算好的常量来优化一些输入全为常量的节点。
                        input_names=["image"],      # 按顺序分配给onnx图的输入节点的名称列表
                        output_names=["classes"],   # 按顺序分配给onnx图的输出节点的名称列表
                        # 动态形状,初始通道不要变换,transformer使用dynamic可能会有问题
                        dynamic_axes={"image": {0: "batch_size", 2: "height", 3:"width"}, "classes": {0: "batch_size"}} if dynamic_axes else None
                        )

    # 载入onnx模型
    model_ = onnx.load(onnx_path)

    # 检查IR是否良好
    try:
        onnx.checker.check_model(model_)
    except Exception:
        print(f"{onnx_path} incorrect")
    else:
        print(f"{onnx_path} correct")

    # 简化模型
    model_simple, check = simplify(model_)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(model_simple, onnx_path)
    print("simplified ONNX model success")
    print('finished exporting ' + onnx_path)

In [23]:
input = torch.ones(1, 3, 224, 224)
resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
export_onnx(resnet18, input, "resnet18.onnx", dynamic_axes=True)
export_onnx(resnet18.cuda(0), input.cuda(0), "resnet18.half.onnx", dynamic_axes=True, half=True)

verbose: False, log level: Level.ERROR

resnet18.onnx correct
simplified ONNX model success
finished exporting resnet18.onnx
verbose: False, log level: Level.ERROR

resnet18.half.onnx correct
simplified ONNX model success
finished exporting resnet18.half.onnx


In [29]:
ort_run("resnet18.half.onnx", np.ones((1, 3, 224, 224), dtype=np.float16))

(1, 1000)


In [None]:
efficientnet_b0 = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
export_onnx(efficientnet_b0, input, "efficientnet_b0.onnx", dynamic_axes=True)

In [None]:
import timm
twins_svt_small = timm.models.twins_svt_small(num_classes=10)
export_onnx(twins_svt_small, input, "twins_svt_small.onnx")