In [1]:
import os
import pdb
import time
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import numpy as np
import torch

import torch.nn.functional as F
# For visualizer
import rospy
from ContinuousBKI import *
from visualization_msgs.msg import *
rospy.init_node('talker',disable_signals=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device == "cuda":
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
else:
  start = None
  end = None
print("device is ", device)
    
home_dir = os.path.expanduser('~')
dataset_loc = os.path.join(home_dir, "Data/Rellis-3D/00004/")

device is  cuda


In [2]:
class DiscreteBKI(torch.nn.Module):
    def __init__(self, grid_size, min_bound, max_bound, filter_size=3,
                 num_classes=21, prior=0.001, device="cpu",
                max_dist=0.5):
        '''
        Input:
            grid_size: (x, y, z) int32 array, number of voxels
            min_bound: (x, y, z) float32 array, lower bound on local map
            max_bound: (x, y, z) float32 array, upper bound on local map
        '''
        super().__init__()
        self.min_bound = min_bound.view(-1, 3).to(device)
        self.max_bound = max_bound.view(-1, 3).to(device)
        self.grid_size = grid_size
        self.prior = prior

        self.device = device
        self.num_classes = num_classes
        
        self.voxel_sizes = (self.max_bound.view(-1) - self.min_bound.view(-1)) / self.grid_size.to(self.device)
        
        self.pi = torch.acos(torch.zeros(1)).item() * 2
        self.max_dist = max_dist
        self.filter_size = torch.tensor(filter_size).to(self.device)
        self.initialize_kernel()
        
        [xs, ys, zs] = [(max_bound[i]-min_bound[i])/(2*grid_size[i]) + 
                        torch.linspace(min_bound[i], max_bound[i], device=device, steps=grid_size[i]+1)[:-1] 
                        for i in range(3)]
        self.centroids = torch.cartesian_prod(xs, ys, zs).to(device)
    
    def initialize_kernel(self):
        # Initialize with sparse kernel
        weights = []
        assert(self.filter_size % 2 == 1)
        middle_ind = torch.floor(self.filter_size / 2)
        
        self.sigma = torch.nn.Parameter(torch.tensor(1.0)) # Kernel must map to 0 to 1
        self.ell = torch.nn.Parameter(torch.tensor(self.max_dist)) # Max distance to consider
        
        for x_ind in range(self.filter_size):
            for y_ind in range(self.filter_size):
                for z_ind in range(self.filter_size):
                    x_dist = torch.abs(x_ind - middle_ind) * self.voxel_sizes[0]
                    y_dist = torch.abs(y_ind - middle_ind) * self.voxel_sizes[1]
                    z_dist = torch.abs(z_ind - middle_ind) * self.voxel_sizes[2]
                    total_dist = torch.sqrt(x_dist**2 + y_dist**2 + z_dist**2)
                    kernel_value = self.calculate_kernel(total_dist)
                    # Edge case: middle
                    if total_dist == 0:
                        weights.append(1.0)
                    else:
                        weight = self.inverse_sigmoid(kernel_value)
                        weights.append(torch.nn.Parameter(weight))
        self.weights = weights
                    
    def inverse_sigmoid(self, x):
        return -torch.log((1 / (x + 1e-8)) - 1)
            
            
    def calculate_kernel(self, d):
        if d > self.max_dist:
            return torch.tensor(0.0, device=self.device)
        if d == 0:
            return 1
        return self.sigma * ( 
                (1/3)*(2 + torch.cos(2 * self.pi * d/self.ell))*(1 - d/self.ell) +
                         1/(2*self.pi) * torch.sin(2 * self.pi * d / self.ell)
                         )
            
            
    def initialize_grid(self):
        return torch.zeros(self.grid_size[0], self.grid_size[1], self.grid_size[2], 
                           self.num_classes, device=self.device) + self.prior
    
    def grid_ind(self, input_pc):
        '''
        Input:
            input_xyz: N * (x, y, z, c) float32 array, point cloud
        Output:
            grid_inds: N' * (x, y, z, c) int32 array, point cloud mapped to voxels
        '''
        input_xyz   = input_pc[:, :3]
        labels      = input_pc[:, 3].view(-1, 1)
        
        valid_input_mask = torch.all((input_xyz < self.max_bound) & (input_xyz >= self.min_bound), axis=1)
        
        valid_xyz = input_xyz[valid_input_mask]
        valid_labels = labels[valid_input_mask]
        
        grid_inds = torch.floor((valid_xyz - self.min_bound) / self.voxel_sizes)
        maxes = (self.grid_size - 1).view(1, 3)
        clipped_inds = torch.clamp(grid_inds, torch.zeros_like(maxes), maxes)
        
        return torch.hstack( (clipped_inds, valid_labels) )
        
        
    def forward(self, current_map, point_cloud):
        '''
        Input:
            current_map: (x, y, z, c) float32 array, prior dirichlet distribution over map
            point_cloud: N * (x, y, z, c) float32 array, semantically labeled points
        Output:
            updated_map: (x, y, z, c) float32 array, posterior dirichlet distribution over map
        '''
        # Assume map and point cloud are already aligned
        X, Y, Z, C = current_map.shape
        update = torch.zeros_like(current_map)
        
        # 1: Discretize
        grid_pc = self.grid_ind(point_cloud).to(torch.long)
       
        unique_inds, counts = torch.unique(grid_pc, return_counts=True, dim=0)  
        grid_indices = [unique_inds[:, i] for i in range(grid_pc.shape[1])]
        
        update[grid_indices] = update[grid_indices] + counts
        
        # 2: Apply BKI filters
        filters = torch.sigmoid(torch.tensor(self.weights, device=self.device)).view(
            1, 1, self.filter_size, self.filter_size, self.filter_size)
        mid = torch.floor(self.filter_size / 2).to(torch.long)
        filters[0, 0, mid, mid, mid] = 1
        
        update = torch.unsqueeze(update.permute(3, 0, 1, 2), 1)
        update = F.conv3d(update, filters, padding="same")
        update = torch.squeeze(update).permute(1, 2, 3, 0)
        
        return current_map + update
    
    # def propagate(self, current_map, transformation)

In [3]:
bki_map = DiscreteBKI(
    torch.tensor([256, 256, 16]).to(device), # Grid size
    torch.tensor([-25.6, -25.6, -2.0]).to(device), # Lower bound
    torch.tensor([25.6, 25.6, 1.2]).to(device), # Upper bound
    device=device
)

# Add visualization
map_pub = rospy.Publisher('SemMap', MarkerArray, queue_size=10)

In [4]:
# Load point cloud from RELLIS
velo_loc = os.path.join(dataset_loc, "os1_cloud_node_kitti_bin")
label_base_loc = os.path.join(dataset_loc, "os1_cloud_node_semantickitti_label_id")
os_files = os.listdir(velo_loc)

curr_frame_id=0
end_frame_id=0
for velo_file in sorted(os_files):
    velo = np.fromfile(os.path.join(velo_loc, velo_file), dtype=np.float32).reshape(-1, 4)[:, :3]
    velo = torch.from_numpy(velo).to(device)
    labels = np.fromfile(os.path.join(label_base_loc, velo_file.split(".")[0]+".label"), dtype=np.uint32)
    labels_remapped = torch.from_numpy(class_remap[labels]).to(device=device) # Remap labels to be contiguous
    
    # Ego vehicle = 0
    non_void = labels_remapped != 0
    velo = velo[non_void]
    labels_remapped = labels_remapped[non_void]
    
    labeled_pc = torch.hstack( (velo, labels_remapped.reshape(-1, 1)) )
    
    non_dynamic = (labels_remapped != class_remap[8]) & (labels_remapped != class_remap[17])
    labeled_pc = labeled_pc[non_dynamic]
    
    current_map = bki_map.initialize_grid()
    posterior_map = bki_map(current_map, labeled_pc)

    if curr_frame_id==end_frame_id:
        break
    curr_frame_id += 1

In [5]:
H, W, D, _ = posterior_map.shape

publish_voxels(posterior_map, map_pub, 
    bki_map.centroids, 
    bki_map.min_bound.reshape(-1), 
    bki_map.max_bound.reshape(-1), 
    bki_map.grid_size.reshape(-1)
)

In [6]:
# Test 3D conv

num_classes = 20

# X, Y, Z
filters = torch.zeros(27, dtype=torch.float)
filters[13] = 1
filters = filters.view(1, 1, 3, 3, 3)

print(filters[0, 0, 1, :, :])

inputs = torch.ones(num_classes, 1, 5, 5, 5)

output = F.conv3d(inputs, filters, padding="same")
print(output[0, 0, :, :, 0])

tensor([[0., 0., 0.],
        [0., 1., 0.],
        [0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
