## A simple analytical PointNetLK example

Load packages.

In [1]:
# !pip install git+git://github.com/Lilac-Lee/PointNetLK_Revisited

import argparse
import os
import sys
import numpy as np
import torch
import torch.utils.data
import torchvision

# visualize the point cloud
import open3d as o3d
# open3d>=0.13.0, otherwise, comment line below
o3d.visualization.webrtc_server.enable_webrtc()

sys.path.insert(0, '../')
import data_utils
import trainer

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] Resetting default logger to print to terminal.


Initialize some testing and network parameters as arguments.

In [2]:
args = argparse.Namespace()

# dimension for the PointNet embedding
args.dim_k = 1024

# device: cuda/cpu
# args.device = 'cuda:0'
args.device = 'cpu'

# maximum iterations for the LK
args.max_iter = 10

# embedding function: pointnet
args.embedding = 'pointnet'

# output log file name
args.outfile = 'toyexample_2021_04_17'

# specify data type: real
args.data_type = 'real'

# specify visualize result or not
args.vis = True

Get toy example point cloud pairs, and ground truth rigid pose.
Set some voxelization related parameters.

In [3]:
# load data
p0 = np.load('./p0.npy')[np.newaxis,...]
p1 = np.load('./p1.npy')[np.newaxis,...]

# randomly set the twist parameters for the ground truth pose
x = np.array([[0.57, -0.29, 0.73, -0.37, 0.48, -0.54]])

# set voxelization parameters
voxel_ratio = 0.05
voxel = 2
max_voxel_points = 1000
num_voxels = 8

# construct the testing dataset
testset = data_utils.ToyExampleData(p0, p1, voxel_ratio, voxel, max_voxel_points, num_voxels, x, args.vis)

Create the model class, load the pre-trained model, and begin testing!

In [4]:
# create model
dptnetlk = trainer.TrainerAnalyticalPointNetLK(args)
model = dptnetlk.create_model()

# specify device
if not torch.cuda.is_available():
    args.device = 'cpu'
args.device = torch.device(args.device)
model.to(args.device)

# load pre-trained model
model.load_state_dict(torch.load('../logs/model_trained_on_ModelNet40_model_best.pth', map_location='cpu'))

# testloader
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=0, drop_last=False)

# begin testing
dptnetlk.test_one_epoch(model, testloader, args.device, 'test', args.data_type, args.vis, toyexample=True)

  0%|                                              | 0/1 [00:00<?, ?it/s]

WebVisualizer(window_uid='window_0')

INFO - 2021-06-15 13:36:53,437 - trainer - test, 0/1, 0 iterations, 1.276627


WebVisualizer(window_uid='window_1')

INFO - 2021-06-15 13:36:56,794 - trainer - test, 0/1, 1 iterations, 0.874349


WebVisualizer(window_uid='window_2')

INFO - 2021-06-15 13:36:59,965 - trainer - test, 0/1, 2 iterations, 0.413566


WebVisualizer(window_uid='window_3')

INFO - 2021-06-15 13:37:03,296 - trainer - test, 0/1, 3 iterations, 0.152394


WebVisualizer(window_uid='window_4')

INFO - 2021-06-15 13:37:06,579 - trainer - test, 0/1, 4 iterations, 0.084110


WebVisualizer(window_uid='window_5')

INFO - 2021-06-15 13:37:09,831 - trainer - test, 0/1, 5 iterations, 0.071174


WebVisualizer(window_uid='window_6')

INFO - 2021-06-15 13:37:13,345 - trainer - test, 0/1, 6 iterations, 0.069102


WebVisualizer(window_uid='window_7')

INFO - 2021-06-15 13:37:16,841 - trainer - test, 0/1, 7 iterations, 0.068973


WebVisualizer(window_uid='window_8')

INFO - 2021-06-15 13:37:20,172 - trainer - test, 0/1, 8 iterations, 0.069051


WebVisualizer(window_uid='window_9')

INFO - 2021-06-15 13:37:23,592 - trainer - test, 0/1, 9 iterations, 0.069104


WebVisualizer(window_uid='window_10')

INFO - 2021-06-15 13:37:27,086 - trainer - test, 0/1, 10 iterations, 0.069128
                                                                         

success cases are [0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
Source to Template:
toyexample_2021_04_17
********************mean********************
rot_MSE: 2.437836505812705, rot_RMSE: 1.5613572639894768, rot_MAE: 1.158016039794037, trans_MSE: 0.0013231680495664477, trans_RMSE: 0.03637537732720375, trans_MAE: 0.036363422870635986
********************median********************
rot_MSE: 0.7500058401062544, rot_RMSE: 0.8660287755647929, rot_MAE: 0.8660287755647929, trans_MSE: 0.0012785122962668538, trans_RMSE: 0.03575628995895386, trans_MAE: 0.03575628995895386


