# Test plane-based registration
Test different plane-based registration methods

In [None]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import torch

import planeslam.io as io
from planeslam.general import plot_3D_setup, color_legend
from planeslam.scan import pc_to_scan
from planeslam.registration import get_correspondences, extract_corresponding_features
from planeslam.geometry.util import quat_to_rot_mat

%load_ext autoreload
%autoreload 2

In [None]:
np.set_printoptions(suppress=True)

### Load AirSim data

In [None]:
# Read in point cloud data
binpath = os.path.join(os.getcwd(), '..', '..', 'data', 'airsim', 'blocks_20_samples_1', 'lidar', 'Drone0')
PC_data = io.read_lidar_bin(binpath)

# Read in ground-truth poses (in drone local frame)
posepath = os.path.join(os.getcwd(), '..', '..', 'data', 'airsim', 'blocks_20_samples_1', 'poses', 'Drone0')
drone_positions, drone_orientations = io.read_poses(posepath)

In [None]:
# Extract scans
num_scans = len(PC_data)
scans = num_scans * [None]
scans_transformed = num_scans * [None]
for i in range(num_scans):
    scans[i] = pc_to_scan(PC_data[i])

### Decoupled

In [None]:
from planeslam.registration import decoupled_register

abs_traj_transformations = np.zeros((num_scans-1,4,4))
T_abs = np.eye(4)
avg_runtime = 0

for i in range(len(scans)-1):
    start_time = time.time()
    R_hat, t_hat = decoupled_register(scans[i+1], scans[i])
    avg_runtime += time.time() - start_time
    print(t_hat)
    T_hat = np.vstack((np.hstack((R_hat, t_hat)), np.hstack((np.zeros(3), 1))))
    T_abs = T_hat @ T_abs
    abs_traj_transformations[i,:,:] = T_abs

avg_runtime /= len(scans)-1
print("average registration time: ", avg_runtime)

In [None]:
np.round(abs_traj_transformations[:,:3,3],2)

### Gauss-newton

In [None]:
from planeslam.registration import GN_register

abs_traj_transformations = np.zeros((num_scans-1,4,4))
T_abs = np.eye(4)
avg_runtime = 0

for i in range(len(scans)-1):
    start_time = time.time()
    R_hat, t_hat = GN_register(scans[i+1], scans[i])
    avg_runtime += time.time() - start_time
    print(t_hat)
    T_hat = np.vstack((np.hstack((R_hat, t_hat)), np.hstack((np.zeros(3), 1))))
    T_abs = T_hat @ T_abs
    abs_traj_transformations[i,:,:] = T_abs

avg_runtime /= len(scans)-1
print("average registration time: ", avg_runtime)

### Torch GN

In [None]:
# Set torch device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
    print("WARNING: CPU only, this will be slow!")

In [None]:
from planeslam.registration import torch_GN_register

abs_traj_transformations = np.zeros((num_scans-1,4,4))
T_abs = np.eye(4)
avg_runtime = 0

for i in range(len(scans)-1):
    start_time = time.time()
    R_hat, t_hat = torch_GN_register(scans[i+1], scans[i], device)
    avg_runtime += time.time() - start_time
    T_hat = np.vstack((np.hstack((R_hat, t_hat)), np.hstack((np.zeros(3), 1))))
    T_abs = T_hat @ T_abs
    abs_traj_transformations[i,:,:] = T_abs

avg_runtime /= len(scans)-1
print("average registration time: ", avg_runtime)

In [None]:
abs_traj_transformations[:,:3,3]

In [None]:
traj = abs_traj_transformations[:,:3,3]
ax = plot_3D_setup(traj)
ax.scatter(traj[:,0], traj[:,1], traj[:,2])

In [None]:
ax = plot_3D_setup(drone_positions)
ax.scatter(drone_positions[:,0], drone_positions[:,1], drone_positions[:,2])

In [None]:
filepath = os.path.join(os.getcwd(), '..', '..', 'data', 'results', 'l2l_abs_traj_transformations')
np.save(filepath, np.array(abs_traj_transformations))

### Torch SGD

In [None]:
from planeslam.registration import torch_register

abs_traj_transformations = np.zeros((num_scans-1,4,4))
T_abs = np.eye(4)
avg_runtime = 0

for i in range(len(scans)-1):
    start_time = time.time()
    R_hat, t_hat = torch_register(scans[i+1], scans[i], device)
    avg_runtime += time.time() - start_time
    T_hat = np.vstack((np.hstack((R_hat, t_hat)), np.hstack((np.zeros(3), 1))))
    T_abs = T_hat @ T_abs
    abs_traj_transformations[i,:,:] = T_abs

avg_runtime /= len(scans)-1
print("average registration time: ", avg_runtime)

In [None]:
abs_traj_transformations[:,:3,3]

In [None]:
drone_positions