In [None]:
"""
flip_coors:
5.679806465846243
0.00000000e+00 2.53344143e-06 8.78724159e-03 6.77619578e-02
 1.74091761e-01 2.42535215e-01 2.22794639e-01 1.48205057e-01
 7.90547730e-02 3.53250405e-02 1.44127483e-02 4.86800770e-03
 1.61253547e-03 4.06617349e-04 1.03871099e-04 3.16680178e-05
 3.80016214e-06 2.53344143e-06
reverse_coors:
3.0068771376087198
[0.00000000e+00 4.74460333e-04 2.84478297e-01 4.73738494e-01
 1.96769103e-01 3.86634427e-02 5.34211888e-03 4.89683659e-04
 4.31327576e-05 1.26861052e-06]
shift:
3.9410063569260574
[0.00000000e+00 7.65772505e-04 9.80607192e-02 3.05146397e-01
 3.06869385e-01 1.80180945e-01 7.49075114e-02 2.48128041e-02
 6.89955955e-03 1.79145124e-03 4.15849969e-04 1.15373010e-04
 2.91602113e-05 5.07134109e-06]
"""

In [None]:
from rd3d import set_workspace

set_workspace()

In [273]:
import numpy as np
import accelerate
from rd3d.core import Config
from rd3d import build_dataloader, build_detector

acc = accelerate.Accelerator()
cfg = Config.fromfile_py("configs/voxformer/voxformer_4x2_80e_kitti_3cls.py")
dataloader = build_dataloader(cfg.DATASET, cfg.RUN, training=False)
model = build_detector(cfg.MODEL, dataset=dataloader.dataset).cuda()


def collate_mapping_result():
    import torch
    from tqdm import tqdm
    import pickle

    all_multigroup_type = ['reverse_coors', 'flip_coors', 'shift']
    which_g1_in_g2 = {'reverse_coors': [], 'flip_coors': [], 'shift': []}
    group_size = model.backbone_3d.group_size
    with torch.no_grad():
        for batch_dict in tqdm(iterable=dataloader):
            dataloader.dataset.load_data_to_gpu(batch_dict)
            batch_dict = model.vfe(batch_dict)
            vox_numbs = batch_dict['voxel_numbers']
            vox_coors = batch_dict['voxel_coords']
            for multigroup_type in list(which_g1_in_g2.keys()):
                model.backbone_3d.cfg.MULTI_GROUP = multigroup_type
                (_, _, map1to2_batch, _), vox_nums = model.backbone_3d.mapping(vox_numbs, vox_coors)
                #  (k, 32) // 32, the elems of each g2 group are gather from which g1 group.
                map1to2_batch = torch.split(map1to2_batch, vox_nums.tolist(), dim=0)
                num_groups = 0
                for map1to2 in map1to2_batch:
                    g1_in_g2 = torch.div(map1to2.view(-1, group_size), group_size, rounding_mode='trunc') - num_groups
                    num_groups += len(g1_in_g2)
                    which_g1_in_g2[multigroup_type].append([list(set(g1)) for g1 in g1_in_g2.tolist()])
    with open(f"data/cache/multigroup.pkl", 'wb') as f:
        pickle.dump(which_g1_in_g2, f)
    return which_g1_in_g2


def analysis_connect_nums_1step():
    from matplotlib import pyplot as plt
    width = 1
    max_connect = 0
    offset = 0
    for i, (multigroup_type, g1_in_g2_list) in enumerate(which_g1_in_g2.items()):
        group_nums = np.mean([len(g1_in_g2) for g1_in_g2 in g1_in_g2_list])
        group_connect = [len(g1) for g1_in_g2 in g1_in_g2_list for g1 in g1_in_g2]
        group_connect_nums = np.mean(group_connect)
        connect_nums_hist = np.bincount(group_connect)
        connect_nums_hist_f = connect_nums_hist / connect_nums_hist.sum()
        max_connect = max(max_connect, len(connect_nums_hist))

        print('------')
        print(multigroup_type)
        print(f"group nums {group_nums}\n"
              f" connected {group_connect_nums}")
        # print(connect_nums_hist)
        # print(f"connected hist: [{', '.join([f'{h:.1f}' for h in connect_nums_hist_f])}]")
        x = np.arange(1, len(connect_nums_hist))

        offset = offset + max(connect_nums_hist_f)
        plt.bar(x=x,
                bottom=offset / 2,
                height=connect_nums_hist_f[1:],
                width=width,
                alpha=0.9,
                label=multigroup_type,
                zorder=255 - i)
    plt.xticks(np.arange(max_connect) + 1)
    plt.legend()
    plt.show()


def graph_analysis(sample_ind, group_type):
    def get_coors(ind):
        batch_dict = dataloader.dataset.collate_batch([dataloader.dataset[ind]])
        dataloader.dataset.load_data_to_gpu(batch_dict)
        batch_dict = model.vfe(batch_dict)
        vox_numbs = batch_dict['voxel_numbers']
        vox_coors = batch_dict['voxel_coords']
        coors = []
        for multigroup_type in list(which_g1_in_g2.keys()):
            model.backbone_3d.cfg.MULTI_GROUP = multigroup_type
            (ind1, _, _, _), vox_nums = model.backbone_3d.mapping(vox_numbs, vox_coors)
            coors.append(vox_coors[ind1].cpu().numpy())
        return coors

    from matplotlib import pyplot as plt
    import networkx as nx
    from itertools import permutations

    grid = np.array(model.vfe.grid_size)
    connectives = []
    for i, (multigroup_type, g1_in_g2_list) in enumerate(which_g1_in_g2.items()):
        connectives.append(g1_in_g2_list[sample_ind])
    coors = get_coors(sample_ind)[group_type][:, [1, 2, 3]]
    connective = connectives[group_type]

    G = nx.Graph()
    G.add_nodes_from(list(range(len(connective))))
    node_list = []
    for nodes in connective:
        G.add_edges_from(list(permutations(nodes, 2)))
        node_list.extend(nodes)

    pos = nx.spring_layout(G)

    coors = coors.reshape((-1, model.backbone_3d.group_size, 3))
    group_pos = (coors.mean(axis=1)[:, :2] / grid[:2])
    size = (coors.reshape(-1, model.backbone_3d.group_size, 3) / grid).var(axis=1).sum(axis=1)
    size = (size * (100 ** 2))
    pos = {k: c for c, (k, v) in zip((group_pos * 2) - 1, pos.items())}
    plt.figure(figsize=(8, 8))  # Set a larger figure size
    nx.draw(G, pos, with_labels=True, font_weight='bold', node_size=size, node_color='skyblue', font_color='black',
            font_size=10, edge_color='gray', linewidths=1, alpha=0.7)


def highlight_group(pts, group_id, gs):
    import torch
    from rd3d.utils import viz_utils
    from matplotlib import pyplot as plt
    c = torch.zeros_like(pts).view(-1, gs, 3).cpu().float()
    gc = torch.randperm(len(group_id))[:, None].repeat(1, gs).view(-1)
    gc = torch.tensor(plt.get_cmap('tab20c')(gc / gc.max())[:, :3]).float()
    c[group_id] = gc.view(-1, gs, 3)
    viz_utils.viz_scene((pts, c.view(-1, 3)))


def viz_group2_connect_which_group1(sample_ind, group_type, g2_id):
    from matplotlib import pyplot as plt
    import networkx as nx
    from itertools import permutations

    def get_coors(ind, g_id):
        batch_dict = dataloader.dataset.collate_batch([dataloader.dataset[ind]])
        dataloader.dataset.load_data_to_gpu(batch_dict)
        batch_dict = model.vfe(batch_dict)
        vox_numbs = batch_dict['voxel_numbers']
        vox_coors = batch_dict['voxel_coords']
        vox_points = batch_dict['voxel_features'][:, :3]
        coors = []
        for multigroup_type in list(which_g1_in_g2.keys()):
            model.backbone_3d.cfg.MULTI_GROUP = multigroup_type
            (ind1, ind2, _, _), vox_nums = model.backbone_3d.mapping(vox_numbs, vox_coors)
            coors.append(vox_points[ind1 if g_id == 1 else ind2].cpu())
        return coors

    connectives = []
    for i, (multigroup_type, g1_in_g2_list) in enumerate(which_g1_in_g2.items()):
        connectives.append(g1_in_g2_list[sample_ind])

    points1 = get_coors(sample_ind, 1)[group_type]
    connections = connectives[group_type]

    G = nx.Graph()
    G.add_nodes_from(list(range(len(connections))))
    for nodes in connections:
        G.add_edges_from(list(permutations(nodes, 2)))

    # ego_subgraph = nx.generators.ego.ego_graph(G, g2_id, radius=1)
    # components = list(list(nx.connected_components(ego_subgraph))[0])
    # print(components)
    # pts_coors1 = pts_coors[ind1]
    highlight_group(points1, connections[g2_id], model.backbone_3d.group_size)


def viz_group1_connection_by_step(sample_ind, group_type, g1_id, steps):
    from matplotlib import pyplot as plt
    import networkx as nx
    from itertools import permutations

    def get_coors(ind, g_id):
        batch_dict = dataloader.dataset.collate_batch([dataloader.dataset[ind]])
        dataloader.dataset.load_data_to_gpu(batch_dict)
        batch_dict = model.vfe(batch_dict)
        vox_numbs = batch_dict['voxel_numbers']
        vox_coors = batch_dict['voxel_coords']
        vox_points = batch_dict['voxel_features'][:, :3]
        coors = []
        for multigroup_type in list(which_g1_in_g2.keys()):
            model.backbone_3d.cfg.MULTI_GROUP = multigroup_type
            (ind1, ind2, _, _), vox_nums = model.backbone_3d.mapping(vox_numbs, vox_coors)
            coors.append(vox_points[ind1 if g_id == 1 else ind2].cpu())
        return coors

    connectives = []
    for i, (multigroup_type, g1_in_g2_list) in enumerate(which_g1_in_g2.items()):
        connectives.append(g1_in_g2_list[sample_ind])

    points1 = get_coors(sample_ind, 1)[group_type]
    connections = connectives[group_type]

    G = nx.Graph()
    G.add_nodes_from(list(range(len(connections))))
    for nodes in connections:
        G.add_edges_from(list(permutations(nodes, 2)))

    ego_subgraph = nx.generators.ego.ego_graph(G, g1_id, radius=steps)
    components = list(list(nx.connected_components(ego_subgraph))[0])
    highlight_group(points1, components, model.backbone_3d.group_size)


[2023-12-19 18:55:00,444 cfg INFO] import module at root: /home/nrsl/workspace/temp/voxformer
[2023-12-19 18:55:00,444 cfg INFO] import module as config: configs.voxformer.voxformer_4x2_80e_kitti_3cls
[2023-12-19 18:55:00,449 dataset INFO] Loading KITTI dataset
[2023-12-19 18:55:00,582 dataset INFO] Total samples for KITTI dataset: 3769


In [283]:
# collate_mapping_result()
with open(f"data/cache/multigroup.pkl", 'rb') as f:
    import pickle

    which_g1_in_g2 = pickle.load(f)

# analysis_connect_nums_1step()
# graph_analysis(sample_ind=0, group_type=0)
# graph_analysis(sample_ind=0, group_type=1)
# graph_analysis(sample_ind=0, group_type=2)
viz_group2_connect_which_group1(sample_ind=0, group_type=0, g2_id=0)
viz_group1_connection_by_step(sample_ind=0, group_type=0, g1_id=3, steps=4)