In [16]:
import numpy as np
import os
import matplotlib.pyplot as plt

import open3d as o3d
import os

In [6]:
class Visualizer:

    def __init__(self, config) -> None:

        self.config = config

        self.pcds = {} # points are mapped from 3d coordinates to Point
        self.experiment_name = config["experiment_name"]
        self.experiment_run = config["experiment_run"]
        self.video_name = config['video']
        self.work_dir = os.path.join(config['work_dir'], self.video_name)
        self.load_work_dir()

    def idx_to_color(self, idx, rand=False):

        if rand:
            if idx == 0:
                return np.array([0, 0, 0], dtype=np.uint8)
            np.random.seed(idx) # randomly assign to some color, for visualizing gt
            id = np.random.randint(0, 256**3)
        else:
            id = self.idx_to_ids_dict[idx]

        rgb = np.zeros((3, ), dtype=np.float32)
        for i in range(3):
            rgb[i] = id % 256
            id = id // 256
        return rgb
            

    def load_work_dir(self):
        
        self.pcd_orig = o3d.io.read_point_cloud(os.path.join(self.work_dir, self.experiment_name, self.experiment_run, f"{self.video_name}_vh_clean_2.ply"))

        self.pred_inst = np.load(os.path.join(self.work_dir,self.experiment_name, self.experiment_run,"pred_inst.npy"))
        self.pred_super = np.load(os.path.join(self.work_dir,self.experiment_name, self.experiment_run,"pred_super.npy"))
        self.gt_inst = np.load(os.path.join(self.work_dir,self.experiment_name, self.experiment_run,"gt_inst.npy"))
        
        self.cluster_vis_tmp = np.load(os.path.join(self.work_dir,self.experiment_name, self.experiment_run,"clustering_vis.npy"))
        
        self.cluster_vis = []
        for cluster in self.cluster_vis_tmp:
            self.cluster_vis.append(self.finalize_clustering(cluster, self.pred_super))
        self.cluster_vis=np.array(self.cluster_vis)

    def gen_pcds(self, pcd_points, color):

        pcd_res = o3d.geometry.PointCloud()
        pcd_res.points = pcd_points
        if np.max(color) > 2:
            pcd_res.colors = o3d.utility.Vector3dVector(np.array(color) / 255)
        else:
            pcd_res.colors = o3d.utility.Vector3dVector(np.array(color))

        return pcd_res
    
    def finalize_clustering(self, cluster_super, cluster_point):

        final_cluster = np.zeros(cluster_point.shape[0], dtype=np.int32)

        super_ids = np.unique(cluster_point)

        for super_idx in np.unique(cluster_super):

            super_ids_cluster = super_ids[cluster_super == super_idx]

            final_cluster[np.isin(cluster_point, super_ids_cluster)] = super_idx

        return final_cluster
        
    def visualize_clustering(self):
        
        """
        visualize clustering one cut at a time
        """

        vis = o3d.visualization.VisualizerWithKeyCallback()

        idx = -1
        color_npy = [[0,0,0] for i in range(self.cluster_vis[0].shape[0])]
        pcd = self.gen_pcds(self.pcd_orig.points, color_npy)
            
        def key_action_callback(vis, key, action):

            nonlocal idx
            nonlocal pcd

            if action == 0 and key == 0:  # key down

                idx += 1
                if idx >= self.cluster_vis.shape[0]:
                    idx = 0
                clustering = self.cluster_vis[idx]
                color_npy = np.array([self.idx_to_color(i, rand=True) for i in clustering]) / 255
                pcd.colors = o3d.utility.Vector3dVector(color_npy)
                vis.update_geometry(pcd)
                vis.update_renderer()
                vis.poll_events()
                vis.run()

            elif action == 2 and key == 0:

                idx -= 1
                if idx <= -1:
                    idx = self.cluster_vis.shape[0] - 1
                clustering = self.cluster_vis[idx]
                color_npy = np.array([self.idx_to_color(i, rand=True) for i in clustering]) / 255
                pcd.colors = o3d.utility.Vector3dVector(color_npy)
                vis.update_geometry(pcd)
                vis.update_renderer()
                vis.poll_events()
                vis.run()

            else:  # key repeat
                pass
    
        vis.register_key_action_callback(65, key_action_callback)  # space
        vis.create_window(width=1220, height=1220)
        opt = vis.get_render_option()
        opt.background_color = np.asarray([255,255,255])

        vis.add_geometry(pcd)
    
        vis.poll_events()
        vis.update_renderer()
        vis.run()

    def visualize_all(self, mode=0):

        vis = o3d.visualization.Visualizer()

        if mode == 0:
            pcd = self.gen_pcds(self.pcd_orig.points, [self.idx_to_color(i, rand=True) for i in self.pred_super])
        elif mode == 1:
            pcd = self.gen_pcds(self.pcd_orig.points, [self.idx_to_color(i, rand=True) for i in self.pred_inst])
        elif mode == 2:
            gt_inst_copy = self.gt_inst.copy()
            for inst in gt_inst_copy:
                if inst // 1000 in [1,3]:
                    self.gt_inst[self.gt_inst==inst] = 0
            pcd = self.gen_pcds(self.pcd_orig.points, [self.idx_to_color(i, rand=True) for i in self.gt_inst])

        vis.create_window(width=1220, height=1220)
        opt = vis.get_render_option()
        opt.background_color = np.asarray([255,255,255])
    
        vis.add_geometry(pcd)
        
        vis.poll_events()
        vis.update_renderer()
        vis.run()

In [8]:
config = {
    "video": "",   # insert scan
    "work_dir":"../data/data_val_preprocessed/vis/",
    "experiment_name": "test",
    "experiment_run": "test",
}
    
vis = Visualizer(config)

In [None]:
vis.visualize_all(1)           # Visualize clusters

In [None]:
vis.visualize_clustering()     # Visualize clustering one cut at a time, along with ground truth and superpoints