In [23]:
import numpy as np
from datasets import kitti, sampler
import yaml
from easydict import EasyDict
import matplotlib.pyplot as plt
import os, sys, time
from datasets.data_classes import Box, PointCloud
from datasets.points_utils import get_in_box_mask, crop_pc_axis_aligned, getOffsetBB
import open3d as o3d

In [24]:
file_name = "cfgs/STNet_Car.yaml"
pred_bb_root_dir = 'visual_results'
# trackers = [['BAT', [0, 0, 1]],  # 蓝
#             ['PVT', [1, 0, 0]],  # 红
#             ['P2B', [0, 1, 1]]] # 青

# trackers = [['LTTR', [0, 0, 1]],  # 蓝
#             ['PVT', [1, 0, 0]]]  # 红

trackers = [['SRSTNet', [0, 1, 1]]]
lines_box = np.array([[0, 1], [1, 2], [0, 3], [2, 3], [4, 5], [4, 7], 
                      [5, 6], [6, 7], [0, 4], [1, 5], [2, 6], [3, 7]])
tracklet_id = 4

In [25]:
with open(file_name, 'r') as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader)
    except:
        config = yaml.load(f)
config = EasyDict(config)
config.category_name = "Pedestrian"  # Car, Van, Pedestrian, Cyclist
data = kitti.kittiDataset(path = config.path, 
                          split = config.test_split, 
                          category_name = config.category_name, 
                          coordinate_mode = config.coordinate_mode, 
                          preloading = False, 
                          preload_offset = config.preload_offset if type != 'test' else -1)
data_loader = sampler.TestTrackingSampler(dataset=data, config=config)
box_list = []
pcd_list = []

In [26]:
instance = data_loader[tracklet_id]  # {"pc": pc, "3d_bbox": bb, 'meta': anno}
pred_bb = {}
for tracker, color in trackers:
    pred_bb_path = os.path.join(pred_bb_root_dir, 'kitti_'+tracker,
                                config.category_name, f'{tracklet_id:04.0f}.txt')
    pred_bb.update({tracker: np.loadtxt(pred_bb_path)})
    
for frame_id in range(len(instance)):
    boxes = []
    search_pc = instance[frame_id]['pc']
    gt_bb = instance[frame_id]['3d_bbox']
    # search_pc = crop_pc_axis_aligned(search_pc, gt_bb, 10)
    # search_pc = getOffsetBB(search_pc, gt_bb)
    mask = get_in_box_mask(search_pc, gt_bb)
    flag = np.reshape(mask, [1, -1]).repeat(3, 0)
    target_pcd = np.where(flag == 1., search_pc.points, 0)
    target_pcd = target_pcd[:, np.any(target_pcd, 0)]
    background_pcd = np.where(flag == 0., search_pc.points, 0)
    background_pcd = background_pcd[:, np.any(background_pcd, 0)]
    
    pcd_list.append({
        'target_pcd': target_pcd.T,
        'background_pcd': background_pcd.T
    })
    
    boxes.append({
        'corners': gt_bb.corners().T,
        'lines': lines_box,
        'colors': [0, 1, 0]
    })
    for tracker, color in trackers:
        try:
            cur_bb = pred_bb[tracker][frame_id].reshape([3, 8])
        except ValueError:
            cur_bb = pred_bb[tracker].reshape([3, 8])
        boxes.append({
            'corners': cur_bb.T,
            'lines': lines_box,
            'colors': color
        })
    box_list.append(boxes)

In [27]:
def show_pointcloud_dir(pcd_list, box_list, add_box):
    
    file_idx=0
    end_idx=0
    play_status=False
    backward_status=False
    forward_status=False

    vis = o3d.visualization.VisualizerWithKeyCallback()
    vis_box = add_box
    open3d_status=True
    end_idx = len(pcd_list)-1

    def draw_box(vis, boxes):
        nonlocal lineset_list
        update_number = len(boxes)
        for id in range(update_number):
            points, lines, color = boxes[id]['corners'], boxes[id]['lines'], boxes[id]['colors']
            lineset_list[id].lines = o3d.utility.Vector2iVector(lines)
            lineset_list[id].colors = o3d.utility.Vector3dVector(np.array([color for j in range(len(lines_box))]))  # 线条颜色
            lineset_list[id].points = o3d.utility.Vector3dVector(points)
            vis.update_geometry(lineset_list[id])

        return vis

    def key_space_callback(vis,action,modes):
        nonlocal play_status

        if action==1:
            play_status = not play_status
        return True

    def key_forward_callback(vis,action,modes):
        nonlocal play_status,file_idx,forward_status,end_idx

        if action==1:
            file_idx+=1
            if file_idx>=end_idx:
                file_idx=end_idx-1
            play_status= True
            forward_status=True
        elif action==0:
            play_status =False
            forward_status=False
            # print(play_status)
        elif action==2:
            pass

    def key_back_callback(vis,action,modes):
        nonlocal play_status,file_idx,backward_status
        if action == 1:
            file_idx -= 1
            if file_idx <0:
                file_idx = 0
            play_status = True
            backward_status=True
        elif action == 0:
            play_status = False
            backward_status=False
        else:
            pass
        return True

    def animation_callback(vis):
        nonlocal play_status,file_idx,pcd_list,end_idx,\
        lineset_list,backward_status,forward_status,vis_box
        if play_status and file_idx < end_idx:
            target_datas_i = pcd_list[file_idx]['target_pcd']
            target_pcd.points = o3d.utility.Vector3dVector(target_datas_i)
            target_pcd.paint_uniform_color([0, 0, 0])
            vis.update_geometry(target_pcd)
            
            background_datas_i = pcd_list[file_idx]['background_pcd']
            background_pcd.points = o3d.utility.Vector3dVector(background_datas_i)
            background_pcd.paint_uniform_color([0.7, 0.7, 0.7])
            vis.update_geometry(background_pcd)
            
            if vis_box:
                boxes_i = box_list[file_idx]
                vis = draw_box(vis, boxes_i)

            if not backward_status and not forward_status:
                file_idx += 1
            
            sys.stdout.write("\r")  # 清空终端并清空缓冲区
            sys.stdout.write("{}/{}s".format(file_idx , end_idx ))  # 往缓冲区里写数据
            sys.stdout.flush()  # 将缓冲区里的数据刷新到终端，但是不会清空缓冲区

    vis.register_key_action_callback(32, key_space_callback)  # space
    vis.register_key_action_callback(262, key_forward_callback)  # ->
    vis.register_key_action_callback(263, key_back_callback)  # <-

    vis.register_animation_callback(animation_callback)
    vis.create_window()

    target_pcd = o3d.geometry.PointCloud()
    background_pcd = o3d.geometry.PointCloud()
    datas = pcd_list[0]
    target_pcd.points = o3d.utility.Vector3dVector(datas['target_pcd'])
    target_pcd.paint_uniform_color([0, 0, 0])
    vis.add_geometry(target_pcd)
    background_pcd.points = o3d.utility.Vector3dVector(datas['background_pcd'])
    background_pcd.paint_uniform_color([0.7,0.7,0.7])
    vis.add_geometry(background_pcd)
    lineset_list = []
    if vis_box:
        for i in range(2):
            lineset = o3d.geometry.LineSet()
            lineset_list.append(lineset)
            vis.add_geometry(lineset)
        if len(box_list[0]) > len(lineset_list):
            for i in range(len(box_list[0]) - len(lineset_list)):
                lineset = o3d.geometry.LineSet()
                lineset_list.append(lineset)
                vis.add_geometry(lineset)
        vis = draw_box(vis, box_list[0])

    render_option = vis.get_render_option()
    render_option.point_size = 4
    render_option.background_color = np.asarray([1, 1, 1])  # 颜色 0为黑；1为白
    vis.run()

In [28]:
show_pointcloud_dir(pcd_list, box_list, add_box=True)

