## A simple IFR example

Load packages.

In [15]:
import os
import sys
import numpy as np
import torch

sys.path.insert(0, '../')
from ifr import IFR
import utils

In [16]:
ifr_model = IFR(scale=2, maxiter=20,zero_mean=True,
                trunc=True,rand_pa=False, kp_nb=False,encoder_id=4)

In [17]:
def do_transform(p0, x):
    # p0: [N, 3]
    # x: [1, 6], twist-params
    g = utils.exp(x).to(p0) # [1, 4, 4]
    p1 = utils.transform(g, p0)
    igt = g.squeeze(0) # igt: p0 -> p1
    return p1, igt


# load data
p0 = np.load('./p0.npy')
p1 = np.load('./p1.npy')

# randomly set the twist parameters for the ground truth pose
x = np.array([[0.57, -0.29, 0.73, -0.37, 0.48, -0.54]])

p1_pre, igt = do_transform(torch.from_numpy(p1[np.newaxis,...]), torch.from_numpy(x)[np.newaxis,...])
p1_pre = p1_pre.numpy()[0,:,:]
print('GT Transform', igt)

GT Transform tensor([[[ 0.7150, -0.6970, -0.0543, -0.4716],
         [ 0.5443,  0.6037, -0.5824,  0.4599],
         [ 0.4388,  0.3868,  0.8111, -0.4687],
         [ 0.0000,  0.0000,  0.0000,  1.0000]]], dtype=torch.float64)


In [18]:
print('Input Shape without downsample', p0.shape, p1_pre.shape)
estimated_pose = ifr_model.register(p0, p1_pre)[0,:,:]
print('Transform back:\n',estimated_pose)

print('Error Matrix:\n', estimated_pose@igt.numpy()[0,:,:])

Input Shape without downsample (103974, 3) (76241, 3)
Transform back:
 [[ 0.7021648   0.55477244  0.446309    0.31933507]
 [-0.71055853  0.5860416   0.38943765 -0.41371375]
 [-0.04550641 -0.590578    0.8056965   0.64275825]
 [ 0.          0.          0.          1.        ]]
Error Matrix:
 [[ 9.99834977e-01  1.81426755e-02  6.99938656e-04  3.41444511e-02]
 [-1.81502537e-02  9.99748879e-01  1.31395401e-02  8.33123176e-03]
 [-4.61194321e-04 -1.31498695e-02  9.99913461e-01  1.50140751e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]


In [19]:
# vis (visualization is on downsampled for fast draw)
p0_hat = (estimated_pose[:3,:3]@p1_pre.T+estimated_pose[:3,(3,)]).T

import ipyvolume as ipv

def plot(x1,x2):
    fig = ipv.figure()
    scatter = ipv.scatter(x1[:,0], x1[:,1], x1[:,2],np.array([255,0,0]))
    scatter = ipv.scatter(x2[:,0], x2[:,1], x2[:,2],np.array([0,0,255]))
    ipv.show()

In [20]:
plot(p0, p1_pre)
plot(p0, p0_hat)

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), projectionMatrix=(1.0, 0.0,…

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), projectionMatrix=(1.0, 0.0,…