In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import os
import os.path as osp
import glob
from torch_geometric.data import Data
import matplotlib.pyplot as plt
from plyfile import PlyData

import sys
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.visualization.multimodal_data import visualize_mm_data, hex_to_tensor
from torch_points3d.core.multimodal.data import MMData
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData

# Dev dataset

In [None]:
from omegaconf import OmegaConf
from torch_points3d.utils.config import hydra_read

# Set root to the DATA drive, where the data was downloaded
# DATA_ROOT = '/mnt/fa444ffd-fdb4-4701-88e7-f00297a8e29b/projects/datasets/kitti360'  # ???
# DATA_ROOT = '/media/drobert-admin/DATA/datasets/kitti360'  # IGN DATA
# DATA_ROOT = '/media/drobert-admin/DATA2/datasets/kitti360'  # IGN DATA2
# DATA_ROOT = '/var/data/drobert/datasets/kitti360'  # AI4GEO
# DATA_ROOT = '/home/qt/robertda/scratch/datasets/kitti360'  # CNES
DATA_ROOT = '/raid/dataset/pointcloud/data/kitti360'  # ENGIE

overrides = [
    'task=segmentation',
#     'data=segmentation/kitti360-sparse',
    'data=segmentation/multimodal/kitti360-sparse',
    'data.mini=True',
#     'models=segmentation/sparseconv3d',
    'models=segmentation/multimodal/sparseconv3d',
#     'model_name=Res16UNet34',
    'model_name=Res16UNet34-PointPyramid-early-cityscapes-interpolate',
    f"data.dataroot={os.path.join(DATA_ROOT, '5cm')}",
#     f"data.dataroot={os.path.join(DATA_ROOT, 'temp')}",
#     '+train_is_trainval=True',
    'data.sample_per_epoch=5'
]

cfg = hydra_read(overrides)

In [None]:
# from torch_points3d.datasets.segmentation.kitti360 import KITTI360Dataset

# dataset = KITTI360Dataset(cfg.data)

In [None]:
from torch_points3d.datasets.segmentation.multimodal.kitti360 import KITTI360DatasetMM

# NB: preprocessing 3D and 2D data takes roughly 1 min per KITTI-360 window 
dataset = KITTI360DatasetMM(cfg.data)

In [None]:
from torch_points3d.core.multimodal import MMData, ImageData
from torch_points3d.datasets.segmentation.kitti360 import KITTI360_NUM_CLASSES, INV_OBJECT_LABEL, OBJECT_COLOR, CLASS_NAMES, CLASS_COLORS

train_2d_transforms = dataset.train_dataset.transform_image.transforms
val_2d_transforms = dataset.val_dataset.transform_image.transforms
test_2d_transforms = dataset.test_dataset[0].transform_image.transforms

In [None]:
dataset.train_dataset.transform = None
dataset.train_dataset.transform_image.transforms = train_2d_transforms[:4]
dataset.train_dataset.transform_image.transforms[3].credit = 1408 * 376 * 3

dataset.val_dataset.transform = None
dataset.val_dataset.transform_image.transforms = val_2d_transforms[:4]
dataset.val_dataset.transform_image.transforms[3].credit = 1408 * 376 * 3

dataset.test_dataset[0].transform = dataset.val_dataset.transform
dataset.test_dataset[0].transform_image = dataset.val_dataset.transform_image

In [None]:
# Train sample
# mm_data = dataset.train_dataset[dataset.train_dataset._pick_random_label_and_window()]
mm_data = dataset.val_dataset[50]
# mm_data = dataset.test_dataset[0][44]

# Val sample
# mm_data = dataset.val_dataset[np.random.randint(len(dataset.val_dataset[0]))]

visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.5, show_2d=True, front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=2)

In [None]:
unseen = 0
samples = 1000

from tqdm import tqdm
for i in tqdm(range(samples)):
    mm_data = dataset.test_dataset[0][i]
    if mm_data.modalities['image'].num_views < 1:
#         print(i)
        unseen += 1

print(unseen / samples)

In [None]:
unseen = 0
samples = 1000

from tqdm import tqdm
for i in tqdm(range(samples)):
    mm_data = dataset.val_dataset[i]
    if mm_data.modalities['image'].num_views < 1:
#         print(i)
        unseen += 1

print(unseen / samples)

# Visualize a large sample

In [None]:
from torch_points3d.datasets.segmentation.kitti360_config import CLASS_NAMES, CLASS_COLORS
from torch_points3d.core.multimodal.data import MMData
from torch_points3d.core.multimodal.image import ImageData

dataset.val_dataset[0]
mm_window = dataset.val_dataset.buffer[0]

# dataset.test_dataset[0][0]
# mm_window = dataset.test_dataset[0].buffer[0]

mm_data_large = MMData(mm_window.data, image=ImageData(mm_window.images))

visualize_mm_data(mm_data_large, figsize=1000, pointsize=3, voxel=1, show_2d=False, front='map', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=2, max_points=500000)

In [None]:
from torch_points3d.datasets.segmentation.kitti360_config import CLASS_NAMES, CLASS_COLORS
from torch_points3d.core.multimodal.data import MMData
from torch_points3d.core.multimodal.image import ImageData

# dataset.val_dataset[0]
# mm_window = dataset.val_dataset.buffer[0]

dataset.test_dataset[0][1]
mm_window = dataset.test_dataset[0].buffer[1]

mm_data_large = MMData(mm_window.data, image=ImageData(mm_window.images))

visualize_mm_data(mm_data_large, figsize=1000, pointsize=3, voxel=1, show_2d=False, front='map', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=2, max_points=500000)

In [None]:
import torch_points3d.core.data_transform as cT
from torch_points3d.core.data_transform.multimodal.image import SelectMappingFromPointId, PickImagesFromMappingArea
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.core.multimodal.image import ImageData, SameSettingImageData

# Sample sphere at chosen locatiopn
opacity = 0.4
radius = 6
center = torch.Tensor([1012, 3843, 114])

# Take ball, take images, make into a MMData
data_small = cT.CylinderSampling(radius, center[:2], align_origin=False)(mm_window.data).clone()
data_small = Data(pos=data_small.pos, rgb=data_small.rgb, y=data_small.y, mapping_index=data_small.mapping_index, origin_id=data_small.origin_id)
data_small, images = SelectMappingFromPointId()(data_small, mm_window.images)
data_small, images = PickImagesFromMappingArea(area_ratio=0.02, use_bbox=True)(data_small, images)
# data_small, images = PickImagesFromMemoryCredit(img_size=images.ref_size, n_img=8)(data_small, images)
data_small = MMData(data_small, image=ImageData([images]))

# Take the surroundings, remove the center, remove image mappings, convert to MMData
data_large = cT.SphereSampling(radius * 3, center[:2], align_origin=False)(mm_window.data).clone()
is_in_small = torch.from_numpy(np.isin(data_large.origin_id.numpy(), data_small.origin_id.numpy()))
data_large = Data(pos=data_large.pos[~is_in_small], rgb=data_large.rgb[~is_in_small], y=data_large.y[~is_in_small], mapping_index=data_large.mapping_index[~is_in_small])
data_large.mapping_index = torch.arange(data_large.num_nodes)
data_large.rgb = 1 - opacity * (1 - data_large.rgb)
empty_images = images[0][[]]
empty_images.mappings.pointers = torch.zeros(data_large.num_nodes + 1, dtype=torch.long)
data_large = MMData(data_large, image=ImageData([empty_images]))

# Combine both into a MMBatch
mm_data = MMBatch.from_mm_data_list([data_large, data_small])

# Remove the ceiling
# mm_data = mm_data[mm_data.data.pos[:, 2] < 2.5]
# mm_data = mm_data[mm_data.data.pos[:, 0] < 3.3]

#-----------
# data, images = SelectMappingFromPointId()(data_large, dataset.test_dataset[0]._images[0])

# # Reduce the number of images
# data, images = PickImagesFromMappingArea(area_ratio=0.02, use_bbox=True)(data, images)
# data, images = PickImagesFromMemoryCredit(img_size=[512, 256], n_img=6)(data, images)

# # Convert to MMData
# mm_data = MMData(data, image=images)

In [None]:
visualize_mm_data(mm_data, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, figsize=1600, voxel=0.05, show_2d=False, pointsize=3, front=None)

In [None]:
from torch_points3d.models.model_factory import instantiate_model

model = instantiate_model(cfg, dataset)
# model = model.eval().cuda()
# model = model.train().cpu()
model = model.train().cuda()

In [None]:
from torch_points3d.core.multimodal import MMData, MMBatch

mm_data = MMBatch.from_mm_data_list([dataset.val_dataset[2893*4]])

In [None]:
from torch_points3d.core.multimodal import MMData, MMBatch

mm_data = MMBatch.from_mm_data_list([dataset.val_dataset[4], dataset.val_dataset[4]])

In [None]:
model.set_input(mm_data, model.device)

_ = model(mm_data)

In [None]:
model.loss_seg.backward()

In [None]:
mm_data

In [None]:
model.set_input(mm_data, model.device)

In [None]:
model(mm_data)
model.output.sum().backward()

In [None]:
model.output.shape

In [None]:
self.input = {
                'x_3d': sp3d.nn.SparseTensor(data.x, data.coords, data.batch, self.device),
                'x_seen': None,
                'modalities': data.to(self.device).modalities}

In [None]:
import torch_points3d.modules.SparseConv3d as sp3d

in_data = mm_data.clone()

mm_data_dict = {
    'x_3d': sp3d.nn.SparseTensor(in_data.data.x, in_data.data.coords, in_data.data.batch, model.device),
    'x_seen': None,
    'modalities': in_data.to(model.device).modalities }

out_dict = model.backbone.down_modules[0](mm_data_dict)

In [None]:
mm_data.modalities['image'].x[0].shape

In [None]:
out_dict['modalities']['image'].x[0].shape

In [None]:
mm_data = dataset.val_dataset[2893*4 + 0]
print(mm_data)

images = mm_data.modalities['image'][0]
# images = images.select_views([])
print(images)


In [None]:
images = images[[]]
images

In [None]:
# images.select_points([])
idx = 
images.mappings.pointers = torch.zeros(idx.shape[0] + 1).long().to(images.device)

In [None]:
images.select_points(torch.LongTensor(()))

In [None]:
from torch_points3d.datasets.segmentation.kitti360 import read_kitti360_window
from tqdm import tqdm

for split in ['train', 'val', 'test']:
    print(split)
    
    dataset_stage = dataset.get_dataset(split)
    
    for i, (path, sampling_path, raw_path) in tqdm(enumerate(zip(dataset_stage.paths, dataset_stage.sampling_paths, dataset_stage.raw_3d_paths))):
        
        if not osp.splitext(osp.basename(path))[0][:21] == osp.splitext(osp.basename(sampling_path))[0][:21] == osp.splitext(osp.basename(raw_path))[0][:21]:
            print(split, i, 'has name issues')
        
        sampling = torch.load(sampling_path)
        
        num_points = torch.load(path).num_nodes
        if num_points != sampling['num_points']:
            print(split, i, 'has num_points issues')
            sampling['num_points'] = num_points
#             torch.save(sampling, sampling_path)

        num_raw_points = read_kitti360_window(raw_path, xyz=True, rgb=False, semantic=False, instance=False).num_nodes
        if num_raw_points != sampling['num_raw_points']:
            print(split, i, 'has num_raw_points issues')
            sampling['num_raw_points'] = num_raw_points
#             torch.save(sampling, sampling_path)

In [None]:
mm_data

In [None]:
mm_data.modalities['image'].select_views([])

In [None]:
dataset.val_dataset[]

In [None]:
mm_data

In [None]:
from torch_points3d.core.multimodal.data import MMBatch

mm_data = dataset.val_dataset[2893*4 + 0]
mm_data.modalities['image'] = mm_data.modalities['image'].select_views([])
mm_data = MMBatch.from_mm_data_list([mm_data, mm_data, mm_data])

# mm_data = MMBatch.from_mm_data_list([dataset.val_dataset[2893*4 + 0]])

model.set_input(mm_data, model.device)
out = model.forward(mm_data)
model.output.sum().backward()

In [None]:
model.backbone.down_modules[3].image.out_channels

In [None]:
mm_data.to_mm_data_list()

In [None]:
from torch_points3d.core.multimodal import MMData, ImageData
from torch_points3d.datasets.segmentation.kitti360 import KITTI360_NUM_CLASSES, INV_OBJECT_LABEL, OBJECT_COLOR, CLASS_NAMES, CLASS_COLORS

dataset.train_dataset.transform = None
dataset.train_dataset.transform_image.transforms = dataset.train_dataset.transform_image.transforms[:4]
dataset.train_dataset.transform_image.transforms[3].credit = 1408 * 376 * 3

dataset.val_dataset.transform = None
dataset.val_dataset.transform_image.transforms = dataset.val_dataset.transform_image.transforms[:4]
dataset.val_dataset.transform_image.transforms[3].credit = 1408 * 376 * 3

# dataset.test_dataset[0].transform = dataset.val_dataset.transform
# dataset.test_dataset[0].transform_image = dataset.val_dataset.transform_image

# Train sample
mm_data = dataset.train_dataset[dataset.train_dataset._pick_random_label_and_window()]

# Val sample
# mm_data = dataset.val_dataset[np.random.randint(len(dataset.val_dataset[0]))]

visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.5, show_2d=True, front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=2)

# Reading images

In [None]:
DATA_ROOT = '/raid/dataset/pointcloud/data/kitti_360/KITTI-360'
raw_2d_dir = osp.join(DATA_ROOT, 'data_2d_raw')

scan_names = sorted(os.listdir(raw_2d_dir))
print(f'There are {len(scan_names)} scans')
for scan_name in scan_names:
    print(f'  {scan_name}')

camera_names = ['image_00', 'image_01']

images = sorted(glob.glob(osp.join(raw_2d_dir, scan_names[0], camera_names[0], 'data_rect', '*.png')))
image_names = sorted(os.listdir(osp.join(raw_2d_dir, scan_names[0], camera_names[0], 'data_rect')))

print(f"\nThere are 2 x {len(image_names)} perspective images in '{scan_names[0]}'")

In [None]:
from PIL import Image

for i in range(200, 260, 20):
    image_path_0 = osp.join(raw_2d_dir, scan_names[0], camera_names[0], 'data_rect', image_names[i])
    image_path_1 = osp.join(raw_2d_dir, scan_names[0], camera_names[1], 'data_rect', image_names[i])
    
    fig, axes = plt.subplots(1, 2, figsize=(24, 4))
    
    axes[0].imshow(np.asarray(Image.open(image_path_0)))
    axes[1].imshow(np.asarray(Image.open(image_path_1)))
    plt.show()

We note that the perspective cameras are redundant, they see roughly the same thing. So we will only use one for our application.μ

In [None]:
def readVariable(fid,name,M,N):
    # rewind
    fid.seek(0,0)
    
    # search for variable identifier
    line = 1
    success = 0
    while line:
        line = fid.readline()
        if line.startswith(name):
            success = 1
            break

    # return if variable identifier not found
    if success==0:
        return None
    
    # fill matrix
    line = line.replace('%s:' % name, '')
    line = line.split()
    assert(len(line) == M*N)
    line = [float(x) for x in line]
    mat = np.array(line).reshape(M, N)

    return mat

def checkfile(filename):
    if not os.path.isfile(filename):
        raise RuntimeError('%s does not exist!' % filename)
        
def loadCalibrationCameraToPose(filename):
    # check file
    checkfile(filename)

    # open file
    fid = open(filename,'r');
     
    # read variables
    Tr = {}
    cameras = ['image_00', 'image_01', 'image_02', 'image_03']
    lastrow = np.array([0,0,0,1]).reshape(1,4)
    for camera in cameras:
        Tr[camera] = np.concatenate((readVariable(fid, camera, 3, 4), lastrow))
      
    # close file
    fid.close()
    return Tr

class Camera:
    def __init__(self):
        
        # load intrinsics
        self.load_intrinsics(self.intrinsic_file)

        # load poses
        poses = np.loadtxt(self.pose_file)
        frames = poses[:,0].astype(np.int64)
        poses = np.reshape(poses[:,1:],[-1,3,4])
        self.cam2world = {}
        self.frames = frames
        for frame, pose in zip(frames, poses): 
            pose = np.concatenate((pose, np.array([0.,0.,0.,1.]).reshape(1,4)))
            # consider the rectification for perspective cameras
            if self.cam_id==0 or self.cam_id==1:
                self.cam2world[frame] = np.matmul(np.matmul(pose, self.camToPose),
                                                  np.linalg.inv(self.R_rect))
            # fisheye cameras
            elif self.cam_id==2 or self.cam_id==3:
                self.cam2world[frame] = np.matmul(pose, self.camToPose)
            else:
                raise RuntimeError('Unknown Camera ID!')


    def world2cam(self, points, R, T, inverse=False):
        assert (points.ndim==R.ndim)
        assert (T.ndim==R.ndim or T.ndim==(R.ndim-1)) 
        ndim=R.ndim
        if ndim==2:
            R = np.expand_dims(R, 0) 
            T = np.reshape(T, [1, -1, 3])
            points = np.expand_dims(points, 0)
        if not inverse:
            points = np.matmul(R, points.transpose(0,2,1)).transpose(0,2,1) + T
        else:
            points = np.matmul(R.transpose(0,2,1), (points - T).transpose(0,2,1))

        if ndim==2:
            points = points[0]

        return points

    def cam2image(self, points):
        raise NotImplementedError

    def load_intrinsics(self, intrinsic_file):
        raise NotImplementedError
    
    def project_vertices(self, vertices, frameId, inverse=True):

        # current camera pose
        curr_pose = self.cam2world[frameId]
        T = curr_pose[:3,  3]
        R = curr_pose[:3, :3]

        # convert points from world coordinate to local coordinate 
        points_local = self.world2cam(vertices, R, T, inverse)

        # perspective projection
        u,v,depth = self.cam2image(points_local)

        return (u,v), depth 

    def __call__(self, obj3d, frameId):

        vertices = obj3d.vertices

        uv, depth = self.project_vertices(vertices, frameId)

        obj3d.vertices_proj = uv
        obj3d.vertices_depth = depth 
        obj3d.generateMeshes()


class CameraPerspective(Camera):

    def __init__(self, root_dir, seq='2013_05_28_drive_0009_sync', cam_id=0):
        # perspective camera ids: {0,1}, fisheye camera ids: {2,3}
        assert (cam_id==0 or cam_id==1)

        pose_dir = os.path.join(root_dir, 'data_poses', seq)
        calib_dir = os.path.join(root_dir, 'calibration')
        self.pose_file = os.path.join(pose_dir, "poses.txt")
        self.intrinsic_file = os.path.join(calib_dir, 'perspective.txt')
        fileCameraToPose = os.path.join(calib_dir, 'calib_cam_to_pose.txt')
        self.camToPose = loadCalibrationCameraToPose(fileCameraToPose)['image_%02d' % cam_id]
        self.cam_id = cam_id
        super(CameraPerspective, self).__init__()

    def load_intrinsics(self, intrinsic_file):
        ''' load perspective intrinsics '''
    
        intrinsic_loaded = False
        width = -1
        height = -1
        with open(intrinsic_file) as f:
            intrinsics = f.read().splitlines()
        for line in intrinsics:
            line = line.split(' ')
            if line[0] == 'P_rect_%02d:' % self.cam_id:
                K = [float(x) for x in line[1:]]
                K = np.reshape(K, [3,4])
                intrinsic_loaded = True
            elif line[0] == 'R_rect_%02d:' % self.cam_id:
                R_rect = np.eye(4) 
                R_rect[:3,:3] = np.array([float(x) for x in line[1:]]).reshape(3,3)
            elif line[0] == "S_rect_%02d:" % self.cam_id:
                width = int(float(line[1]))
                height = int(float(line[2]))
        assert(intrinsic_loaded==True)
        assert(width>0 and height>0)
    
        self.K = K
        self.width, self.height = width, height
        self.R_rect = R_rect

    def cam2image(self, points):
        ndim = points.ndim
        if ndim == 2:
            points = np.expand_dims(points, 0)
        points_proj = np.matmul(self.K[:3,:3].reshape([1,3,3]), points)
        depth = points_proj[:,2,:]
        depth[depth==0] = -1e-6
        u = np.round(points_proj[:,0,:]/np.abs(depth)).astype(np.int64)
        v = np.round(points_proj[:,1,:]/np.abs(depth)).astype(np.int64)

        if ndim==2:
            u = u[0]; v=v[0]; depth=depth[0]
        return u, v, depth


In [None]:
camera = CameraPerspective(DATA_ROOT, seq='2013_05_28_drive_0000_sync', cam_id=0)

In [None]:
uv, d = camera.project_vertices(np.eye(3), 1)
uv, d

In [None]:
def load_intrinsics(intrinsic_file, cam_id=0):
    ''' load KITTIU360 perspective camera intrinsics 
    
    Credit: https://github.com/autonomousvision/kitti360Scripts
    '''

    intrinsic_loaded = False
    width = -1
    height = -1
    with open(intrinsic_file) as f:
        intrinsics = f.read().splitlines()
    for line in intrinsics:
        line = line.split(' ')
        if line[0] == f'P_rect_0{cam_id}:':
            K = [float(x) for x in line[1:]]
            K = np.reshape(K, [3,4])
            intrinsic_loaded = True
        elif line[0] == f'R_rect_0{cam_id}:':
            R_rect = np.eye(4) 
            R_rect[:3,:3] = np.array([float(x) for x in line[1:]]).reshape(3,3)
        elif line[0] == f"S_rect_0{cam_id}:":
            width = int(float(line[1]))
            height = int(float(line[2]))
    assert(intrinsic_loaded==True)
    assert(width>0 and height>0)
    
    return K, R_rect, width, height
        
        
def loadCalibrationCameraToPose(filename):
    ''' load KITTIU360 camera-to-pose calibration 
    
    Credit: https://github.com/autonomousvision/kitti360Scripts
    '''
    Tr = {}
    with open(filename,'r') as fid:
        cameras = ['image_00', 'image_01', 'image_02', 'image_03']
        lastrow = np.array([0,0,0,1]).reshape(1,4)
        for camera in cameras:
            Tr[camera] = np.concatenate((readVariable(fid, camera, 3, 4), lastrow))
    return Tr

In [None]:
seq = scan_names[0]
cam_id = 0

pose_dir = osp.join(DATA_ROOT, 'data_poses', seq)
calib_dir = osp.join(DATA_ROOT, 'calibration')

pose_file = osp.join(pose_dir, "poses.txt")
intrinsic_file = osp.join(calib_dir, 'perspective.txt')
fileCameraToPose = osp.join(calib_dir, 'calib_cam_to_pose.txt')
camToPose = loadCalibrationCameraToPose(fileCameraToPose)[f'image_0{cam_id}']

K, R_rect, width, height = load_intrinsics(intrinsic_file, cam_id=cam_id)

# Reading points

In [None]:
IGNORE = -1

# Credit: https://github.com/autonomousvision/kitti360Scripts/blob/master/kitti360scripts/helpers/labels.py

from collections import namedtuple

labels = [
    #       name                     id    kittiId,    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        1 ,         0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        3 ,         1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,        2 ,    IGNORE , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,        10,    IGNORE , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        11,         2 , 'construction'    , 2       , True         , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        7 ,         3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        8 ,         4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,        30,    IGNORE , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,        31,    IGNORE , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,        32,    IGNORE , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        21,         5 , 'object'          , 3       , True         , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,       -1 ,    IGNORE , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        23,         6 , 'object'          , 3       , True         , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        24,         7 , 'object'          , 3       , True         , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        5 ,         8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        4 ,         9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,        9 ,        10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,        19,        11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,        20,        12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,        13,        13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,        14,        14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,        34,        15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,        16,    IGNORE , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,        15,    IGNORE , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,        33,        16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,        17,        17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,        18,        18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'garage'               , 34 ,        12,         2 , 'construction'    , 2       , True         , False        , ( 64,128,128) ),
    Label(  'gate'                 , 35 ,        6 ,         4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'stop'                 , 36 ,        29,    IGNORE , 'construction'    , 2       , True         , True         , (150,120, 90) ),
    Label(  'smallpole'            , 37 ,        22,         5 , 'object'          , 3       , True         , False        , (153,153,153) ),
    Label(  'lamp'                 , 38 ,        25,    IGNORE , 'object'          , 3       , True         , False        , (0,   64, 64) ),
    Label(  'trash bin'            , 39 ,        26,    IGNORE , 'object'          , 3       , True         , False        , (0,  128,192) ),
    Label(  'vending machine'      , 40 ,        27,    IGNORE , 'object'          , 3       , True         , False        , (128, 64,  0) ),
    Label(  'box'                  , 41 ,        28,    IGNORE , 'object'          , 3       , True         , False        , (64,  64,128) ),
    Label(  'unknown construction' , 42 ,        35,    IGNORE , 'void'            , 0       , False        , True         , (102,  0,  0) ),
    Label(  'unknown vehicle'      , 43 ,        36,    IGNORE , 'void'            , 0       , False        , True         , ( 51,  0, 51) ),
    Label(  'unknown object'       , 44 ,        37,    IGNORE , 'void'            , 0       , False        , True         , ( 32, 32, 32) ),
    Label(  'license plate'        , -1 ,        -1,        -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]

# Dictionaries for a fast lookup
NAME2LABEL = {label.name: label for label in labels}
ID2LABEL = {label.id: label for label in labels}
TRAINID2LABEL = {label.trainId: label for label in reversed(labels)}
KITTIID2LABEL = {label.kittiId: label for label in labels}  # KITTI-360 ID to cityscapes ID
CATEGORY2LABELS = {}
for label in labels:
    category = label.category
    if category in CATEGORY2LABELS:
        CATEGORY2LABELS[category].append(label)
    else:
        CATEGORY2LABELS[category] = [label]

CLASS_NAMES = [ID2LABEL[x].name for x in range(45)]
CLASS_COLORS = [ID2LABEL[x].color for x in range(45)]

In [None]:
labels = [
    #       name                     id    kittiId,    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,       -1 ,    IGNORE , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        1 ,         0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        3 ,         1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,        2 ,    IGNORE , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,        10,    IGNORE , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        11,         2 , 'construction'    , 2       , True         , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        7 ,         3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        8 ,         4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,        30,    IGNORE , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,        31,    IGNORE , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,        32,    IGNORE , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        21,         5 , 'object'          , 3       , True         , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,       -1 ,    IGNORE , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        23,         6 , 'object'          , 3       , True         , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        24,         7 , 'object'          , 3       , True         , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        5 ,         8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        4 ,         9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,        9 ,        10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,        19,        11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,        20,        12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,        13,        13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,        14,        14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,        34,        15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,        16,    IGNORE , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,        15,    IGNORE , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,        33,        16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,        17,        17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,        18,        18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'garage'               , 34 ,        12,         2 , 'construction'    , 2       , True         , False        , ( 64,128,128) ),
    Label(  'gate'                 , 35 ,        6 ,         4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'stop'                 , 36 ,        29,    IGNORE , 'construction'    , 2       , True         , True         , (150,120, 90) ),
    Label(  'smallpole'            , 37 ,        22,         5 , 'object'          , 3       , True         , False        , (153,153,153) ),
    Label(  'lamp'                 , 38 ,        25,    IGNORE , 'object'          , 3       , True         , False        , (0,   64, 64) ),
    Label(  'trash bin'            , 39 ,        26,    IGNORE , 'object'          , 3       , True         , False        , (0,  128,192) ),
    Label(  'vending machine'      , 40 ,        27,    IGNORE , 'object'          , 3       , True         , False        , (128, 64,  0) ),
    Label(  'box'                  , 41 ,        28,    IGNORE , 'object'          , 3       , True         , False        , (64,  64,128) ),
    Label(  'unknown construction' , 42 ,        35,    IGNORE , 'void'            , 0       , False        , True         , (102,  0,  0) ),
    Label(  'unknown vehicle'      , 43 ,        36,    IGNORE , 'void'            , 0       , False        , True         , ( 51,  0, 51) ),
    Label(  'unknown object'       , 44 ,        37,    IGNORE , 'void'            , 0       , False        , True         , ( 32, 32, 32) ),
    Label(  'license plate'        , -1 ,        -1,        -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]

# Dictionaries for a fast lookup
KITTI360_NUM_CLASSES = 19
NAME2LABEL = {label.name: label for label in labels}
ID2LABEL = {label.id: label for label in labels}
TRAINID2LABEL = {label.trainId: label for label in reversed(labels)}
KITTIID2LABEL = {label.kittiId: label for label in labels}  # KITTI-360 ID to cityscapes ID
CATEGORY2LABELS = {}
for label in labels:
    category = label.category
    if category in CATEGORY2LABELS:
        CATEGORY2LABELS[category].append(label)
    else:
        CATEGORY2LABELS[category] = [label]
INV_OBJECT_LABEL = {k: TRAINID2LABEL[k].name for k in range(KITTI360_NUM_CLASSES)}
OBJECT_COLOR = np.asarray([TRAINID2LABEL[k].color for k in range(KITTI360_NUM_CLASSES)])
OBJECT_LABEL = {name: i for i, name in INV_OBJECT_LABEL.items()}

In [None]:
len(TRAINID2LABEL)-1

In [None]:
INV_OBJECT_LABEL

In [None]:
ID2TRAINID = [label.trainId for label in labels]

In [None]:
ID2TRAINID

In [None]:
[label.id for label in labels]

In [None]:
{k: TRAINID2LABEL[k].name for k in sorted(TRAINID2LABEL.keys())[1:]}

In [None]:
KITTI360_NUM_CLASSES = 19
INV_OBJECT_LABEL = {k: TRAINID2LABEL[k].name for k in range(KITTI360_NUM_CLASSES)}
OBJECT_COLOR = np.asarray([TRAINID2LABEL[k].color for k in range(KITTI360_NUM_CLASSES)])

In [None]:
{name: i for i, name in INV_OBJECT_LABEL.items()}

In [None]:
CLASS_NAMES

In [None]:
DATA_ROOT = '/raid/dataset/pointcloud/data/kitti_360/KITTI-360'

for raw_3d_dir in [osp.join(DATA_ROOT, 'data_3d_semantics'), osp.join(DATA_ROOT, 'data_3d_semantics_test')]:
    scan_names = sorted(os.listdir(raw_3d_dir))
    print(f'\nThere are {len(scan_names)} {osp.basename(raw_3d_dir)} scans')
    for scan_name in scan_names:
        window_names = os.listdir(osp.join(raw_3d_dir, scan_name, "static"))
        num_points = []
        for window_name in window_names:
            with open(osp.join(raw_3d_dir, scan_name, 'static', window_name), "rb") as f:
                data = PlyData.read(f)
                num_points.append(torch.tensor(data["vertex"]["red"]).shape[0])
        print(f'  {scan_name} - {len(window_names):>2} windows - total={np.sum(num_points) / 10**6:0.1f} - mean={np.mean(num_points) / 10**6:0.1f} - min={np.min(num_points) / 10**6:0.1f} - max={np.max(num_points) / 10**6:0.1f}')

In [None]:
def read_chunk(filepath, instance=False):
    with open(filepath, "rb") as f:
        chunk = PlyData.read(f)

        pos = torch.stack([torch.tensor(chunk["vertex"][axis]) for axis in ["x", "y", "z"]], dim=-1)

        rgb = torch.stack([torch.tensor(chunk["vertex"][axis]) for axis in ["red", "green", "blue"]], dim=-1).float() / 255

        y = torch.tensor(chunk["vertex"]['semantic'])
        
        data = Data(pos=pos, rgb=rgb, y=y, mapping_index=torch.arange(pos.shape[0]))
        
        if instance:
            data.instance = torch.tensor(chunk["vertex"]['instance'])
        
        return data

In [None]:
DATA_ROOT = '/raid/dataset/pointcloud/data/kitti_360/KITTI-360'
raw_3d_dir = osp.join(DATA_ROOT, 'data_3d_semantics')
scan_name = sorted(os.listdir(raw_3d_dir))[0]
window_name = os.listdir(osp.join(raw_3d_dir, scan_name, "static"))[0]

mm_data = MMData(data=read_chunk(osp.join(raw_3d_dir, scan_name, 'static', window_name)), image=ImageData([SameSettingImageData()]))

print(f'Raw cloud size: {mm_data.data.num_nodes / 10**6:0.3f} M points')

from torch_points3d.core.data_transform.grid_transform import GridSampling3D

print('\nSubsampling the cloud...')
for voxel in [1.0, 0.5, 0.25, 0.1, 0.08, 0.05, 0.01]:
    sub_cloud = GridSampling3D(voxel, mode='last')(mm_data.data.clone())
    print(f'voxel={voxel:>4}, n={sub_cloud.num_nodes:>7}, ratio={sub_cloud.num_nodes / mm_data.data.num_nodes * 100:0.2f}')

In [None]:
mm_data = MMData(data=read_chunk(osp.join(raw_3d_dir, scan_name, 'static', window_name)), image=ImageData([SameSettingImageData()]))

visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.5, show_2d=False, class_names=class_names, class_colors=class_colors)

# Trying to visualize image centers

In [None]:
seq = scan_name
cam_id = 0

pose_dir = osp.join(DATA_ROOT, 'data_poses', seq)
calib_dir = osp.join(DATA_ROOT, 'calibration')

pose_file = osp.join(pose_dir, "poses.txt")

poses = np.loadtxt(pose_file)[:350:20]
frames = poses[:, 0].astype(np.int64)
poses = poses[:, 1:].reshape(-1, 3, 4)

n_images = poses.shape[0]
extrinsic = torch.from_numpy(poses)
xyz = extrinsic[:, :3, 3]
fx = torch.zeros(n_images)
fy = torch.zeros(n_images)
mx = torch.zeros(n_images)
my = torch.zeros(n_images)

image_data = SameSettingImageData(ref_size=(width, height), proj_upscale=1, path=np.zeros(n_images).astype('O'), pos=xyz, fx=fx, fy=fy, mx=mx, my=my, extrinsic=extrinsic)

In [None]:
mm_data = MMData(data=read_chunk(osp.join(raw_3d_dir, scan_names[0], 'static', scan_ply_chunks[0])), image=ImageData(image_data))

visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.5, show_2d=False, class_names=class_names, class_colors=class_colors)

In [None]:
from PIL import Image

for i_frame, frame in enumerate(frames):
    image_path_0 = osp.join(raw_2d_dir, scan_names[0], camera_names[0], 'data_rect', f'{frame:010d}.png')
    image_path_1 = osp.join(raw_2d_dir, scan_names[0], camera_names[1], 'data_rect', f'{frame:010d}.png')
    
    fig, axes = plt.subplots(1, 2, figsize=(24, 4))
    
    axes[0].imshow(np.asarray(Image.open(image_path_0)))
    axes[1].imshow(np.asarray(Image.open(image_path_1)))
    fig.suptitle(f'{i_frame}')
    plt.show()

# KITTI360 dataset

In [None]:
def read_chunk(filepath, instance=False):
    with open(filepath, "rb") as f:
        chunk = PlyData.read(f)

        pos = torch.stack([torch.FloatTensor(chunk["vertex"][axis]) for axis in ["x", "y", "z"]], dim=-1)

        rgb = torch.stack([torch.FloatTensor(chunk["vertex"][axis]) for axis in ["red", "green", "blue"]], dim=-1) / 255

        y = torch.LongTensor(chunk["vertex"]['semantic'])
        
        data = Data(pos=pos, rgb=rgb, y=y, mapping_index=torch.arange(pos.shape[0]))
        
        if instance:
            data.instance = torch.LongTensor(chunk["vertex"]['instance'])
        
        return data

def readVariable(fid,name,M,N):
    # rewind
    fid.seek(0,0)
    
    # search for variable identifier
    line = 1
    success = 0
    while line:
        line = fid.readline()
        if line.startswith(name):
            success = 1
            break

    # return if variable identifier not found
    if success==0:
        return None
    
    # fill matrix
    line = line.replace('%s:' % name, '')
    line = line.split()
    assert(len(line) == M*N)
    line = [float(x) for x in line]
    mat = np.array(line).reshape(M, N)

    return mat

def load_intrinsics(intrinsic_file, cam_id=0):
    ''' load KITTIU360 perspective camera intrinsics 
    
    Credit: https://github.com/autonomousvision/kitti360Scripts
    '''

    intrinsic_loaded = False
    width = -1
    height = -1
    with open(intrinsic_file) as f:
        intrinsics = f.read().splitlines()
    for line in intrinsics:
        line = line.split(' ')
        if line[0] == f'P_rect_0{cam_id}:':
            K = [float(x) for x in line[1:]]
            K = np.reshape(K, [3,4])
            intrinsic_loaded = True
        elif line[0] == f'R_rect_0{cam_id}:':
            R_rect = np.eye(4) 
            R_rect[:3,:3] = np.array([float(x) for x in line[1:]]).reshape(3,3)
        elif line[0] == f"S_rect_0{cam_id}:":
            width = int(float(line[1]))
            height = int(float(line[2]))
    assert(intrinsic_loaded==True)
    assert(width>0 and height>0)
    
    return K, R_rect, width, height
        
        
def loadCalibrationCameraToPose(filename):
    ''' load KITTIU360 camera-to-pose calibration 
    
    Credit: https://github.com/autonomousvision/kitti360Scripts
    '''
    Tr = {}
    with open(filename,'r') as fid:
        cameras = ['image_00', 'image_01', 'image_02', 'image_03']
        lastrow = np.array([0,0,0,1]).reshape(1,4)
        for camera in cameras:
            Tr[camera] = np.concatenate((readVariable(fid, camera, 3, 4), lastrow))
    return Tr

In [None]:
DATA_ROOT = '/raid/dataset/pointcloud/data/kitti_360/KITTI-360'
raw_3d_dir = osp.join(DATA_ROOT, 'data_3d_semantics')
raw_2d_dir = osp.join(DATA_ROOT, 'data_2d_raw')
camera_names = ['image_00', 'image_01']

scan_names = sorted(os.listdir(raw_3d_dir))

seq = scan_names[0]
cam_id = 0

# Initialize file paths
pose_dir = osp.join(DATA_ROOT, 'data_poses', seq)
calib_dir = osp.join(DATA_ROOT, 'calibration')
intrinsic_file = osp.join(calib_dir, 'perspective.txt')
pose_file = osp.join(pose_dir, 'poses.txt')
fileCameraToPose = osp.join(calib_dir, 'calib_cam_to_pose.txt')

# Camera-to-pose calibration
camToPose = torch.from_numpy(loadCalibrationCameraToPose(fileCameraToPose)[f'image_{cam_id:02d}'])

# System poses (different from camera pose)
poses = np.loadtxt(pose_file)[:350:20]
frames = poses[:, 0].astype(np.int64)
poses = torch.from_numpy(poses[:, 1:]).view(-1, 3, 4)

n_images = poses.shape[0]

# Intrinsic parameters
K, R_rect, width, height = load_intrinsics(intrinsic_file, cam_id=cam_id)
fx = torch.Tensor([K[0, 0]]).repeat(n_images)
fy = torch.Tensor([K[1, 1]]).repeat(n_images)
mx = torch.Tensor([K[0, 2]]).repeat(n_images)
my = torch.Tensor([K[1, 2]]).repeat(n_images)
R_rect = torch.from_numpy(R_rect)

# Recover the cam2world from system pose and calibration
cam2world = {}
for frame, pose in zip(frames, poses): 
    
    pose = torch.cat((pose, torch.ones(1, 4)), dim=0)
    # consider the rectification for perspective cameras
    if cam_id==0:
        cam2world[frame] = pose @ camToPose @ torch.inverse(R_rect)
#         cam2world[frame] = np.matmul(np.matmul(pose, self.camToPose), np.linalg.inv(R_rect))
    else:
        raise RuntimeError(f"Unknown Camera ID '{cam_id}'!")

In [None]:
i_frame = 15
T = cam2world[frames[i_frame]][:3, 3]
R = cam2world[frames[i_frame]][:3, :3]

In [None]:
from torch_points3d.core.data_transform.grid_transform import GridSampling3D

# voxel = 0.05
voxel = 0.1
    
scan_ply_chunks = sorted(os.listdir(osp.join(raw_3d_dir, scan_names[0], 'static')))
data = read_chunk(osp.join(raw_3d_dir, scan_names[0], 'static', scan_ply_chunks[0]))

data = GridSampling3D(voxel, mode='last')(data)
del data.grid_size

radius = 50
in_range = (data.pos - T).norm(dim=1) < radius
data = Data(**{x: data[x][in_range] for x in data.keys})
data.mapping_index = torch.arange(data.num_nodes)

In [None]:
from torch_points3d.core.multimodal.visibility import pinhole_projection_cuda, field_of_view_cuda

u, v, z = pinhole_projection_cuda(data.pos, cam2world[frames[i_frame]], K, camera='kitti360_perspective')

In [None]:
in_fov = torch.zeros(data.num_nodes, dtype=torch.bool)
in_fov[field_of_view_cuda(u, v, x_min=0, x_max=width, y_min=0, y_max=height, z=z, img_mask=None)] = True

alpha = 0.6
data.rgb[in_fov] = data.rgb[in_fov] * alpha + torch.Tensor([1, 0, 0]) * (1 - alpha)

In [None]:
image_path = np.array([osp.join(raw_2d_dir, scan_names[0], camera_names[0], 'data_rect', f'{frames[i_frame]:010d}.png')])
image_data = SameSettingImageData(ref_size=(width, height), proj_upscale=1, path=image_path, pos=T.unsqueeze(0), fx=fx[0].unsqueeze(0), fy=fy[0].unsqueeze(0), mx=mx[0].unsqueeze(0), my=my[0].unsqueeze(0), extrinsic=cam2world[frames[i_frame]].unsqueeze(0))

In [None]:
mm_data = MMData(data=data, image=ImageData(image_data))

visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=voxel, show_2d=True, class_names=class_names, class_colors=class_colors)

In [None]:
from torch_points3d.core.data_transform.multimodal.image import MapImages

transform = MapImages(method='SplattingVisibility', ref_size=(width, height), proj_upscale=1, use_cuda=True, voxel=voxel, r_max=radius, r_min=0, k_swell=1.5, d_swell=10**6, exact=False, camera='kitti360_perspective', verbose=True)
out = transform(data, image_data)
mm_data = MMData(data=out[0], image=ImageData(out[1]))


In [None]:
visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=voxel, show_2d=True, class_names=class_names, class_colors=class_colors, front='y')

In [None]:
from PIL import Image
Image.open(osp.join(raw_2d_dir, scan_names[0], camera_names[0], 'data_rect', f'{frames[i_frame]:010d}.png'))