In [1]:
# https://github.com/davezdeng8/tlfpad/blob/master/util_ec.py

import torch
import numpy as np
import time
import os
import pdb

import torch.nn.functional as F
# For visualizer
import rospy
from visualizer import publish_voxels
from visualization_msgs.msg import *
rospy.init_node('talker',disable_signals=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device == "cuda":
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
else:
  start = None
  end = None
print("device is ", device)
    
dataset_loc = "/home/arthur/Data/rellis3d/00004/"

device is  cpu


In [2]:
def knn_point(k, pos1, pos2):
    '''
    Input:
        k: int32, number of k in k-nn search
        pos1: (batch_size, ndataset, c) float32 array, input points
        pos2: (batch_size, npoint, c) float32 array, query points
    Output:
        val: (batch_size, npoint, k) float32 array, L2 distances
        idx: (batch_size, npoint, k) int32 array, indices to input points
    '''
    B, N, C = pos1.shape
    M = pos2.shape[1]
    pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1)
    pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1)
    dist = torch.sum(-(pos1-pos2)**2,-1)
    val,idx = dist.topk(k=k,dim = -1)
    return torch.sqrt(-val), idx

In [3]:
class ContinuousBKI(torch.nn.Module):
    def __init__(self, grid_size, min_bound, max_bound, max_k, batch_size=1000, 
                 num_classes=20, kernel="sparse", prior=0.001, device="cpu",
                max_dist=1.0):
        '''
        Input:
            grid_size: (x, y, z) int32 array, number of voxels
            min_bound: (x, y, z) float32 array, lower bound on local map
            max_bound: (x, y, z) float32 array, upper bound on local map
        '''
        super().__init__()
        self.max_dist = max_dist
        self.kernel = kernel
        self.min_bound = min_bound.view(-1, 3).to(device)
        self.max_bound = max_bound.view(-1, 3).to(device)
        self.grid_size = grid_size
        self.prior = prior
        self.max_k = max_k
        self.device = device
        self.num_classes = num_classes
        self.batch_size = batch_size
        
        self.pi = torch.acos(torch.zeros(1)).item() * 2
        
        [xs, ys, zs] = [(max_bound[i]-min_bound[i])/(2*grid_size[i]) + 
                        torch.linspace(min_bound[i], max_bound[i], steps=grid_size[i]+1)[:-1] for i in range(3)]
        self.centroids = torch.cartesian_prod(xs, ys, zs).to(device)
        
        reshaped_centroids = self.centroids.reshape(grid_size[0], grid_size[1], grid_size[2], 3)
        
        self.initialize_kernel()
        
        
    def knn_point(self, pos1, pos2):
        '''
        Input:
            pos1: (batch_size, ndataset, c) float32 array, input points
            pos2: (batch_size, npoint, c) float32 array, query points
        Output:
            val: (batch_size, npoint, k) float32 array, L2 distances
            idx: (batch_size, npoint, k) int32 array, indices to input points
        '''
        N, C = pos1.shape
        M = pos2.shape[0]
        pos1 = pos1.view(1,N,-1).repeat(M,1,1)
        pos2 = pos2.view(M,1,-1).repeat(1,N,1)
        dist = torch.sum(-(pos1-pos2)**2,-1)
        val,idx = dist.topk(k=self.max_k,dim = -1)
        
        return torch.sqrt(-val), idx
    
    def initialize_kernel(self):
        if self.kernel == "sparse":
            self.sigma = torch.nn.Parameter(torch.tensor(0.1))
            self.ell = torch.nn.Parameter(torch.tensor(0.3))
            self.max_dist = self.ell.item()
            
    def calculate_kernel(self, d):
        if self.kernel == "sparse":
            return self.sigma * ( 
                (1/3)*(2 + torch.cos(2 * self.pi * d/self.ell))*(1 - d/self.ell) +
                         1/(2*self.pi) * torch.sin(2 * self.pi * d / self.ell)
                         )
            
            
    def initialize_grid(self):
        return torch.zeros(self.grid_size[0], self.grid_size[1], self.grid_size[2], 
                           self.num_classes, device=self.device) + self.prior
    
    def grid_ind(self, input_pc):
        '''
        Input:
            input_xyz: N * (x, y, z, c) float32 array, point cloud
        Output:
            grid_inds: N' * (x, y, z, c) int32 array, point cloud mapped to voxels
        '''
        input_xyz   = input_pc[:, :3]
        labels      = input_pc[:, 3].view(-1, 1)
        
        valid_input_mask = torch.all((input_xyz < self.max_bound) & (input_xyz >= self.min_bound), axis=1)
        # To be continued..
        
        
    def forward(self, current_map, point_cloud):
        '''
        Input:
            current_map: (x, y, z, c) float32 array, prior dirichlet distribution over map
            point_cloud: N * (x, y, z, c) float32 array, semantically labeled points
        Output:
            updated_map: (x, y, z, c) float32 array, posterior dirichlet distribution over map
        '''
        # Optional: downsample points and queries to reduce time complexity

        # First, get matching indices
        with torch.no_grad():
            start_num = 0
            matched_dists = torch.zeros(self.centroids.shape[0], self.max_k, device=self.device)
            matched_labels = torch.zeros(self.centroids.shape[0], self.max_k, dtype=torch.long, device=self.device)
            while start_num < self.centroids.shape[0]:
                end_point = min(self.centroids.shape[0], start_num + self.batch_size)

                dists, inds = self.knn_point(point_cloud[:, :3], self.centroids[start_num:end_point])
                labels = point_cloud[:, 3][inds]
                matched_dists[start_num:end_point, :] = dists
                matched_labels[start_num:end_point, :] = labels
                start_num += self.batch_size
            labels_one_hot = F.one_hot(matched_labels, num_classes=self.num_classes)
        print("finished no grad")
        # Next, use non-zero distances to calculate kernel
        X, Y, Z, C = current_map.shape
        
        kernel_values = torch.zeros(X, Y, Z, 1, self.max_k).to(self.device)
        dists = matched_dists.reshape(X, Y, Z, 1, self.max_k)
        
        labels = labels_one_hot.permute(0, 2, 1).reshape(X, Y, Z, C, self.max_k)
        
        within_range = dists < self.max_dist
        kernel_values[within_range] = self.calculate_kernel(dists[within_range])
        
        # Perform Update
        update = torch.sum(kernel_values * labels, 4)
        
        return current_map + update
    
    # def propagate(self, current_map, )

In [4]:
# Create an empty grid
bki_map = ContinuousBKI(
    torch.tensor([128, 128, 8]), # Grid size
    torch.tensor([-25.6, -25.6, -2.0]), # Lower bound
    torch.tensor([25.6, 25.6, 2.0]), # Upper bound
    10 # Max points in neighborhood
)


current_map = bki_map.initialize_grid()

# Load point cloud from KITTI
velo_loc = os.path.join(dataset_loc, "os1_cloud_node_kitti_bin")
label_base_loc = os.path.join(dataset_loc, "os1_cloud_node_semantickitti_label_id")
os_files = os.listdir(velo_loc)

for velo_file in sorted(os_files):
    velo = np.fromfile(os.path.join(velo_loc, velo_file), dtype=np.float32).reshape(-1, 4)[:, :3]
    labels = np.fromfile(os.path.join(label_base_loc, velo_file.split(".")[0]+".label"), dtype=np.uint32)
    labeled_pc = np.hstack( (velo, labels.reshape(-1, 1)) )
    break

In [5]:
current_map = bki_map.initialize_grid()
posterior_map = bki_map(current_map, torch.tensor(labeled_pc, device=device))

In [6]:
# print(torch.sum(posterior_map))

In [8]:
# Add visualization
map_pub = rospy.Publisher('SemMap', MarkerArray, queue_size=10)

H, W, D, _ = current_map.shape
current_map[0:H:2, 1:W:2, 1:D:2, 2] = 3
publish_voxels(current_map, map_pub, 
    bki_map.centroids, 
    bki_map.min_bound.reshape(-1), 
    bki_map.max_bound.reshape(-1), 
    bki_map.grid_size.reshape(-1)
)