In [1]:
import os

import numpy as np
import torch
import accelerate

from rd3d.datasets import build_dataloader
from rd3d.models import build_detector
from rd3d.core import Config
from rd3d import PROJECT_ROOT
from rd3d.models.dense_heads.point_segment_head import PointSegmentor

os.chdir(PROJECT_ROOT)
acc = accelerate.Accelerator()
cfg = Config.fromfile_py("configs/voxformer/voxformer_4x2_80e_kitti_3cls.py")
dataloader = build_dataloader(cfg.DATASET, cfg.RUN,training=True)
model = build_detector(cfg.MODEL, dataset=dataloader.dataset).cuda()
batch_dict = next(iter(dataloader))
dataloader.dataset.load_data_to_gpu(batch_dict)

[2023-12-19 15:27:23,995 cfg INFO] import module at root: /home/nrsl/workspace/temp/voxformer
[2023-12-19 15:27:23,996 cfg INFO] import module as config: configs.voxformer.voxformer_4x2_80e_kitti_3cls
[2023-12-19 15:27:24,141 dataset INFO] Database filter by min points Car: 14357 => 13532
[2023-12-19 15:27:24,143 dataset INFO] Database filter by min points Pedestrian: 2207 => 2168
[2023-12-19 15:27:24,144 dataset INFO] Database filter by min points Cyclist: 734 => 705
[2023-12-19 15:27:24,170 dataset INFO] Database filter by difficulty Car: 13532 => 10759
[2023-12-19 15:27:24,177 dataset INFO] Database filter by difficulty Pedestrian: 2168 => 2075
[2023-12-19 15:27:24,178 dataset INFO] Database filter by difficulty Cyclist: 705 => 581
[2023-12-19 15:27:24,185 dataset INFO] Loading KITTI dataset
[2023-12-19 15:27:24,312 dataset INFO] Total samples for KITTI dataset: 3712


In [2]:
from rd3d.core.ckpt import load_from_file

load_from_file("/home/nrsl/workspace/temp/voxformer/output/kitti_3cls/voxformer_4x2_80e_kitti_3cls/ssd/train/ckpt/checkpoint_epoch_80.pth", model)
model.eval()
pred_dicts, _ = model(batch_dict)
pred_dict = pred_dicts[0]
print(pred_dict)
from rd3d.utils import viz_utils
points = batch_dict['points']
points = points[points[:,0]==0]
viz_utils.viz_scene(points,pred_dict['pred_boxes'].detach()[:, :7])



[2023-12-19 15:27:29,290 ckpt INFO] load checkpoint /home/nrsl/workspace/temp/voxformer/output/kitti_3cls/voxformer_4x2_80e_kitti_3cls/ssd/train/ckpt/checkpoint_epoch_80.pth to cuda:0
[2023-12-19 15:27:29,290 ckpt INFO] checkpoint trained from version: 0.5.2+770d684
[2023-12-19 15:27:29,309 ckpt INFO] loaded params for model (237/237)


{'pred_boxes': tensor([[24.2103, 10.6113, -0.7056,  4.2479,  1.7034,  1.5628,  3.2845],
        [16.3473,  8.7749, -0.6641,  1.8315,  0.5711,  1.7981,  3.6880],
        [12.0736,  0.5378, -0.7651,  4.0129,  1.6483,  1.5954,  0.1473],
        [12.3804, -4.6490, -0.8502,  3.7806,  1.5780,  1.4652,  0.1190],
        [ 4.8694,  4.5225, -0.9243,  0.8708,  0.6281,  1.7320,  3.6963],
        [13.7108, 12.4291, -1.1486,  3.8717,  1.6159,  1.4784,  0.5592],
        [20.3610,  7.1796, -0.7372,  3.8472,  1.5999,  1.4860,  0.5008],
        [19.8355,  9.1793, -0.6157,  0.8214,  0.5939,  1.7549,  3.6005],
        [13.9456,  3.7279, -0.7868,  3.8826,  1.6130,  1.5248,  0.0583]],
       device='cuda:0', grad_fn=<IndexBackward0>), 'pred_scores': tensor([0.8032, 0.7908, 0.6947, 0.6122, 0.5848, 0.4756, 0.4441, 0.3651, 0.2930],
       device='cuda:0', grad_fn=<IndexBackward0>), 'pred_labels': tensor([1, 3, 1, 1, 2, 1, 1, 2, 1], device='cuda:0')}
Jupyter environment detected. Enabling Open3D WebVisualizer.

In [None]:

from rd3d.utils import viz_utils
from matplotlib import pyplot as plt
cls_scores = batch_dict['cls_logits'].sigmoid()
cls_scores, pred_labels = cls_scores.max(dim=-1)
pred_labels += 1
pred_labels[cls_scores < 0.1] = 0
points = batch_dict['point_coords']

ends = torch.cumsum(batch_dict['voxel_numbers'], dim=0)
begins = ends - batch_dict['voxel_numbers']

print(plt.cm.Set1(0))
print(plt.cm.Set1(1))
print(plt.cm.Set1(2))
print(plt.cm.Set1(3))
for begin, end in zip(begins, ends):
    this_points = points[begin:end]
    this_pred_labels = pred_labels[begin:end].detach().cpu().numpy()
    this_colors = plt.cm.Set1(this_pred_labels)[:, :3]
    viz_utils.viz_scene((this_points, this_colors))