-
Notifications
You must be signed in to change notification settings - Fork 3
/
sample_onnx.py
124 lines (92 loc) · 3.54 KB
/
sample_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import time
import argparse
import cv2 as cv
import numpy as np
import onnxruntime
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--movie", type=str, default=None)
parser.add_argument("--width", help='cap width', type=int, default=960)
parser.add_argument("--height", help='cap height', type=int, default=540)
parser.add_argument(
"--model",
type=str,
default='model/opensketch_style_512x512.onnx',
)
parser.add_argument("--input_size", type=int, default=512)
args = parser.parse_args()
return args
def run_inference(onnx_session, input_size, image):
image_width, image_height = image.shape[1], image.shape[0]
# リサイズ
temp_image = copy.deepcopy(image)
resize_image = cv.resize(temp_image, dsize=(input_size, input_size))
x = cv.cvtColor(resize_image, cv.COLOR_BGR2RGB)
# 前処理
x = np.array(x, dtype=np.float32)
x = x.transpose(2, 0, 1).astype('float32')
x = x.reshape(-1, 3, input_size, input_size)
# 推論
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
onnx_result = onnx_session.run([output_name], {input_name: x})
# 後処理
onnx_result = np.array(onnx_result).squeeze()
onnx_result = onnx_result * 255
onnx_result = onnx_result.astype(np.uint8)
onnx_result = cv.resize(onnx_result, dsize=(image_width, image_height))
return onnx_result
def main():
# 引数解析 #################################################################
args = get_args()
cap_device = args.device
cap_width = args.width
cap_height = args.height
if args.movie is not None:
cap_device = args.movie
model_path = args.model
input_size = args.input_size
# カメラ準備 ###############################################################
cap = cv.VideoCapture(cap_device)
cap.set(cv.CAP_PROP_FRAME_WIDTH, cap_width)
cap.set(cv.CAP_PROP_FRAME_HEIGHT, cap_height)
# モデルロード #############################################################
onnx_session = onnxruntime.InferenceSession(
model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
)
elapsed_time = 0.0
while True:
start_time = time.time()
# カメラキャプチャ #####################################################
ret, image = cap.read()
if not ret:
break
result_image = run_inference(
onnx_session,
input_size,
image,
)
elapsed_time = time.time() - start_time
# 描画 ###############################################################
# フレーム経過時間
elapsed_time_text = "Elapsed time: "
elapsed_time_text += str(round((elapsed_time * 1000), 1))
elapsed_time_text += 'ms'
cv.putText(image, elapsed_time_text, (10, 30), cv.FONT_HERSHEY_SIMPLEX,
0.7, (0, 255, 0), 1, cv.LINE_AA)
# 画面反映 ############################################################
cv.imshow('Informative Drawings Before', image)
cv.imshow('Informative Drawings After', result_image)
# キー処理(ESC:終了) #################################################
key = cv.waitKey(1)
if key == 27: # ESC
break
cap.release()
cv.destroyAllWindows()
if __name__ == '__main__':
main()