-
Notifications
You must be signed in to change notification settings - Fork 8
/
predict.py
89 lines (75 loc) · 2.57 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import errno
from argparse import ArgumentParser, ArgumentTypeError
from pathlib import Path
import torch
from wpodnet.backend import Predictor
from wpodnet.model import WPODNet
from wpodnet.stream import ImageStreamer
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
'source',
type=str,
help='the path to the image'
)
parser.add_argument(
'-w', '--weight',
type=str,
required=True,
help='the path to the model weight'
)
parser.add_argument(
'--scale',
type=float,
default=1.0,
help='adjust the scaling ratio. default to 1.0.'
)
parser.add_argument(
'--save-annotated',
type=str,
help='save the annotated image at the given folder'
)
parser.add_argument(
'--save-warped',
type=str,
help='save the warped image at the given folder'
)
args = parser.parse_args()
if args.scale <= 0.0:
raise ArgumentTypeError(message='scale must be greater than 0.0')
if args.save_annotated is not None:
save_annotated = Path(args.save_annotated)
if not save_annotated.is_dir():
raise FileNotFoundError(errno.ENOTDIR, 'No such directory', args.save_annotated)
else:
save_annotated = None
if args.save_warped is not None:
save_warped = Path(args.save_warped)
if not save_warped.is_dir():
raise FileNotFoundError(errno.ENOTDIR, 'No such directory', args.save_warped)
else:
save_warped = None
# Prepare for the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = WPODNet()
model.to(device)
checkpoint = torch.load(args.weight)
model.load_state_dict(checkpoint)
predictor = Predictor(model)
streamer = ImageStreamer(args.source)
for i, image in enumerate(streamer):
prediction = predictor.predict(image, scaling_ratio=args.scale)
print(f'Prediction #{i}')
print(' bounds', prediction.bounds.tolist())
print(' confidence', prediction.confidence)
if save_annotated:
annotated_path = save_annotated / Path(image.filename).name
annotated = prediction.annotate()
annotated.save(annotated_path)
print(f'Saved the annotated image at {annotated_path}')
if save_warped:
warped_path = save_warped / Path(image.filename).name
warped = prediction.warp()
warped.save(warped_path)
print(f'Saved the warped image at {warped_path}')
print()