# Inference with Torch

## 导包

In [2]:
# ---standard----
import sys
import os 
import numpy as np
from time import time
from matplotlib import pyplot as plt
from skimage import io,transform
from PIL import Image
sys.path.append('../')

# dl
import torch
import tensorrt as trt
import torch
import pycuda.driver as cuda
import pycuda.autoinit # 非常重要

# mine
from unet import UNet



## 构建TensorRT Engine

In [3]:
def ONNX_build_engine(onnx_file_path):
        '''
        通过加载onnx文件，构建engine
        :param onnx_file_path: onnx文件路径
        :return: engine
        '''
        # 打印日志
        G_LOGGER = trt.Logger(trt.Logger.VERBOSE)
        # create_network() without parameters will make parser.parse() return False
        with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
#         with trt.Builder(G_LOGGER) as builder, \
#                 builder.create_network(flags=1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network,\
#                 trt.OnnxParser(network, G_LOGGER) as parser:
#             print(builder)
#             print(network)
#             print(parser)
            builder.max_batch_size = 1
            builder.max_workspace_size = 1 << 20

            print('Loading ONNX file from path {}...'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                print(sys.getsizeof(model))
                parser.parse(model.read())
            print('Completed parsing of ONNX file')
            print(sys.getsizeof(parser))
            
            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
            engine = builder.build_cuda_engine(network)
            print(sys.getsizeof(engine))
            print(engine)
            print("Completed creating Engine")

            # 保存计划文件
            with open("model/unet.engine", "wb") as f:
                f.write(engine.serialize())
            return engine

engine = ONNX_build_engine("model/unet.onnx")

Loading ONNX file from path model/unet.onnx...
Beginning ONNX file parsing
4272
Completed parsing of ONNX file
56
Building an engine from file model/unet.onnx; this may take a while...
16
None
Completed creating Engine


AttributeError: 'NoneType' object has no attribute 'serialize'

## 读取图片

In [None]:
img = Image.open(os.path.join("images", "input.jpg"))
print("img type:", type(img),"| img size", img.size)
plt.imshow(img)

In [None]:
# PIL -> ndarray
img = np.array(img)
print("---PIL2ndarray----")
print("type img:",type(img), "| img shape :", img.shape, "| img[0][0]:", img[0][0], "| img.dtype:", img.dtype)


# 换dims & 增加dims
img = img.transpose((2, 0, 1))
print("---换dims----")
print("type img:",type(img), "| img shape :", img.shape, "| img[0][0]:", img[0][0], "| img.dtype:", img.dtype)

img = np.expand_dims(img, axis=0)
print("---增加dims----")
print("type img:",type(img), "| img shape :", img.shape, "| img[0][0]:", img[0][0], "| img.dtype:", img.dtype)


# uint8 -> float
if img.max() > 1:
    img = img / 255
img = img.astype(np.float32)
print("---uint8 2 float----")
print("type img:",type(img), "| img shape :", img.shape, "| img[0][0]:", img[0][0], "| img.dtype:", img.dtype)


output = np.empty((1, 1, 1280, 1918), dtype=np.float32)

## 执行

In [None]:
context = engine.create_execution_context()
 
# 分配内存
d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
bindings = [int(d_input), int(d_output)]

# pycuda操作缓冲区
stream = cuda.Stream()

In [None]:
%%time
# 将输入数据放入device
cuda.memcpy_htod_async(d_input, img, stream)

In [None]:
%%time
# 执行模型
context.execute_async(1, bindings, stream.handle, None)

In [None]:
%%time
# 将预测结果从从缓冲区取出
cuda.memcpy_dtoh_async(output, d_output, stream)

In [None]:
%%time
# 线程同步
stream.synchronize()

## 展示

In [None]:
print(type(output))
print(output.shape)
print(output[0][0])
print(output.dtype)


## 显示图片

In [None]:
result_numpy = result.cpu().detach().numpy()
print(type(result_numpy))
print(result_numpy.shape)
print(result_numpy.dtype)

In [None]:
res = np.array(result_numpy[0][0])
print(type(res))
print(res.shape)

In [None]:
res

In [None]:
# res= Image.open("1.png")
# mat = np.array(src)
res = res.astype(np.uint8)
res = Image.fromarray(res, 'P')
# dst.save('new.png')
print(type(res))

In [None]:
plt.imshow(res)