## Torch2Onnx

In [None]:
import os
import sys
sys.path.append("..")

In [None]:
# load packages
import cv2
import numpy as np
import onnx
from onnxsim import simplify
import onnxruntime
from pathlib import Path
import time

import torch
import torch.nn as nn
import torchvision

from pytocr.modeling.architectures import build_model
from pytocr.utils.save_load import load_pretrained_params
from utils import load_config, draw_det_res

In [None]:
# config
config_path = ".../PyTorchOCR/configs/det/det_r18_db.yml"
model_path = ".../PyTorchOCR/models/torch/db_r18.pth"

In [None]:
config = load_config(config_path)
config["Global"]["distributed"] = False

# build model
model = build_model(config["Architecture"])
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config["Global"]["use_gpu"] and torch.cuda.is_available()
device = torch.device("cuda:0" if use_gpu else "cpu")
model = model.to(device)
model.eval()
model = load_pretrained_params(model, model_path)
# print(model)

In [None]:
# convert to onnx model
input_img = torch.ones(1, 3, 736, 736)
input_img = input_img.to(device)
out_onnx_path = ".../models/onnx/db_r18_op10.onnx"
out_onnx_sim_path = ".../models/onnx/db_r18_op10_sim.onnx"
input_name = "ocr_det_input"     # 输入结点的名称
output_name = "ocr_det_output"   # 输出结点的名称
with torch.no_grad():
    torch.onnx.export(
        model, 
        input_img, 
        out_onnx_path, 
            verbose=False,      # 是否输出log
        input_names=[input_name], 
        output_names=[output_name], 
        dynamic_axes= {
            input_name: {0:'batch_size', 2:'in_width', 3:'in_height'},
            output_name: {0:'batch_size', 2:'out_width', 3:'out_height'}}, # 动态batch+宽高
        opset_version=10)  # 有问题就改 opset_version
    
# simplify  合并/删除冗余结点 转换支持操作符
input_shapes = {input_name: list(input_img.shape)}   # 动态输入需要

# use onnxsimplify to reduce reduent model.
onnx_model = onnx.load(out_onnx_path)
model_simp, check = simplify(
    onnx_model, 
    dynamic_input_shape=True, 
    input_shapes=input_shapes)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, out_onnx_sim_path)

## 验证模型 

In [None]:
# choose test image
img_path = r".../test_img.png"
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
resized_img = cv2.resize(img, (960, 736))
image_data = np.array(resized_img, dtype='float32')
image_data /= 255.
image_data = np.transpose(image_data, (2, 0, 1))  # C H W

In [None]:
# Torch infer
torch_input = torch.from_numpy(image_data).unsqueeze(0)  # N C H W

st_time = time.time()
with torch.no_grad():
    torch_input = torch_input.to(device)
    print(torch_input.shape, torch_input.dtype)
    torch_preds = model(torch_input)["maps"].cpu().numpy()
print("torch infer cost time", time.time() - st_time)
print(torch_preds.shape)

In [None]:
# load onnx model
onnx_path = out_onnx_sim_path
session = onnxruntime.InferenceSession(onnx_path)
session.get_modelmeta()
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
input_name, output_name

In [None]:
# Onnx infer
onnx_input = np.expand_dims(image_data, 0)  # Add batch dimension.
st_time = time.time()
onnx_preds = session.run([output_name], {input_name: onnx_input})[0]
print("onnx infer cost time", time.time() - st_time)
print(onnx_preds.shape)

In [None]:
# 计算输出差异
diff = onnx_preds - torch_preds
print("difference between onnx and torch: ", max(diff.reshape(-1)))