-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
149 lines (117 loc) · 5.01 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import time
import cv2
import numpy as np
from PIL import Image
from mban import MBAN, YOLO_ONNX
if __name__ == "__main__":
# ----------------------------------------------------------------------------------------------------------#
# mode is used to specify the mode of the test:
# 'predict' means single picture prediction
# 'video' means video detection
# 'fps' means test fps
# 'dir_predict' means traverse the folder to detect and save
# ----------------------------------------------------------------------------------------------------------#
mode = "predict"
crop = False
count = False
video_path = 0
video_save_path = ""
video_fps = 25.0
test_interval = 100
fps_image_path = "img/000031.jpg"
dir_origin_path = "VOCdevkit/VOC2007/JPEGImages/"
dir_save_path = "VOCdevkit/VOC2007/img_out/"
heatmap_save_path = "model_data/heatmap_vision.png"
simplify = True
onnx_save_path = "model_data/models.onnx"
if mode != "predict_onnx":
yolo = MBAN()
else:
yolo = YOLO_ONNX()
if mode == "predict":
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = yolo.detect_image(image, crop = crop, count=count)
r_image.show()
elif mode == "video":
capture = cv2.VideoCapture(video_path)
if video_save_path!="":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
ref, frame = capture.read()
if not ref:
raise ValueError("Please pay attention to whether the camera is installed correctly (whether the video path is filled in correctly).")
fps = 0.0
while(True):
t1 = time.time()
ref, frame = capture.read()
if not ref:
break
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
frame = Image.fromarray(np.uint8(frame))
frame = np.array(yolo.detect_image(frame))
frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
fps = ( fps + (1./(time.time()-t1)) ) / 2
print("fps= %.2f"%(fps))
frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video",frame)
c= cv2.waitKey(1) & 0xff
if video_save_path!="":
out.write(frame)
if c==27:
capture.release()
break
print("Video Detection Done!")
capture.release()
if video_save_path!="":
print("Save processed video to the path :" + video_save_path)
out.release()
cv2.destroyAllWindows()
elif mode == "fps":
img = Image.open(fps_image_path)
tact_time = yolo.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
elif mode == "dir_predict":
import os
from tqdm import tqdm
img_names = os.listdir(dir_origin_path)
for img_name in tqdm(img_names):
if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(dir_origin_path, img_name)
image = Image.open(image_path)
r_image = yolo.detect_image(image)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
elif mode == "heatmap":
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
yolo.detect_heatmap(image, heatmap_save_path)
elif mode == "export_onnx":
yolo.convert_to_onnx(simplify, onnx_save_path)
elif mode == "predict_onnx":
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = yolo.detect_image(image)
r_image.show()
else:
raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")