In [None]:
import torch
torchdevice = "cuda"
torchdtype = torch.float32
torch.set_default_dtype(torchdtype)
import numpy as np
import open3d as o3d
import kde
import open3dtools
import copy
import wasserstein_metric
from tqdm import tqdm

In [None]:
learning_rate = 0.7*1e-8
tolerance = 1e-10
alpha_init, beta_init, gamma_init = [torch.tensor(a, requires_grad=True, dtype=torchdtype, device=torchdevice) for a in [
    0,0,0]]

t_init = torch.tensor([0, 0, 0], dtype=torchdtype, device=torchdevice, requires_grad=True)

interval = torch.tensor([[-1, 1], [-1, 1], [-1, 1]],dtype=torchdtype)

# ---------------LOAD POINT CLOUDS---------------
source = o3d.io.read_point_cloud('data/bun000.ply')
source = source.voxel_down_sample(voxel_size=0.05)
source_points = open3dtools.o3dpcd2np(source)

target = copy.deepcopy(source)

# Translate 
source.rotate(target.get_rotation_matrix_from_xyz([0.1*np.pi,-0.1*np.pi,0.2*np.pi]), 
            center=target.get_center())
source.translate([-0.05,0.08,0.02])

target_points = open3dtools.o3dpcd2np(target)

# ---------------BANDWIDTH---------------
vars = np.std(source_points,axis=0)
bandwidth = 1.06*np.power(source_points.shape[0],-0.2)*np.mean(vars)

# ---------------KDE---------------
source_kde = kde.KernelDensityEstimator(
    bandwidth, kde.gaussian_kernel_3d_torch, source_points, use_pytorch=True, torch_device=torchdevice, dtype=torchdtype)
target_kde = kde.KernelDensityEstimator(
    bandwidth, kde.gaussian_kernel_3d_torch, target_points, use_pytorch=True, torch_device=torchdevice, dtype=torchdtype)
    
source_center = torch.tensor(source.get_center(), dtype=torchdtype, device=torchdevice)
target_center = torch.tensor(target.get_center(), dtype=torchdtype, device=torchdevice)

## ---------------WD---------------

wd_instance = wasserstein_metric.DifferentiableWassersteinMetric(p=source_kde.pdf, q=target_kde.pdf,
                                                                p_center=source_center, q_center=target_center,
                                                                enable_move_pdf_to_center=False,
                                                                initial_t=t_init,
                                                                init_alpha=alpha_init, init_beta=beta_init, init_gamma=gamma_init,
                                                                step_size=learning_rate,
                                                                interval=interval,
                                                                torch_device=torchdevice,
                                                                dtype=torchdtype,
                                                                s=6,
                                                                enable_translation=True,
                                                                enable_rotation=True,
                                                                integration_num_steps=65,)


num_iteration  = 280
wd2 = float(wd_instance.calc_1_wasserstein_metric()[0])
print(f"Initial (scaled) approximate Wasserstein distance: {wd2}")
for i in tqdm(range(num_iteration)):
    wd1 = wd2
    wd_instance.update_transformation()
    wd2 = float(wd_instance.calc_1_wasserstein_metric()[0])
    if i>50 and (wd1-wd2)<tolerance:
        break

# get estimated transformation on the source point cloud that align it with the target point cloud
print(wd_instance.get_rotation()) # 3 tensors representing alpha, beta, gamma
print(wd_instance.get_t()) 