-
Notifications
You must be signed in to change notification settings - Fork 21
/
demo.py
101 lines (77 loc) · 3.29 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import cv2
import numpy as np
import argparse
import time
import torch
from KittiCalibration import KittiCalibration
from visualizer import Visualizer
from BiSeNetv2.model.BiseNetv2 import BiSeNetV2
from BiSeNetv2.utils.utils import preprocessing_kitti, postprocessing
from pointpainting import PointPainter
from bev_utils import boundary
import tensorrt as trt
dev = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(dev)
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
def main(args):
# Semantic Segmentation
bisenetv2 = BiSeNetV2()
checkpoint = torch.load(args.weights_path, map_location=dev)
bisenetv2.load_state_dict(checkpoint['bisenetv2'], strict=False)
bisenetv2.eval()
bisenetv2.to(device)
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
print(TRT_LOGGER, builder)
# Fusion
painter = PointPainter()
# Visualizer
visualizer = Visualizer(args.mode)
image = cv2.imread(args.image_path)
pointcloud = np.fromfile(args.pointcloud_path, dtype=np.float32).reshape((-1, 4))
# if calib file is in kitti video format
# calib = KittiCalibration(args.calib_path, from_video=True)
# if calib file is in normal kitti format
calib = KittiCalibration(args.calib_path)
t1 = time_synchronized()
input_image = preprocessing_kitti(image)
print(f'Time of preprocessing = {1000 * (time_synchronized()-t1)} ms')
print(input_image.shape)
semantic = bisenetv2(input_image)
t2 = time_synchronized()
semantic = postprocessing(semantic)
t3 = time_synchronized()
painted_pointcloud = painter.paint(pointcloud, semantic, calib)
import pandas as pd
df = pd.DataFrame(painted_pointcloud[:,3])
print(df.value_counts())
t4 = time_synchronized()
print(f'Time of bisenetv2 = {1000 * (t2-t1)} ms')
print(f'Time of postprocesssing = {1000 * (t3-t2)} ms')
print(f'Time of pointpainting = {1000 * (t4-t3)} ms')
print(f'Time of Total = {1000 * (t4-t1)} ms')
if args.mode == '3d':
visualizer.visuallize_pointcloud(painted_pointcloud, blocking=True)
else:
color_image = visualizer.get_colored_image(image, semantic)
scene_2D = visualizer.get_scene_2D(color_image, painted_pointcloud, calib)
scene_2D = cv2.resize(scene_2D, (600, 900))
cv2.imshow("scene", scene_2D)
if cv2.waitKey(0) == 27:
cv2.destroyAllWindows()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, default='Kitti_sample/image_2/000038.png')
parser.add_argument('--pointcloud_path', type=str, default='Kitti_sample/velodyne/000038.bin')
parser.add_argument('--calib_path', type=str, default='Kitti_sample/calib/000038.txt')
parser.add_argument('--weights_path', type=str, default='BiSeNetv2/checkpoints/BiseNetv2_150.pth',)
parser.add_argument('--save_path', type=str, default='results',)
parser.add_argument('--mode', type=str, default='2d', choices=['2d', '3d'],
help='visualization mode .. img is semantic image .. 2d is semantic + bev .. 3d is colored pointcloud')
args = parser.parse_args()
main(args)
args.image_path = 'Kitti_sample/image_2/000031.png'
main(args)