-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_pytorch_ImageClassification_onnx.py
73 lines (58 loc) · 2.19 KB
/
main_pytorch_ImageClassification_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import cv2
import time
import numpy as np
from PIL import Image
import onnxruntime as ort
import torchvision.transforms as trns
onnxmodel_path='./weight/mobilenetv2.onnx'
class_def = './weight/imagenet_classes.txt'
def softmax(x):
x = x.reshape(-1)
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
def postprocess(result):
return softmax(np.array(result)).tolist()
def main():
# Run the model on the backend
session = ort.InferenceSession(onnxmodel_path, None)
# get the name of the first input of the model
input_name = session.get_inputs()[0].name
# Load ImageNet classes
with open(class_def) as f:
classes = [line.strip() for line in f.readlines()]
# Define image transforms
transforms = trns.Compose([trns.Resize((224, 224)),
trns.ToTensor(),
trns.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
cap = cv2.VideoCapture(1)
while True:
ret, img = cap.read()
if not ret:
break
# Read image and run prepro
image = Image.fromarray(img)#.convert("RGB")
image_tensor = transforms(image)
image_tensor = image_tensor.unsqueeze(0)
image_np = image_tensor.numpy()
# model run
outputs = session.run([], {input_name: image_np})[0]
print("Output size:{}".format(outputs.shape))
# Result postprocessing
idx = np.argmax(outputs)
sort_idx = np.flip(np.squeeze(np.argsort(outputs)))
idx = np.argmax(outputs)
# outputs = np.sort(outputs[0,:])
probs = postprocess(outputs)
top_k=3
cv2.putText(img, "Inference results:", (0, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1, cv2.LINE_AA)
print("Inference results:")
for i, index in enumerate(sort_idx[:top_k]):
py = 35 + 15*i
text = "Label {}: {} ({:5f}) \n".format(index, classes[index],probs[index])
cv2.putText(img, text, (0, py), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1, cv2.LINE_AA)
print(text)
cv2.imshow('demo', img)
cv2.waitKey(1)
cap.release()
if __name__ == '__main__':
main()