This is a demonstration on 1-degree-of-freedom registration experiment that the moving point cloud can be rotated around its z-axis.

In [None]:
import torch
torchdevice = "cuda"
torchdtype = torch.float32
torch.set_default_dtype(torchdtype)
import numpy as np
import open3d as o3d
import kde
import open3dtools
import copy
import wasserstein_metric
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib.ticker import PercentFormatter
from joblib import Parallel, delayed

In [None]:
# This is the function that given a ground truth rotation, output the estimated rotation by proposed algorithm.
def diff_r(rotate):
    print(f"-------{rotate}--------")
    learning_rate = 0.7*1e-8
    tolerance = 1e-10
    alpha_init, beta_init, gamma_init = [torch.tensor(a, requires_grad=True, dtype=torchdtype, device=torchdevice) for a in [
        0,0,0]]

    t_init = torch.tensor([0, 0, 0], dtype=torchdtype, device=torchdevice, requires_grad=True)

    interval = torch.tensor([[-1, 1], [-1, 1], [-1, 1]],dtype=torchdtype)

    # ---------------LOAD POINT CLOUDS---------------
    source = o3d.io.read_point_cloud('data/bun000.ply')
    source = source.voxel_down_sample(voxel_size=0.05)
    source_points = open3dtools.o3dpcd2np(source)

    target = copy.deepcopy(source)

    target.rotate(target.get_rotation_matrix_from_xyz(rotate), 
                center=target.get_center())

    target_points = open3dtools.o3dpcd2np(target)

    # ---------------BANDWIDTH---------------
    vars = np.std(source_points,axis=0)
    bandwidth = 1.06*np.power(source_points.shape[0],-0.2)*np.mean(vars)

    # ---------------KDE---------------
    source_kde = kde.KernelDensityEstimator(
        bandwidth, kde.gaussian_kernel_3d_torch, source_points, use_pytorch=True, torch_device=torchdevice, dtype=torchdtype)
    target_kde = kde.KernelDensityEstimator(
        bandwidth, kde.gaussian_kernel_3d_torch, target_points, use_pytorch=True, torch_device=torchdevice, dtype=torchdtype)
        
    source_center = torch.tensor(source.get_center(), dtype=torchdtype, device=torchdevice)
    target_center = torch.tensor(target.get_center(), dtype=torchdtype, device=torchdevice)

    ## ---------------WD---------------
    # DifferentiableWassersteinMetric_RZ constraint the point cloud can only rotate around its Z axis.
    # In this demo, we estimate the transformation on source point cloud to align with the target point cloud.
    # The reason for such special design is: in the experiment, we firstly duplicate the point cloud, then we apply 
    # transformation T on the target point cloud. Through such design, we are able to directly estimate the T, instead of its inversion T^(-1).
    # It will be convenient to calculate the error. Otherwise, if we estimate the transformation on target point cloud to align with 
    # source point cloud, we have to calculate either the inversion of T or the inversion of estimated transformation, which add unnecessary computations.
    wd_instance = wasserstein_metric.DifferentiableWassersteinMetric_RZ(p=source_kde.pdf, q=target_kde.pdf,
                                                                    p_center=source_center, q_center=target_center,
                                                                    enable_move_pdf_to_center=False,
                                                                    initial_t=t_init,
                                                                    init_alpha=alpha_init, init_beta=beta_init, init_gamma=gamma_init,
                                                                    step_size=learning_rate,
                                                                    interval=interval,
                                                                    torch_device=torchdevice,
                                                                    dtype=torchdtype,
                                                                    s=6,
                                                                    enable_translation=False,
                                                                    enable_rotation=True,
                                                                    integration_num_steps=65,)


    num_iteration  = 280
    
    for i in tqdm(range(num_iteration)):
        wd1 = float(wd_instance.calc_1_wasserstein_metric()[0])
        wd_instance.update_transformation()

        if i>50 and (wd1-float(wd_instance.calc_1_wasserstein_metric()[0]))<tolerance:
            break

    return wd_instance.get_rotation()



In [None]:
# The function to calculate approximate Wasserstein distance, given rotation on a point cloud.
def diff_r_wd(rotate):
    learning_rate = 1e2


    alpha_init, beta_init, gamma_init = [torch.tensor(a, requires_grad=True, dtype=torchdtype, device=torchdevice) for a in [
        0,0,0]]

    t_init = torch.tensor([0, 0, 0], dtype=torchdtype, device=torchdevice, requires_grad=True)

    interval = torch.tensor([[-1, 1], [-1, 1], [-1, 1]],dtype=torchdtype)

    # ---------------LOAD POINT CLOUDS---------------
    source = o3d.io.read_point_cloud('data/bun000.ply')
    source = source.voxel_down_sample(voxel_size=0.05)
    source_points = open3dtools.o3dpcd2np(source)

    target = copy.deepcopy(source)

    target.rotate(target.get_rotation_matrix_from_xyz(rotate), 
                center=target.get_center())

    target_points = open3dtools.o3dpcd2np(target)

    # ---------------BANDWIDTH---------------
    vars = np.std(source_points,axis=0)
    bandwidth = 1.06*np.power(source_points.shape[0],-0.2)*np.mean(vars)

    # ---------------KDE---------------
    source_kde = kde.KernelDensityEstimator(
        bandwidth, kde.gaussian_kernel_3d_torch, source_points, use_pytorch=True, torch_device=torchdevice, dtype=torchdtype)
    target_kde = kde.KernelDensityEstimator(
        bandwidth, kde.gaussian_kernel_3d_torch, target_points, use_pytorch=True, torch_device=torchdevice, dtype=torchdtype)
        
    source_center = torch.tensor(source.get_center(), dtype=torchdtype, device=torchdevice)
    target_center = torch.tensor(target.get_center(), dtype=torchdtype, device=torchdevice)

    ## ---------------WD---------------
    wd_instance = wasserstein_metric.DifferentiableWassersteinMetric(p=source_kde.pdf, q=target_kde.pdf,
                                                                    p_center=source_center, q_center=target_center,
                                                                    enable_move_pdf_to_center=False,
                                                                    initial_t=t_init,
                                                                    init_alpha=alpha_init, init_beta=beta_init, init_gamma=gamma_init,
                                                                    step_size=learning_rate,
                                                                    interval=interval,
                                                                    torch_device=torchdevice,
                                                                    dtype=torchdtype,
                                                                    s=6,
                                                                    enable_translation=False,
                                                                    enable_rotation=True,
                                                                    integration_num_steps=65,)


    return float(wd_instance.calc_1_wasserstein_metric()[0]) # type: ignore


In [None]:
# Normalize to 0~2pi
def normalize_angle(angle):
    return np.mod(angle,2*np.pi)

In [None]:
# Do experiment and plot results
r_axis = 2
r_arange = np.arange(-np.pi,np.pi,0.01*np.pi)
r_arange = r_arange[np.abs(r_arange)>=0.01*np.pi]

rs = np.zeros((len(r_arange),3))
rs[:,r_axis] = r_arange
evaluated_r = np.empty_like(rs)
evaluated_r = Parallel(n_jobs=1)(delayed(diff_r)(rotation) for rotation in rs)
evaluated_r = torch.tensor(evaluated_r).cpu().detach().numpy()

np.save("register_wd_rotate.npy",evaluated_r)

rel_errors = np.absolute((normalize_angle(evaluated_r[:,r_axis])-normalize_angle(rs[:,r_axis]))/np.pi)

fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('rotation around z-axis')
ax1.set_ylabel('relative error')
ax1.plot(r_arange, rel_errors*100, color=color, label="relative error")
ax1.tick_params(axis='y')
ax1.yaxis.set_major_formatter(PercentFormatter())




r_arange = np.arange(-np.pi,np.pi,0.01*np.pi)
rs = np.zeros((len(r_arange),3))
rs[:,r_axis] = r_arange
wds = Parallel(n_jobs=4)(delayed(diff_r_wd)(rotation) for rotation in rs)
wds = np.array(wds)
norm_wds  = wds/wds.max()


ax2 = ax1.twinx()

color = 'tab:blue'
ax2.set_ylabel('normalized approximate Wasserstein distance') 
ax2.plot(r_arange, norm_wds, color = color,label = "normalized approximate WD")
ax2.tick_params(axis='y')

ax1.legend(loc="upper left")
ax2.legend(loc="upper right")

fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.show()