In [2]:
import onnx
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
#!pip install onnxruntime
import onnxruntime

In [None]:
wget https://github.com/microsoft/onnxruntime/releases/download/v1.8.1/onnxruntime-linux-x64-1.8.1.tgz
tar -zxvf onnxruntime-linux-x64-1.8.1.tgz
cd onnxruntime-linux-x64-1.8.1
os.environ['ONNXRUNTIME_DIR'] = '/kaggle/working/onnxruntime-linux-x64-1.8.1'
export ONNXRUNTIME_DIR=$(pwd)
export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH

In [None]:
#先将模型载入
model = test_model()
state = torch.load('test.pth')
model.load_state_dict(state['model'], strict=True)

#将模型设置成为eval（）模式
model.eval()

In [None]:
# Input to the model
#需要构造完整的输入形状
batch_size=1
x = torch.randn(batch_size, 3, 224, 224, requires_grad=True)

#这里需要补全参数,因为多个输入难以指定

from functools import partial 
temp_forward=model.forward
model.forward=partial(model.forward,img_metas={}, return_loss=False)
torch_out = model(x)

#  torch.onnx.export(model, (x, {'y': None, 'z': z}), ‘test.onnx’)参数也可以用字典的形式传入，不要上面的方式
# Export the model
with torch.no_grad():
    torch.onnx.export(model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "test2.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # 自定义input names
                  output_names = ['output'], # 自定义output names
                  dynamic_axes={'input' : {0 : 'batch_size', 2: 'width',3: 'height'},'output' : {0 : 'batch_size'}})
    
#所谓的动态dynamic_axes指的就是给定的输入维度，哪些是可能改变的，0,2,3

In [None]:
#载入onnx格式的模型
import onnx
onnx_model=onnx.load('test.onnx')#直接将onnx的图载入

#检查一下
onnx.checker.check_model(onnx_model)

In [None]:
#读入图片
from PIL import Image
import torchvision.transforms as transforms
pic_path='/kaggle/input/10-monkey-species/validation/validation/n8/n801.jpg'

#变形
img=Image.open(pic_path)
resize = transforms.Resize([224, 224])
img = resize(img)

#转成张量 ，预处理也可以在这里做
to_tensor = transforms.ToTensor()
img_y = to_tensor(img)
img_y.unsqueeze_(0)

In [None]:
#执行推断
#因为是在onnxruntime上执行的，所以是将.onnx文件直接作为图传入，形成session
ort_session = onnxruntime.InferenceSession('test.onnx')

#这个是将输入放到正确的设备上
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction执行的方法
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}#输入{'name':输入张量}
ort_outs = ort_session.run(None, ort_inputs)#输出，其后要考虑对其进行后处理，生成最终的预测结果

In [None]:
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

In [None]:
from PIL import Image
import torchvision.transforms as transforms

img = Image.open("./_static/img/cat.jpg")
resize = transforms.Resize([224, 224])
img = resize(img)

img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()

to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)

In [None]:
#以字典的形式，将模型的输入指定
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]

#这里生成的结果可进行后处理生成最后的预测结果

# Flops


In [None]:
from mmcv.cnn import get_model_complexity_info
model.eval()
input_shape=(3,384,384)
model.forward = model.forward_dummy
flops, params = get_model_complexity_info(model, input_shape)

# 修改ONNX模型

In [None]:
#修改onnx模型的案例   解决模型转换冲突
graph = onnx_model.graph  #图
nodes = graph.node   #节点列表
initializers = graph.initializer  #权重参数
    
for node in nodes:
    #遍历找到指定node，修改内容
    if node.name=='Conv_25':
        node.attribute[1].ints=10
        #节点的所有属性
        node.input=["1","2"]
        node.output=['3','4']
        nodes.remove(node)
        #构造新的节点
        new_node = onnx.helper.make_node(
                    'NonMaxSuppression',
                        node_inputs[:2],
                        node_outputs,
                        score_threshold=score_threshold,
                        offset=offset)                 
        nodes.insert(idx, new_node)


# 测试使用函数

In [3]:
import os
import warnings
from functools import partial

import numpy as np
import onnx
import onnxruntime as rt

import torch
import torch.nn as nn
from packaging import version

In [4]:
onnx_file = 'tmp.onnx'

In [5]:
class WrapFunction(nn.Module):

    def __init__(self, wrapped_function):
        super(WrapFunction, self).__init__()
        self.wrapped_function = wrapped_function

    def forward(self, *args, **kwargs):
        return self.wrapped_function(*args, **kwargs)

In [4]:
def process_grid_sample(func, input, grid, ort_custom_op_path=''):
    wrapped_model = WrapFunction(func).eval()

    input_names = ['input', 'grid']
    output_names = ['output']

    with torch.no_grad():
        torch.onnx.export(
            wrapped_model, (input, grid),
            onnx_file,
            export_params=True,
            keep_initializers_as_inputs=True,
            input_names=input_names,
            output_names=output_names,
            opset_version=11)

    onnx_model = onnx.load(onnx_file)

    session_options = rt.SessionOptions()
    if ort_custom_op_path:
        session_options.register_custom_ops_library(ort_custom_op_path)

    # get onnx output
    input_all = [node.name for node in onnx_model.graph.input]
    input_initializer = [node.name for node in onnx_model.graph.initializer]
    net_feed_input = list(set(input_all) - set(input_initializer))
    assert (len(net_feed_input) == 2)
    sess = rt.InferenceSession(onnx_file, session_options)
    ort_result = sess.run(None, {
        'input': input.detach().numpy(),
        'grid': grid.detach().numpy()
    })
    pytorch_results = wrapped_model(input.clone(), grid.clone())
    os.remove(onnx_file)
    assert np.allclose(pytorch_results, ort_result, atol=1e-3)

# 实时视频检测

In [None]:
import cv2
import time
import warnings
from functools import partial
import numpy as np
import onnx
import onnxruntime as rt
import mmcv
import torch
from packaging import version
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
from helper import preprocess_example_input,bbox2result
from torchvision import transforms

In [None]:
#载入onnx格式的模型
output_file='yolox.onnx'
onnx_model = onnx.load(output_file)#直接将onnx的图载入
#检查一下
onnx.checker.check_model(onnx_model)

In [2]:
input_shape=(1,3,640,640)
input_img='E:\dog.jpg'
normalize_cfg=dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
input_config = {
    'input_shape': input_shape,
    'input_path': input_img,
    'normalize_cfg': normalize_cfg
}
one_img, one_meta = preprocess_example_input(input_config)
tensor_data = [one_img]

In [None]:
#读入图片
from PIL import Image
import torchvision.transforms as transforms
pic_path='E:/dog.jpg'

#变形
img=Image.open(pic_path)
resize = transforms.Resize([640, 640])
img = resize(img)

#转成张量 ，预处理也可以在这里做
to_tensor = transforms.ToTensor()
img_y = to_tensor(img)
img_y.unsqueeze_(0)

In [10]:
#这个是将输入放到正确的设备上
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction执行的方法
#ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}#输入{'name':输入张量}
#ort_outs = ort_session.run(None, ort_inputs)#输出，其后要考虑对其进行后处理，生成最终的预测结果

In [13]:
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
colors = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
               (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
               (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
               (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
               (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
               (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
               (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
               (0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
               (78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
               (134, 134, 103), (145, 148, 174), (255, 208, 186),
               (197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
               (151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
               (166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
               (179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
               (147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
               (119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
               (95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
               (219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
               (127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
               (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
               (191, 162, 208)]

classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
               'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
               'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
               'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
               'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
               'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
               'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
               'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
               'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
               'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
               'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
               'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
               'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
               'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
n = len(classes)
thr_score = 0.3
#执行推断
#因为是在onnxruntime上执行的，所以是将.onnx文件直接作为图传入，形成session
ort_session = rt.InferenceSession('yolo.onnx')

In [None]:
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FPS, 30)
while True:
    ret,img = cap.read()
    
    print('one')
    gray = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = gray.astype(np.float32)
    img = cv2.resize(img,(320,320))
    #img = mmcv.imresize(img,(640,640),backend='cv2')
    img = mmcv.imnormalize(img,mean,std)
    
    #转成张量 ，预处理也可以在这里做
    to_tensor = transforms.ToTensor()
    img_y = to_tensor(img) 
    img_y.unsqueeze_(0)
    ort_inputs = {ort_session.get_inputs()[0].name: img_y.numpy()}#输入{'name':输入张量}
    start_time = time.time()
    ort_outs = ort_session.run(None, ort_inputs)#输出，其后要考虑对其进行后处理，生成最终的预测结果
    ort_dets, ort_labels = ort_outs[:2]
    stop_time = time.time()
    onnx_results = bbox2result(ort_dets,ort_labels,n)
    
    
    for i in range(n):
        if len(onnx_results[i])!=0:
            for x1,y1,x2,y2,score in onnx_results[i]:
                if score>thr_score:
                    # choose color for the label
                    color = tuple(map(int, colors[i]))
                    # draw box
                    cv2.rectangle(img=img,
                                  #一定是整数类型
                                  pt1=(int(x1),int(y1)),
                                  pt2=(int(x2),int(y2)),
                                  color=color, 
                                  thickness=3)

                    # draw label name inside the box
                    cv2.putText(img=img,
                                #
                                text=f"{classes[i]}{score:.2f}time:{stop_time-start_time}",
                                org=(int(x1) + 10, int(x2) + 5),
                                fontFace=cv2.FONT_HERSHEY_COMPLEX, 
                                fontScale=gray.shape[1] / 1000, 
                                color=color,
                                thickness=1, 
                                lineType=cv2.LINE_AA)
                    cv2.imshow('img',img)

    if cv2.waitKey(1) &0xFF == ord('q'):
        break
    
cap.release()
cv2.destroyAllWindows()